In [1]:
import pandas as pd
from pathlib import Path

In [2]:
data_path = Path.cwd().resolve().absolute().parent / "data"
documents_filepath = data_path / "documents.csv"
test_data_filepath = data_path / "test_query_context_pairs.csv"

In [3]:
TOP_K = 6

## Load the test data

In [4]:
test_data = pd.read_csv(test_data_filepath)
test_data

Unnamed: 0,question,context,question_id,context_id
0,Which Japanese carrier survived the first wave...,With the Japanese CAP out of position and the ...,00d9f4e0b8a0c4654bc202f7be8cb243,05bb76ec5fcf0a6b793b5bf3607fdbb9
1,What are anarchists not against?,Anarchists are against the State but are not a...,9ec4aab7fcf6a247539c9f1e3094f37b,15f04e0814d60718a465c034077269dd
2,How did troops react to the missile?,"On 1 May, The Sun claimed to have 'sponsored' ...",d1d6e5c2409d498f00c690e14669e5d0,23d120965eaf2a679c23d7d88eb3fe1a
3,Who primarily occupies the complexes surroundi...,Ann Arbor's residential neighborhoods contain ...,6e99378749ee466ce63713c931d3f740,eb70588ed7264bcd249d6705e1b90547
4,Does God have a gender?,"In monotheism and henotheism, God is conceived...",e6908fed63cb1c828f422793452a5806,1d644948160ff5feb5f6bfae66d8525e
...,...,...,...,...
3948,What percent of the GDP was spent on health?,Despite repeated efforts by the Tajik governme...,2e9b3ababee633b342a5caa2740d456d,6ba7a5fb7bed5afb82e70a427d0111eb
3949,During what time did cotton become widely used...,The earliest evidence of cotton use in South A...,9111bdc85cb44e518f467e302e617473,72f0f5f533a36ccb4bff1c043971b38f
3950,What legends show the importance of the propri...,"In most parts of medieval Europe, the upper cl...",ef07514867301df64dc26a44d656aa7d,ef91a3c8fc57d80a8df4352ca375af8b
3951,What is its longest river?,"Galicia is poetically known as the ""country of...",323a472ab4577d5267630b452529c7cd,7f7a0127db2528c8b2bd89566f1f9d07


## Prepare the data for evaluation

In [5]:
df = test_data.groupby("question_id").agg({"context_id": list}).reset_index()
test_set = {}
for _, row in df.iterrows():
    test_set[row["question_id"]] = [
        (context_id, 1)
        for context_id in row["context_id"]
    ]
len(test_set)

3953

## Prepare the metrics

In [6]:
from src.metrics import RankMetrics

metrics = RankMetrics()

## Random Retriever

In [7]:
from src.indexer import DatabaseIndexer
from src.retriever import RandomRetriever
from src.clients import InMemoryDatabaseClient
from src.repositories import CsvDocumentsRepository

documents_repository = CsvDocumentsRepository(
    path=documents_filepath,
    document_id_column="document_id",
    document_content_column="document_content"
)

in_memory_db_client = InMemoryDatabaseClient()

db_indexer = DatabaseIndexer(client=in_memory_db_client)

random_retriever = RandomRetriever(client=in_memory_db_client, random_state=42)

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


### Index documents

In [8]:
docs = documents_repository.get_all()
db_indexer.index(docs)

### Run retrieval

In [9]:
predictions = {}
for query_id in test_set.keys():
    results = random_retriever.retrieve(query="dummy", k=TOP_K)
    predictions[query_id] = [
        (result.document.id, result.relevance)
        for result in results
    ]
len(predictions)

3953

### Get metrics

In [10]:
metrics.compute_metrics(ground_truth=test_set, predictions=predictions, top_k=TOP_K)

{'map@6': 0.00025297242600556537,
 'mrr@6': 0.00025297242600556537,
 'ndcg@6': 0.00037750641783917114}

## Vector Search Retriever + Random Encoder

In [7]:
from src.indexer import DatabaseIndexer
from src.retriever import VectorSearchRetriever
from src.clients import ChromaDatabaseClient
from src.encoders import RandomEncoder
from src.repositories import CsvDocumentsRepository

documents_repository = CsvDocumentsRepository(
    path=documents_filepath,
    document_id_column="document_id",
    document_content_column="document_content"
)

random_encoder = RandomEncoder(embedding_dim=16, random_state=42)

vector_db_client = ChromaDatabaseClient(collection_name="open_domain")

db_indexer = DatabaseIndexer(client=vector_db_client, encoder=random_encoder)

random_vector_search_retriever = VectorSearchRetriever(client=vector_db_client, encoder=random_encoder)

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


### Index documents

