In [1]:
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import pandas as pd
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import torch

csv_path = 'epomoc_without_nan.csv'

loader = CSVLoader(file_path=csv_path, source_column="answer")
data = loader.load()
df = pd.read_csv(csv_path)

In [4]:
base_model_name = "sdadas/mmlw-retrieval-roberta-large"
base_model = HuggingFaceEmbeddings(
    model_name=base_model_name
)

number_of_samples = 5120
rerank_model_name = "sdadas/polish-reranker-large-ranknet"
rerank_model = HuggingFaceCrossEncoder(model_name=rerank_model_name, model_kwargs={"default_activation_function":torch.nn.Identity(),
    "max_length":512})

In [5]:
chunk_size = 2100
data_part = data[:number_of_samples]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)

In [6]:
base_model_top_k = 20
rerank_top_k = 10

all_splits = text_splitter.split_documents(data_part)
vectorstore = Chroma.from_documents(documents=all_splits, embedding=base_model)
retriever = vectorstore.as_retriever(search_kwargs={"k": base_model_top_k})

compressor = CrossEncoderReranker(model=rerank_model, top_n=rerank_top_k)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

In [7]:
top1_acc = 0
top3_acc = 0
top5_acc = 0
top10_acc = 0

for idx in range(number_of_samples):
    if idx % 50 == 0:
        print("idx: " + str(idx))
    query = "pytanie: " + df["question"][idx]
    compressed_docs = compression_retriever.get_relevant_documents(
        query
    )
    docs_ids = [compressed_docs[idx].metadata["row"] for idx in range(len(compressed_docs))]
    top1_acc += idx in docs_ids[:1]
    top3_acc += idx in docs_ids[:3]
    top5_acc += idx in docs_ids[:5]
    top10_acc += idx in docs_ids[:10]

print("Top 1 accuracy: " + str(top1_acc * 100/number_of_samples))
print("Top 3 accuracy: " + str(top3_acc * 100/number_of_samples))
print("Top 5 accuracy: " + str(top5_acc * 100/number_of_samples))
print("Top 10 accuracy: " + str(top10_acc * 100/number_of_samples))

idx: 0
idx: 50
idx: 100
idx: 150
idx: 200
idx: 250
idx: 300
idx: 350
idx: 400
idx: 450
idx: 500
idx: 550
idx: 600
idx: 650
idx: 700
idx: 750
idx: 800
idx: 850
idx: 900
idx: 950
idx: 1000
idx: 1050
idx: 1100
idx: 1150
idx: 1200
idx: 1250
idx: 1300
idx: 1350
idx: 1400
idx: 1450
idx: 1500
idx: 1550
idx: 1600
idx: 1650
idx: 1700
idx: 1750
idx: 1800
idx: 1850
idx: 1900
idx: 1950
idx: 2000
idx: 2050
idx: 2100
idx: 2150
idx: 2200
idx: 2250
idx: 2300
idx: 2350
idx: 2400
idx: 2450
idx: 2500
idx: 2550
idx: 2600
idx: 2650
idx: 2700
idx: 2750
idx: 2800
idx: 2850
idx: 2900
idx: 2950
idx: 3000
idx: 3050
idx: 3100
idx: 3150
idx: 3200
idx: 3250
idx: 3300
idx: 3350
idx: 3400
idx: 3450
idx: 3500
idx: 3550
idx: 3600
idx: 3650
idx: 3700
idx: 3750
idx: 3800
idx: 3850
idx: 3900
idx: 3950
idx: 4000
idx: 4050
idx: 4100
idx: 4150
idx: 4200
idx: 4250
idx: 4300
idx: 4350
idx: 4400
idx: 4450
idx: 4500
idx: 4550
idx: 4600
idx: 4650
idx: 4700
idx: 4750
idx: 4800
idx: 4850
idx: 4900
idx: 4950
idx: 5000
idx: 5050
idx