In [9]:
import pandas as pd
import json
import random
import torch
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

In [12]:
dev = pd.read_csv('./Data/dev.csv')
dev.head()

Unnamed: 0,query_id,query,positive_docs,negative_docs,lang
0,q-en-1,What is the syntax for the shorthand of the co...,doc-en-8,"['doc-en-9', 'doc-en-10', 'doc-en-11', 'doc-en...",en
1,q-en-49,What other companies did the FRC investigate K...,doc-en-391,"['doc-en-392', 'doc-en-393', 'doc-en-394', 'do...",en
2,q-en-54,When did Canadian journalists hold a protest t...,doc-en-431,"['doc-en-432', 'doc-en-433', 'doc-en-434', 'do...",en
3,q-en-80,What is the full name of the plant species des...,doc-en-637,"['doc-en-638', 'doc-en-639', 'doc-en-640', 'do...",en
4,q-en-82,Who ordered the evacuation of the settlers in ...,doc-en-652,"['doc-en-653', 'doc-en-654', 'doc-en-655', 'do...",en


In [13]:
test = pd.read_csv('./Data/test.csv')
test.head()

Unnamed: 0,id,query_id,query,lang
0,0,q-en-0,What organization proposed listing PFOA under ...,en
1,1,q-en-2,What type of coating do ZM1130 - ZM1132 have?,en
2,2,q-en-4,What year did Deutsche Bank sell its stake in ...,en
3,3,q-en-5,Who expressed exasperation when Raphael and Mo...,en
4,4,q-en-7,Who commissioned Amy Beach to compose a choral...,en


In [14]:
train = pd.read_csv('./Data/train.csv')
train.head()

Unnamed: 0,query_id,query,positive_docs,negative_docs,lang
0,q-en-425512,What is the connection between AAA and Lucha U...,doc-en-798457,"['doc-en-810925', 'doc-en-634020', 'doc-en-143...",en
1,q-en-16636,What is the medical use of iloperidone?,doc-en-121692,"['doc-en-177976', 'doc-en-700330', 'doc-en-567...",en
2,q-en-282671,Who was the provisional administrator in 1940?,doc-en-750259,"['doc-en-805362', 'doc-en-413387', 'doc-en-827...",en
3,q-en-216614,What was the critical reception of the film se...,doc-en-703883,"['doc-en-685958', 'doc-en-84060', 'doc-en-2046...",en
4,q-en-156120,What was the main Spanish record of the year i...,doc-en-648393,"['doc-en-4307', 'doc-en-761696', 'doc-en-79426...",en


In [15]:
sample_submission = pd.read_csv('./Data/sample_submission.csv')
sample_submission.head()

Unnamed: 0,id,docids
0,0,"['doc-en-0', 'doc-de-14895', 'doc-en-829265', ..."
1,1,"['doc-en-447132', 'doc-en-773190', 'doc-en-504..."
2,2,"['doc-en-32', 'doc-en-414951', 'doc-en-564939'..."
3,3,"['doc-en-822169', 'doc-en-441656', 'doc-en-814..."
4,4,"['doc-en-5056', 'doc-en-772925', 'doc-en-72232..."


In [6]:
with open("./Data/corpus.json/corpus.json", "r") as f:
    documents = json.load(f)

In [10]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

Python(77865) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [10]:
def embed_documents(docs):
    embeddings = {}
    
    for doc in tqdm(docs, desc="Embedding documents"):
        # Tokenisation du texte
        doc_id = doc["docid"]
        text = doc["text"]
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        
        # Propagation à travers BERT
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Extraire les embeddings du token [CLS] (représentation du document)
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
        
        # Stocker l'embedding
        embeddings[doc_id] = cls_embedding
    
    return embeddings

In [11]:
doc_ar = [doc for doc in documents if doc['lang'] == 'ar']
print(len(doc_ar))
print(len(documents))

NameError: name 'documents' is not defined

In [31]:
doc_embeddings = embed_documents(documents[:20])

In [32]:
def retrieve_documents(query, embeddings, top_k=10):
    # Tokenisation de la requête
    inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Propagation à travers BERT
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extraire l'embedding du token [CLS] (représentation de la requête)
    query_embedding = outputs.last_hidden_state[:, 0, :].squeeze()
    
    # Calculer les similarités cosinus entre la requête et les embeddings des documents
    similarities = {}
    for doc_id, doc_embedding in embeddings.items():
        similarity = cosine_similarity(query_embedding.reshape(1, -1), doc_embedding.reshape(1, -1)).item()
        similarities[doc_id] = similarity
    
    # Trier les documents par similarité et récupérer les top_k
    results = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k]
    
    return results

In [33]:
retrieve_documents("How to cook a cake", doc_embeddings)

[('doc-en-8773', 0.6215211153030396),
 ('doc-en-16038', 0.6109321117401123),
 ('doc-en-14104', 0.5731789469718933),
 ('doc-en-4639', 0.5707907676696777),
 ('doc-en-3128', 0.564104437828064),
 ('doc-en-16475', 0.5610914826393127),
 ('doc-en-5745', 0.5493555665016174),
 ('doc-en-9696', 0.5429341793060303),
 ('doc-en-3366', 0.5365767478942871),
 ('doc-en-4033', 0.523850679397583)]