In [38]:
!pip install transformers datasets torch torchvision torchtext


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [39]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from datasets import load_dataset
import random
import logging

# Set random seed for reproducibility
random.seed(42)


In [40]:
# Load datasets
max_samples = 10000  # Limit the number of samples for each dataset

hotpotqa_data = load_dataset("BeIR/hotpotqa-generated-queries", split="train").select(range(max_samples))
nq_data = load_dataset("BeIR/nq-generated-queries", split="train").select(range(max_samples))
fiqa_data = load_dataset("BeIR/fiqa-generated-queries", split="train").select(range(max_samples))

# Check the loaded datasets
print(hotpotqa_data)
print(nq_data)
print(fiqa_data)


Dataset({
    features: ['_id', 'title', 'text', 'query'],
    num_rows: 10000
})
Dataset({
    features: ['_id', 'title', 'text', 'query'],
    num_rows: 10000
})
Dataset({
    features: ['_id', 'title', 'text', 'query'],
    num_rows: 10000
})


In [41]:
# Check the loaded datasets
print(hotpotqa_data)
print(nq_data)
print(fiqa_data)


Dataset({
    features: ['_id', 'title', 'text', 'query'],
    num_rows: 10000
})
Dataset({
    features: ['_id', 'title', 'text', 'query'],
    num_rows: 10000
})
Dataset({
    features: ['_id', 'title', 'text', 'query'],
    num_rows: 10000
})


In [42]:
# Function to preprocess data
def preprocess_data(data):
    passages = []
    queries = []
    for entry in data:
        if 'text' in entry:
            passages.append(entry['text'])
        if 'query' in entry:
            queries.append(entry['query'])
    return passages, queries

# Process each dataset
hotpotqa_passages, hotpotqa_queries = preprocess_data(hotpotqa_data)
nq_passages, nq_queries = preprocess_data(nq_data)
fiqa_passages, fiqa_queries = preprocess_data(fiqa_data)

# Combine passages and queries
passages = hotpotqa_passages + nq_passages + fiqa_passages
queries = hotpotqa_queries + nq_queries + fiqa_queries


In [43]:
# Process each dataset
hotpotqa_passages, hotpotqa_queries = preprocess_data(hotpotqa_data)
nq_passages, nq_queries = preprocess_data(nq_data)
fiqa_passages, fiqa_queries = preprocess_data(fiqa_data)

# Combine passages and queries
passages = hotpotqa_passages + nq_passages + fiqa_passages
queries = hotpotqa_queries + nq_queries + fiqa_queries


In [44]:
# Example of accessing the processed data
print(hotpotqa_queries[:5])   # Print first 5 queries from HotpotQA
print(hotpotqa_passages[:5])  # Print first 5 passages from HotpotQA


['anarchism defined', 'what age do children with autism develop', 'what is albedo', 'what is the word for the first letter in latin', 'where is alabama']
['Anarchism is a political philosophy that advocates self-governed societies based on voluntary institutions. These are often described as stateless societies, although several authors have defined them more specifically as institutions based on non-hierarchical free associations. Anarchism holds the state to be undesirable, unnecessary and harmful.', "Autism is a neurodevelopmental disorder characterized by impaired social interaction, impaired verbal and non-verbal communication, and restricted and repetitive behavior. Parents usually notice signs in the first two years of their child's life. These signs often develop gradually, though some children with autism reach their developmental milestones at a normal pace and then regress. The diagnostic criteria require that symptoms become apparent in early childhood, typically before age

In [45]:
# Load small embedding model
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Set device to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-5): 6 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


