<a href="https://colab.research.google.com/github/khushigupta20/Small-Prototype-of-RAG/blob/main/small_prototype_of_rag_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
from transformers import BartTokenizer, BartForConditionalGeneration

# Initialize the retriever (DPR)
question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

# Initialize the generator (Bart)
generator_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
generator_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')

# Example contexts (passages)
contexts = [
    "We analyzed data on 30,312 women from the EstBB cohort.",
    "They entered the cohort between 2002 and 2011, were between 20 and 89 years, without a history of breast cancer, and with full 5-year follow-up by 2015.",
    "We examined PRS and other potential risk factors as possible predictors in Cox regression models for breast cancer incidence.",
    "With 10-fold cross-validation we estimated 3- and 5-year breast cancer incidence predicted by age alone and by PRS plus age, fitting models on 90% of the data.",
    "Calibration, discrimination, and reclassification were calculated on the left-out folds to express prognostic performance.",
]

# Encode the contexts
context_embeddings = []
for context in contexts:
    inputs = context_tokenizer(context, return_tensors='pt')
    embeddings = context_encoder(**inputs).pooler_output
    context_embeddings.append(embeddings)

context_embeddings = torch.cat(context_embeddings, dim=0)

# Function to find the closest context using dot product
def retrieve(query):
    query_inputs = question_tokenizer(query, return_tensors='pt')
    query_embedding = question_encoder(**query_inputs).pooler_output

    # Compute dot product
    scores = torch.matmul(context_embeddings, query_embedding.T)
    closest_context_idx = torch.argmax(scores).item()

    return contexts[closest_context_idx]

# Function to generate an answer using the retrieved context
def generate_answer(query):
    retrieved_context = retrieve(query)
    print(f"Retrieved context: {retrieved_context}")

    input_text = query + " " + retrieved_context
    input_ids = generator_tokenizer(input_text, return_tensors='pt').input_ids

    output_ids = generator_model.generate(input_ids, max_length=50, num_beams=5, early_stopping=True)
    answer = generator_tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return answer

# Example query
query = "Analyzation is done on how many women?"
answer = generate_answer(query)
print(f"Answer: {answer}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/493 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


config.json:   0%|          | 0.00/492 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Retrieved context: We analyzed data on 30,312 women from the EstBB cohort.




Answer: Analyzation is done on how many women? We analyzed data on 30,312 women from the EstBB cohort. We found that the majority of them were women of color. We also found that women were more likely than men to
