In [10]:
pip install tf-keras sentence-transformers transformers faiss-cpu datasets


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.1.2 -> 24.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


# Document Preparation

In [50]:
import os

def load_documents(directory):
    corpus = []
    filenames = []
    for filename in os.listdir(directory):
        if filename.endswith(".txt"):
            with open(os.path.join(directory, filename), 'r', encoding='utf-8') as file:
                corpus.append(file.read())
                filenames.append(filename)
    return corpus, filenames

docs_path = 'docs'
corpus, filenames = load_documents(docs_path)
print(f"Loaded {len(corpus)} documents.")

Loaded 10 documents.


# Embeddings using DPRQuestionEncoder

In [52]:
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import torch

# Initialize DPR components
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

# Create document embeddings using the same model
def create_embeddings(texts):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        embeddings = question_encoder(**inputs).pooler_output
    return embeddings

doc_embeddings = create_embeddings(corpus)
print("Document embeddings created.")

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Document embeddings created.


# Save the Dataset and Index

In [54]:
from datasets import Dataset
import os

# Create a dataset with required columns
dataset_dict = {
    "title": [f"Document {i}" for i in range(len(corpus))],  # Dummy titles
    "text": corpus,
    "embeddings": doc_embeddings.tolist()  # Convert tensors to lists for serialization
}

dataset = Dataset.from_dict(dataset_dict)

# Create simple paths
dataset_path = "rag_dataset"
index_path = "rag_index"
os.makedirs(dataset_path, exist_ok=True)
os.makedirs(index_path, exist_ok=True)

# Save dataset and index
dataset.save_to_disk(dataset_path)
dataset.add_faiss_index(column="embeddings")
dataset.get_index("embeddings").save(os.path.join(index_path, "faiss_index"))

Saving the dataset (0/1 shards):   0%|          | 0/10 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 10/10 [00:00<00:00, 330.08 examples/s]
100%|██████████| 1/1 [00:00<00:00, 200.70it/s]


# Load Dataset and Index for RAG

In [66]:
from transformers import RagRetriever, RagSequenceForGeneration, RagTokenizer

# Load dataset and index
dataset = Dataset.load_from_disk(dataset_path)
dataset.load_faiss_index("embeddings", os.path.join(index_path, "faiss_index"))

# Initialize RAG retriever and generator components
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="embeddings", use_dummy_dataset=False, indexed_dataset=dataset)
rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
rag_model.set_retriever(retriever)
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
print("RAG model components initialized.")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

RAG model components initialized.


# Retrieval and Generation Function

In [71]:
import numpy as np

def rag_retrieve_and_generate(query, n_docs=5):
    # Encode the query
    input_ids = tokenizer(query, return_tensors="pt").input_ids

    # Retrieve documents
    question_hidden_states = rag_model.question_encoder(input_ids)[0]
    question_hidden_states_np = question_hidden_states.detach().cpu().numpy()
    
    # Debugging: Print shapes of the embeddings
    print(f"Query embedding shape: {question_hidden_states_np.shape}")
    print(f"Document embedding shape: {np.array(dataset['embeddings'][0]).shape}")

    # Check dimensions
    assert question_hidden_states_np.shape[1] == np.array(dataset["embeddings"][0]).shape[0], \
        f"Dimension mismatch: query embedding has dimension {question_hidden_states_np.shape[1]} but document embeddings have dimension {np.array(dataset['embeddings'][0]).shape[0]}"
    
    # Use the retriever
    results = retriever(
        question_input_ids=input_ids,
        question_hidden_states=question_hidden_states_np,
        return_tensors="pt"
    )

    # Debugging: Inspect the contents of results
    print("Retriever results keys:", results.keys())

    # Extract the necessary values from the results
    doc_indices = results["doc_ids"]
    context_input_ids = results["context_input_ids"]
    context_attention_mask = results["context_attention_mask"]
    doc_scores = results["retrieved_doc_embeds"]

    # Get the document contents
    retrieved_docs = [dataset[idx.item()]["text"] for idx in doc_indices[0]]

    # Debugging: Print the retrieved documents
    print(f"Retrieved documents: {retrieved_docs}")

    # Generate the response using RAG
    generated_ids = rag_model.generate(
        input_ids=input_ids,
        context_input_ids=context_input_ids,
        context_attention_mask=context_attention_mask,
        doc_scores=doc_scores,
        num_return_sequences=1
    )

    # Decode the generated response
    response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return response

# Inference

In [73]:
query = "What is Blockchain?"
response = rag_retrieve_and_generate(query)
print(response)

Query embedding shape: (1, 768)
Document embedding shape: (768,)
Retriever results keys: dict_keys(['context_input_ids', 'context_attention_mask', 'retrieved_doc_embeds', 'doc_ids'])
Retrieved documents: ['Blockchain is a decentralized digital ledger that records transactions across many computers so that the record cannot be altered retroactively without the alteration of all subsequent blocks and the consensus of the network. Blockchain technology is the backbone of cryptocurrencies like Bitcoin.\n', 'Deep learning is a class of machine learning algorithms that use multiple layers of artificial neural networks to model complex patterns in data. It has been particularly successful in areas such as image and speech recognition, natural language processing, and autonomous driving.\n', 'Big data refers to data sets that are so large or complex that traditional data processing applications are inadequate to deal with them. Challenges include data capture, storage, analysis, data curation,