In [7]:
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
from pymilvus import utility
connections.connect(
    host="localhost", 
    port="19530"
)
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-mpnet-base-v2")

In [8]:
import json 

with open('documents-with-ids.json', 'rt') as f_in:
    documents = json.load(f_in)

In [37]:
utility.drop_collection("course_info")


In [38]:


fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="text_id", dtype=DataType.VARCHAR, max_length=535),
    FieldSchema(name="section", dtype=DataType.VARCHAR, max_length=535),
    FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=50000),
    FieldSchema(name="question", dtype=DataType.VARCHAR, max_length=535),
    FieldSchema(name="course", dtype=DataType.VARCHAR, max_length=535),
    FieldSchema(name="question_vector", dtype=DataType.FLOAT_VECTOR, dim=768), 
    FieldSchema(name="text_vector", dtype=DataType.FLOAT_VECTOR, dim=768), 
    FieldSchema(name="question_text_vector", dtype=DataType.FLOAT_VECTOR, dim=768)]

schema = CollectionSchema(fields=fields,enable_dynamic_field=True)

collection = Collection(name="course_info", schema=schema)

index_params = {
    "metric_type": "IP",
    "params": {},
}

collection.create_index("question_vector", index_params)
collection.create_index("text_vector", index_params)
collection.create_index("question_text_vector", index_params)


entities = []
for doc in documents: 
    try: 
        question = doc['question']
        text = doc['text']
        qt = question + ' ' + text
        doc['question_vector'] = model.encode(question)
        doc['text_vector'] = model.encode(text)
        doc['question_text_vector'] = model.encode(qt)
        entity = {
                "text_id": doc['text_id'],
                "id": doc['id'],
                "section": doc['section'],
                "text": text,
                "question": question,
                "course": doc['course'],
                "question_vector": doc['question_vector'],
                "text_vector": doc['text_vector'], 
                "question_text_vector": doc['question_text_vector'], 
            }
        entities.append(entity)
    except KeyError as e:
        print(f"Missing key {e} in document {doc['text_id']}")
        continue
    
collection.insert(entities)