In [46]:
# Function to get embeddings for passages
def get_embeddings(passages, batch_size=64):
    embeddings = []
    for i in range(0, len(passages), batch_size):
        batch_passages = passages[i:i + batch_size]
        inputs = tokenizer(batch_passages, return_tensors='pt', truncation=True, padding=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Move input tensors to GPU
        with torch.no_grad():
            embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu()  # Move back to CPU after processing
        embeddings.append(embedding)
    return torch.cat(embeddings)

passage_embeddings = get_embeddings(passages)
print(passage_embeddings.shape)


torch.Size([30000, 384])


In [47]:
num_passages = len(passages)
num_embeddings = passage_embeddings.shape[0]

# Check if they match
if num_passages == num_embeddings:
    print(f"The number of passages ({num_passages}) matches the number of embeddings ({num_embeddings}).")
else:
    print(f"Mismatch: Number of passages is {num_passages}, but number of embeddings is {num_embeddings}.")


The number of passages (30000) matches the number of embeddings (30000).


In [48]:
# Load the cross-encoder model for reranking
rank_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
rank_tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")

# Move the model to the appropriate device
rank_model.to(device)


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e

In [49]:
# Function to rerank passages based on a query
def rerank_passages(query, passages, passage_embeddings):
    query_embedding = get_embeddings([query])  # Get embedding for the query
    scores = []
    
    for passage_embedding in passage_embeddings:
        score = torch.nn.functional.cosine_similarity(query_embedding, passage_embedding.unsqueeze(0))  # Ensure correct dimension
        scores.append(score.item())
        
    ranked_indices = np.argsort(scores)[::-1]  # Sort by descending scores
    return ranked_indices


In [50]:
# Function to rerank passages based on the question
def rerank_passages(question, passages_batch):
    # Prepare the input for the cross-encoder
    inputs = rank_tokenizer([[question, passage] for passage in passages_batch], return_tensors='pt', padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move input tensors to TPU

    # Get the model's predictions (scores)
    with torch.no_grad():
        scores = rank_model(**inputs).logits  # Get raw scores from the model

    # Check the shape of scores and process accordingly
    if scores.shape[1] == 1:  # If model returns a single score
        scores = scores.squeeze()  # Remove the extra dimension
        ranked_indices = torch.argsort(scores, descending=True).cpu().numpy()
    else:  # If model returns two scores for binary classification
        scores = scores.softmax(dim=1)[:, 1]  # Get the positive class probabilities
        ranked_indices = torch.argsort(scores, descending=True).cpu().numpy()

    return ranked_indices


In [51]:
# Process and rerank passages
for i in range(0, len(passages), 64):  # Using batch size of 64
    batch_passages = passages[i:i + 64]  # Select the current batch of 64
    
    # Rerank this batch
    ranked_batch_indices = rerank_passages("what age do children with autism develop?", batch_passages)
    
    # Get the ranked passages for this batch
    ranked_batch_passages = [batch_passages[idx] for idx in ranked_batch_indices]
    
    # Print a limited number of ranked passages
    print("Ranked passages for batch:")
    for passage in ranked_batch_passages[:1]:  # Print 1 ranked passage
        print(passage[:50])  # Print only the first 50 characters of each passage
        print("...")  # Indicate that the passage is longer


Ranked passages for batch:
Autism is a neurodevelopmental disorder characteri
...
Ranked passages for batch:
Alexander III of Macedon (20/21 July 356 BC – 10/1
...
Ranked passages for batch:
Ataxia is a neurological sign consisting of lack o
...
Ranked passages for batch:
April 15 is the day of the year in the Gregorian c
...
Ranked passages for batch:
August 13 is the day of the year in the Gregorian 
...
Ranked passages for batch:
The Americans with Disabilities Act of 1990 ( /121
...
Ranked passages for batch:
August 18 is the day of the year in the Gregorian 
...
Ranked passages for batch:
Alexander III married at the age of 9 "Alaxandair 
...
Ranked passages for batch:
The age of consent is the age at which a person is
...
Ranked passages for batch:
Hermann Kolbe ("Adolph Wilhelm Hermann Kolbe", 27 
...
Ranked passages for batch:
April 14 is the day of the year in the Gregorian c
...
Ranked passages for batch:
Anne Brontë ( , "commonly" ; 17 January 1820 – 28 
...
Ranked passages 

In [54]:
from sklearn.metrics import ndcg_score

def evaluate_retrieval(ranked_passages, relevance_scores, k=10):
    """
    Evaluate the ranked passages using NDCG (Normalized Discounted Cumulative Gain).
    
    Parameters:
    - ranked_passages: The indices of passages in the ranked order.
    - relevance_scores: A list of relevance scores for each passage (higher is more relevant).
    - k: Number of top passages to consider for NDCG calculation (default is 10).
    
    Returns:
    - ndcg: The NDCG score for the top-k ranked passages.
    """
    # Sort relevance_scores according to ranked_passages order
    ranked_relevance_scores = [relevance_scores[i] for i in ranked_passages]

    # Truncate to top k
    ranked_relevance_scores = ranked_relevance_scores[:k]
    relevance_scores = relevance_scores[:k]

    # Calculate NDCG score (expects a 2D array: [ground truth, predicted])
    ndcg = ndcg_score([relevance_scores], [ranked_relevance_scores], k=k)
    return ndcg

# Example usage
# Assuming you have ground truth relevance scores for each passage
relevance_scores = [3, 2, 1, 0, 0, 0]  # Example relevance scores for all passages
ranked_passages = [0, 2, 1, 4, 3]      # Indices of the ranked passages from your model

ndcg = evaluate_retrieval(ranked_passages, relevance_scores, k=5)
print(f"NDCG@5: {ndcg}")


NDCG@5: 0.9725044904464192
