In [32]:
pip install transformers datasets faiss-cpu


Note: you may need to restart the kernel to use updated packages.


In [100]:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration,DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
from datasets import Dataset
import torch

In [134]:

# 1. Load DPR models
q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
c_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
c_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
q_encoder.to(device)
c_encoder.to(device)

# 2. Your manual
manual_chunks = [
    "Step 1: Preheat the oven to 180°C (356°F). Ensure the racks are in the middle position.",
    "Step 2: In a large mixing bowl, combine 2 cups of flour, 1 cup of sugar, 1 teaspoon of baking powder, and 1/2 teaspoon of salt.",
    "Step 3: In another bowl, beat 2 eggs with 1 cup of milk and 1/4 cup of melted butter.",
    "Step 4: Slowly combine the wet ingredients with the dry ingredients until smooth. Do not overmix.",
    "Step 5: Grease a 9x13 inch baking pan and pour the batter evenly into the pan.",
    "Step 6: Bake for 30-35 minutes, or until a toothpick inserted in the center comes out clean.",
    "Step 7: Remove the cake from the oven and let it cool for 15 minutes before slicing.",
    "Step 8: Optional: Frost with chocolate or vanilla frosting once cooled."
]
# 3. Compute embeddings
context_embeddings = []
for chunk in manual_chunks:
    inputs = c_tokenizer(chunk, return_tensors="pt", truncation=True).to(device)
    emb = c_encoder(**inputs).pooler_output.detach().cpu().numpy()
    context_embeddings.append(emb[0])
context_embeddings = np.stack(context_embeddings)  # shape = (num_chunks, 768)



Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the

In [135]:
titles = [f"manual_chunk_{i}" for i in range(len(manual_chunks))]

dataset = Dataset.from_dict({
    "title": titles,
    "text": manual_chunks,
    "embeddings": list(context_embeddings)  # must be a list of arrays
})

dataset.add_faiss_index("embeddings")



  0%|          | 0/1 [00:00<?, ?it/s]

Dataset({
    features: ['title', 'text', 'embeddings'],
    num_rows: 8
})

In [136]:
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")

# Create retriever with your dataset
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-base",
    index_name="custom",
    passages_path=None,
    indexed_dataset=dataset, 
)
retriever.n_docs = 2
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)


model.to(device)


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

RagTokenForGeneration(
  (rag): RagModel(
    (question_encoder): DPRQuestionEncoder(
      (question_encoder): DPREncoder(
        (bert_model): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(30522, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (token_type_embeddings): Embedding(2, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0-11): 12 x BertLayer(
                (attention): BertAttention(
                  (self): BertSdpaSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(

In [144]:

query = (
    "How do I bake the cake?"
)

# Tokenize the inputs
inputs = tokenizer(query, return_tensors="pt").to(device)

# RAG will now retrieve from your manual dataset + generate
outputs = model.generate(**inputs, max_new_tokens=50)

answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print("Answer:", answer)


Answer: 55. manual_chunk_6 / Step 7: Remove the cake from the
