In [1]:
import json
import pandas as pd
from tqdm.auto import tqdm
from elasticsearch import Elasticsearch
from sentence_transformers import SentenceTransformer

In [2]:
with open('document_with_ids.json', 'rt') as f_in:
    documents = json.load(f_in)

In [3]:
model_name = 'multi-qa-MiniLM-L6-cos-v1'
model = SentenceTransformer(model_name)

### Indexing

In [4]:
for doc in tqdm(documents):
    question = doc['question']
    text = doc['text']
    qt = question + ' ' + text

    doc['question_vector'] = model.encode(question)
    doc['text_vector'] = model.encode(text)
    doc['question_text_vector'] = model.encode(qt)
    

  0%|          | 0/948 [00:00<?, ?it/s]

In [5]:
es_client = Elasticsearch("http://localhost:9200")

index_settings = {
    "settings": {
        "number_of_shards" : 1,
        "number_of_replicas" : 0
    },
    "mappings" : {
        "properties" : {
            "text" : {"type" : "text"},
            "section" : {"type" : "text"},
            "question" : {"type" : "text"},
            "course" : {"type" : "keyword"},
            "id" : {"type" : "keyword"},
            "question_vector" : {
                "type" : "dense_vector",
                "dims" : 384,
                "index" : True,
                "similarity": "cosine"
            },
            "text_vector" : {
                "type" : "dense_vector",
                "dims" : 384,
                "index" : True,
                "similarity": "cosine"
            },
            "question_text_vector" : {
                "type" : "dense_vector",
                "dims" : 384,
                "index" : True,
                "similarity": "cosine"
            },
        }
    }
}

index_name = "course_questions"
es_client.indices.delete(index = index_name, ignore_unavailable=True)
es_client.indices.create(index = index_name, body = index_settings)

ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'course_questions'})

In [6]:
for doc in tqdm(documents):
    es_client.index(index=index_name, document=doc)

  0%|          | 0/948 [00:00<?, ?it/s]

#### Retrieval stage

In [11]:
from langchain.embeddings import SentenceTransformerEmbeddings
from typing import Dict
from langchain_elasticsearch import ElasticsearchRetriever

In [13]:
es_url = "http://localhost:9200"

In [20]:
course = "data-engineering-zoomcamp"
query = "Can i still join the course?"
embeddings = SentenceTransformerEmbeddings(model_name = 'multi-qa-MiniLM-L6-cos-v1')

In [21]:
def hybrid_query(search_query: str) -> Dict:
    vector = embeddings.embed_query(search_query)  # same embeddings as for indexing
    return {
        "query" : {
            "bool": {
                "must": {
                    "multi_match": {
                        "query": search_query,
                        "fields": ["question", "text", "section"],
                        "type": "best_fields",
                        "boost" :  0.9
                    }
                },
                "filter": {
                    "term": {
                        "course": course
                    }
                }
            }
        },
        "knn" : {
            "field" : "question_text_vector",
            "query_vector" : vector,
            "k" : 5,
            "num_candidates" : 10000,
            "boost" :  0.1,
            "filter" : {
                "term" : {
                    "course" : course
                        }
                    }
        },
        #"rank" : {"rrf": {} },
    }


hybrid_retriever = ElasticsearchRetriever.from_es_params(
    index_name=index_name,
    body_func=hybrid_query,
    content_field="text",
    url=es_url,
)

In [22]:
hybrid_retriever.invoke(query)

[Document(metadata={'_index': 'course_questions', '_id': 'dQy2J5YBbbXtz81UKbLh', '_score': 19.912674, '_source': {'section': 'General course-related questions', 'question': 'Course - Can I still join the course after the start date?', 'course': 'data-engineering-zoomcamp', 'id': '7842b56a', 'question_vector': [0.0030358924996107817, -0.002387200016528368, 0.035881660878658295, 0.02099882811307907, -0.018282320350408554, 0.06715093553066254, -0.10277318954467773, -0.11509547382593155, -0.06606752425432205, -0.004973369650542736, -0.002861724467948079, 0.10543154180049896, -0.000814331229776144, 0.08418365567922592, 0.027047153562307358, -0.03135377913713455, -0.05154325067996979, -0.04948996752500534, 0.05349848419427872, 0.004741473123431206, -0.13610857725143433, 0.01654152013361454, -0.0778471976518631, 0.06462235748767853, 0.03814755752682686, -0.040936168283224106, 0.032365839928388596, -0.017055029049515724, 0.05001968517899513, -0.003753466298803687, -0.0441180020570755, 0.002292