In [11]:
import torch
import chromadb

from preprocessor import Preprocessor

from FlagEmbedding import BGEM3FlagModel
from transformers import XLMRobertaModel, XLMRobertaTokenizer
from sentence_transformers import SentenceTransformer

## Importing models and databases

In [6]:
xlm_tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
xlm_model = XLMRobertaModel.from_pretrained("xlm-roberta-base")

In [12]:
xlm_client = chromadb.PersistentClient(
    path="/home/murad/Documents/self-study/contextual_embeddings/databases/xlm_collection"
)

xlm_collection = xlm_client.get_collection("pdf_chunks")

In [7]:
bge_model = BGEM3FlagModel(
    model_name_or_path="BAAI/bge-m3",
    use_fp16=True
)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [13]:
bge_client = chromadb.PersistentClient(
    path="/home/murad/Documents/self-study/contextual_embeddings/databases/bge_collection"
)
bge_collection = bge_client.get_collection("pdf_chunks")

In [8]:
labse_model = SentenceTransformer(
    model_name_or_path="sentence-transformers/LaBSE"
)

In [14]:
labse_client = chromadb.PersistentClient(
    path="/home/murad/Documents/self-study/contextual_embeddings/databases/labse_collection"
)
labse_collection = labse_client.get_collection("pdf_chunks")

## Query

In [9]:
query = "Nəzarətsiz öyrənmə nədir?"

### XLM RoBERTa model result

In [22]:
xlm_model.eval()

with torch.no_grad():
    inputs = xlm_tokenizer(query, return_tensors="pt", padding=True)
    outputs = xlm_model(**inputs)
    query_embedding_xlm = outputs.last_hidden_state[:, 0, :].squeeze(0).numpy().tolist()

In [23]:
xlm_result = xlm_collection.query(
    query_embeddings=query_embedding_xlm,
    n_results=5,
    include=["distances", "metadatas", "documents"]
)

In [24]:
with open("xlm_result.txt", "w", encoding="utf-8") as f:
    f.write(str(xlm_result["documents"]))

### BGE-M3 model result

In [28]:
query_embedding_bge = bge_model.encode(sentences=query,
                                       max_length=1024,
                                       batch_size=12)['dense_vecs']
query_embedding_bge = query_embedding_bge.tolist()

In [29]:
bge_result = bge_collection.query(
    query_embeddings=query_embedding_bge,
    n_results=5,
    include=["distances", "metadatas", "documents"]
)

In [30]:
with open("bge_result.txt", "w", encoding="utf-8") as f:
    f.write(str(bge_result["documents"]))

### LaBSE model result

In [32]:
query_embedding_labse = labse_model.encode(sentences=query)
query_embedding_labse = query_embedding_labse.tolist()

In [33]:
labse_result = labse_collection.query(
    query_embeddings=query_embedding_labse,
    n_results=5,
    include=["distances", "metadatas", "documents"]
)

In [34]:
with open("labse_result.txt", "w", encoding="utf-8") as f:
    f.write(str(labse_result["documents"]))