(insert count: 948, delete count: 0, upsert count: 0, timestamp: 453070284642320387, success count: 948, err count: 0

In [39]:
collection.release()

In [40]:
collection.load()

In [41]:
query_question_vector = model.encode('Windows or Mac?')
res = collection.search(
    anns_field = "question_vector", 
    param={"metric_type": "IP", "params": {}},
    data=[query_question_vector],
    output_fields=['question'],
    limit=5, # Max. number of search results to return
)

In [59]:
def search(field, vector, course):
    # Perform the search operation
    res = collection.search(
        anns_field=f"{field}", 
        filter=f"course == '{course}'",
        param={"metric_type": "IP", "params": {}},
        data=[vector],
        output_fields=["text_id", "text", "section", "question", "course", "id"], 
        limit=5,  # Max. number of search results to return
    )

    # Initialize an empty list to hold the results
    result_docs = []

    # Loop through the hits
    for hits in res:
        for hit in hits:
            # Append each hit as a dictionary containing the desired fields
            hit_dict = {
                "text_id": hit.entity.get("text_id"),
                "text": hit.entity.get("text"),
                "section": hit.entity.get("section"),
                "question": hit.entity.get("question"),
                "course": hit.entity.get("course"),
            }
            result_docs.append(hit_dict)
    
    # Return the list of result documents
    return result_docs


In [60]:
def question_vector_search(q):
    question = q['question']
    course = q['course']

    v_q = model.encode(question)

    return search('question_vector', v_q, course)

In [61]:
import pandas as pd

In [47]:
df_ground_truth = pd.read_csv('ground-truth-data.csv')

In [48]:
ground_truth = df_ground_truth.to_dict(orient='records')

In [49]:
ground_truth[0]

{'question': "What happens if I can't make the first Office Hours live session on January 15th?",
 'course': 'data-engineering-zoomcamp',
 'document': 'c02e79ef'}

In [52]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

In [53]:
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [62]:
from tqdm.auto import tqdm
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        results = search_function(q)
        relevance = [d['text_id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [63]:
evaluate(ground_truth, question_vector_search)

  0%|          | 0/490 [00:00<?, ?it/s]

[0.6064250469207764, 0.6026979684829712, 0.4027674198150635, 0.39699310064315796, 0.3929474949836731]
[0.7053327560424805, 0.6221023201942444, 0.4539065361022949, 0.452272891998291, 0.3691049814224243]
[0.6450055837631226, 0.5516902208328247, 0.5282288789749146, 0.5247162580490112, 0.517805278301239]
[0.345976859331131, 0.3028748631477356, 0.28106316924095154, 0.26523321866989136, 0.2498122602701187]
[0.6287825107574463, 0.5827332735061646, 0.5691347718238831, 0.5569307804107666, 0.45663100481033325]
[0.6162812113761902, 0.5746455192565918, 0.5283095836639404, 0.521682858467102, 0.4735865592956543]
[0.5207227468490601, 0.4680786728858948, 0.4623643755912781, 0.43360427021980286, 0.4229459762573242]
[0.5066039562225342, 0.5025250911712646, 0.47515857219696045, 0.46582216024398804, 0.44527825713157654]
[0.42600134015083313, 0.39423298835754395, 0.3918437957763672, 0.3153612017631531, 0.3127375841140747]
[0.5769420862197876, 0.5620656609535217, 0.5585223436355591, 0.4972800612449646, 0.44

{'hit_rate': 0.5102040816326531, 'mrr': 0.3724489795918368}

In [64]:
def text_vector_search(q):
    question = q['question']
    course = q['course']

    v_q = model.encode(question)

    return search('text_vector', v_q, course)

In [66]:
evaluate(ground_truth, text_vector_search)

  0%|          | 0/490 [00:00<?, ?it/s]

[0.48471200466156006, 0.42020291090011597, 0.3132360279560089, 0.30395057797431946, 0.3016015291213989]
[0.7552348971366882, 0.5677367448806763, 0.47873666882514954, 0.33706554770469666, 0.2902269661426544]
[0.5818074941635132, 0.4915899932384491, 0.48451441526412964, 0.4642796516418457, 0.45209529995918274]
[0.2777358889579773, 0.27583563327789307, 0.26475828886032104, 0.2433934211730957, 0.24309846758842468]
[0.6088165044784546, 0.5483426451683044, 0.5183913707733154, 0.4989057183265686, 0.4760700464248657]
[0.5093430280685425, 0.47342562675476074, 0.4355292320251465, 0.4080060124397278, 0.3990536332130432]
[0.4719848036766052, 0.3939969539642334, 0.38785114884376526, 0.38163191080093384, 0.3522842526435852]
[0.588912844657898, 0.5384359359741211, 0.5033280849456787, 0.47249719500541687, 0.4414674639701843]
[0.3906805217266083, 0.36312156915664673, 0.30827081203460693, 0.2913321256637573, 0.2721174657344818]
[0.5004897117614746, 0.49797767400741577, 0.47169578075408936, 0.46386277675

{'hit_rate': 0.6081632653061224, 'mrr': 0.4627551020408162}

In [67]:
def question_text_search(q):
    question = q['question']
    course = q['course']

    v_q = model.encode(question)

    return search('question_text_vector', v_q, course)

evaluate(ground_truth, question_text_search)

  0%|          | 0/490 [00:00<?, ?it/s]

[0.6051977872848511, 0.5266731381416321, 0.4812019467353821, 0.412638783454895, 0.3657471537590027]
[0.7239812612533569, 0.5995064973831177, 0.5921218395233154, 0.465165913105011, 0.41230401396751404]
[0.5553195476531982, 0.5116694569587708, 0.5102090835571289, 0.4876810908317566, 0.4576725959777832]
[0.26496511697769165, 0.24907606840133667, 0.24434655904769897, 0.23708072304725647, 0.23491474986076355]
[0.5712003111839294, 0.5706140995025635, 0.5649617910385132, 0.52945876121521, 0.48173901438713074]
[0.5883840322494507, 0.5498811602592468, 0.4830922484397888, 0.4766666889190674, 0.4744686782360077]
[0.47057849168777466, 0.44496190547943115, 0.42656123638153076, 0.4103424549102783, 0.37271609902381897]
[0.5471240282058716, 0.4952797591686249, 0.49523061513900757, 0.48749426007270813, 0.4767211675643921]
[0.4061802625656128, 0.3961114287376404, 0.38307860493659973, 0.31307488679885864, 0.2987065613269806]
[0.6296807527542114, 0.5566585063934326, 0.4700133204460144, 0.4685996472835541,

{'hit_rate': 0.6877551020408164, 'mrr': 0.5305102040816329}