In [8]:
docs = documents_repository.get_all()
db_indexer.index(docs)

### Run retrieval

In [9]:
query_mapper = {}
for _, row in test_data.iterrows():
    query_mapper[row["question_id"]] = row["question"]
len(query_mapper)

3953

In [10]:
predictions = {}
for query_id in test_set.keys():
    query = query_mapper[query_id]
    results = random_vector_search_retriever.retrieve(query=query, k=TOP_K)
    predictions[query_id] = [
        (result.document.id, result.relevance)
        for result in results
    ]
len(predictions)

3953

### Get metrics

In [11]:
metrics.compute_metrics(ground_truth=test_set, predictions=predictions, top_k=TOP_K)

{'map@6': 0.000674593136014841,
 'mrr@6': 0.000674593136014841,
 'ndcg@6': 0.0007556632786945307}

## Vector Search Retriever + TF-IDF Encoder

In [7]:
from src.indexer import DatabaseIndexer
from src.retriever import VectorSearchRetriever
from src.clients import ChromaDatabaseClient
from src.encoders import TfIdfEncoder
from src.processors import TextProcessor
from src.repositories import CsvDocumentsRepository

documents_repository = CsvDocumentsRepository(
    path=documents_filepath,
    document_id_column="document_id",
    document_content_column="document_content"
)

text_processor = TextProcessor()

tf_idf_encoder = TfIdfEncoder(
    model_filepath="/Users/joao.barroca/Desktop/projects/deus-use-case/models/encoders/tf-idf-encoder.joblib",
    text_processor=text_processor,
)

vector_db_client = ChromaDatabaseClient(collection_name="open_domain")

db_indexer = DatabaseIndexer(client=vector_db_client, encoder=tf_idf_encoder)

tf_idf_retriever = VectorSearchRetriever(client=vector_db_client, encoder=tf_idf_encoder)

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


### Index documents

In [8]:
docs = documents_repository.get_all()
db_indexer.index(docs)

### Run retrieval

In [9]:
query_mapper = {}
for _, row in test_data.iterrows():
    query_mapper[row["question_id"]] = row["question"]
len(query_mapper)

3953

In [10]:
queries = [query_mapper[query_id] for query_id in test_set.keys()]
batch_results = tf_idf_retriever.batch_retrieve(queries=queries, k=TOP_K)

predictions = {}
for query_id, results in zip(test_set.keys(), batch_results):
    predictions[query_id] = [
        (result.document.id, result.relevance)
        for result in results
    ]
len(predictions)

3953

### Get metrics

In [23]:
metrics.compute_metrics(ground_truth=test_set, predictions=predictions, top_k=TOP_K)

{'map@6': 0.2013449700649296,
 'mrr@6': 0.2013449700649296,
 'ndcg@6': 0.22658679706296206}

## Vector Search Retriever + Sentence-Transformer Encoder

In [7]:
from src.indexer import DatabaseIndexer
from src.retriever import VectorSearchRetriever
from src.clients import ChromaDatabaseClient
from src.encoders import SentenceTransformersEncoder
from src.repositories import CsvDocumentsRepository

documents_repository = CsvDocumentsRepository(
    path=documents_filepath,
    document_id_column="document_id",
    document_content_column="document_content"
)

st_encoder = SentenceTransformersEncoder(
    model_filepath="/Users/joao.barroca/Desktop/projects/deus-use-case/models/encoders/sentence-transformers/mpnet-base-deus-v2",
)

vector_db_client = ChromaDatabaseClient(collection_name="open_domain")

db_indexer = DatabaseIndexer(client=vector_db_client, encoder=st_encoder)

st_retriever = VectorSearchRetriever(client=vector_db_client, encoder=st_encoder)

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/joao.barroca/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
  from .autonotebook import tqdm as notebook_tqdm


### Index documents

In [8]:
docs = documents_repository.get_all()
db_indexer.index(docs)

### Run retrieval

In [9]:
query_mapper = {}
for _, row in test_data.iterrows():
    query_mapper[row["question_id"]] = row["question"]
len(query_mapper)

3953

In [10]:
queries = [query_mapper[query_id] for query_id in test_set.keys()]
batch_results = st_retriever.batch_retrieve(queries=queries, k=TOP_K)

predictions = {}
for query_id, results in zip(test_set.keys(), batch_results):
    predictions[query_id] = [
        (result.document.id, result.relevance)
        for result in results
    ]
len(predictions)

3953

### Get metrics

In [11]:
metrics.compute_metrics(ground_truth=test_set, predictions=predictions, top_k=TOP_K)

{'map@6': 0.564322455519015,
 'mrr@6': 0.564322455519015,
 'ndcg@6': 0.6020025240200427}

In [65]:
import random
from pprint import pprint

query_ids  = list(test_set.keys())
docs_mapper = {doc.id: doc for doc in docs}

query_id = random.choice(query_ids)
print("Query: ", query_mapper[query_id])
relevant_docs = test_set[query_id]
retrieved_docs = predictions[query_id]
print(
    metrics.compute_metrics(
        ground_truth={query_id: relevant_docs},
        predictions={query_id: retrieved_docs},
        top_k=6,
    )
)
print(relevant_docs)
for doc_id, _ in relevant_docs:
    pprint(docs_mapper[doc_id].content)
print(retrieved_docs)
for doc_id, _ in retrieved_docs:
    pprint(docs_mapper[doc_id].content)

Query:  What new name was given to Revue Productions in 1966?
{'map@6': 0.5, 'mrr@6': 0.5, 'ndcg@6': 0.6309297535714575}
[('c8a087f77d6abde08025355228fe54a6', 1)]
('The long-awaited takeover of Universal Pictures by MCA, Inc. happened in '
 'mid-1962 as part of the MCA-Decca Records merger. The company reverted in '
 'name to Universal Pictures. As a final gesture before leaving the talent '
 'agency business, virtually every MCA client was signed to a Universal '
 'contract. In 1964 MCA formed Universal City Studios, Inc., merging the '
 'motion pictures and television arms of Universal Pictures Company and Revue '
 'Productions (officially renamed as Universal Television in 1966). And so, '
 'with MCA in charge, Universal became a full-blown, A-film movie studio, with '
 'leading actors and directors under contract; offering slick, commercial '
 'films; and a studio tour subsidiary launched in 1964. Television production '
 "made up much of the studio's output, with Universal heavily

In [46]:
import numpy as np

query = "When was ZE's Mutant Disco released?"
doc_id = "acf263c9a76c20354124f79e4ed5f7f7"
r_doc_id = "b8861b923b9373375bc13d44ef4b0089"

query_vector = np.array(st_encoder.encode(query))
doc_vector = np.array(st_encoder.encode(docs_mapper[doc_id].content))
r_doc_vector = np.array(st_encoder.encode(docs_mapper[r_doc_id].content))
print(query_vector.shape, doc_vector.shape, r_doc_vector.shape)

query_vector.dot(doc_vector), query_vector.dot(r_doc_vector)

((768,), (768,), (768,))

## TODO: Hybrid Retrieval

We can use some kind of keyword matching using:
- `where={"metadata_field": "is_equal_to_this"}`
- `where_document={"$contains":"search_string"}`
- `where_document = {$contains: {"text": "hello"}}`

## Results

In [None]:
random_retrieval_metrics = {
    'map@6': 0.00025297242600556537,
    'mrr@6': 0.00025297242600556537,
    'ndcg@6': 0.00037750641783917114
}

vector_retrieval_with_random_encoder_metrics = {
    'map@6': 0.0006408634792140989,
    'mrr@6': 0.0006408634792140989,
    'ndcg@6': 0.0007840291377309857
}

vector_retrieval_with_tfidf_encoder_metrics = {
    'map@6': 0.2013449700649296,
    'mrr@6': 0.2013449700649296,
    'ndcg@6': 0.22658679706296206
}

# all-MiniLM-L6-v2 (foundational model)
vector_retrieval_with_sentence_transformers_encoder_metrics = {
    'map@6': 0.6757019984821655,
    'mrr@6': 0.6757019984821655,
    'ndcg@6': 0.7099112326985552
}

# all-mpnet-base-v2 (foundational model)
vector_retrieval_with_sentence_transformers_encoder_metrics = {
    'map@6': 0.7226663293700986,
    'mrr@6': 0.7226663293700986,
    'ndcg@6': 0.7568168385788774
}

# multi-qa-mpnet-base-cos-v1 (fine-tuned model - multi-qa data - 215M (question, answer) pairs)
vector_retrieval_with_sentence_transformers_encoder_metrics = {
    'map@6': 0.7291086938190404,
    'mrr@6': 0.7291086938190404,
    'ndcg@6': 0.7592275627579353
}

# all-mpnet-base-v2-deus (fine-tuned model - DEUS data)
vector_retrieval_with_sentence_transformers_encoder_metrics = {
    'map@6': 0.7313896618601906,
    'mrr@6': 0.7313896618601906,
    'ndcg@6': 0.763747792759782
}

# # mpnet-base-deus (fine-tuned model - DEUS data)
vector_retrieval_with_sentence_transformers_encoder_metrics = {
    'map@6': 0.564322455519015,
    'mrr@6': 0.564322455519015,
    'ndcg@6': 0.6020025240200427
}