# LLM RAG example

**There are things an LLM doesn't know and must look up.**

More here:

https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/



## High-Level Intuition of Retrieval-Augmented Generation (RAG)

**Retrieval-Augmented Generation (RAG)** is a **method that improves how AI models answer questions by looking up relevant information before generating a response.**

1. **Retriever**: Finds relevant information from a large collection of documents based on the question asked.
2. **Generator**: Uses both the question and the found information to create a more accurate and informative answer.

**Why RAG?**
- **Better Answers**: By checking relevant information first, the AI gives more accurate and useful responses.

**How It Works?**
1. **Ask a Question**: You provide a question.
2. **Look Up Information**: The model searches for related information from a database.
3. **Generate Answer**: The model uses the found information to generate a well-informed answer.

In short, RAG combines looking up information and generating text to give better answers.

# Tutorial

## Retrieval-Augmented Generation (RAG) Overview

### RAG combines two components:

    Retriever: Retrieves relevant documents from a predefined knowledge base.
    Generator: Generates text based on the retrieved documents.

### Steps to Implement RAG:

    Load a Pre-trained Language Model and a Retriever: Use a small language model and a retriever model.
    Create a Knowledge Base: Prepare a simple corpus of documents.
    Retrieve Relevant Documents: Use the retriever to find documents relevant to the input query.
    Generate a Response: Use the language model to generate a response based on the retrieved documents.

## Explanation of the Code

The provided code demonstrates a simple implementation of a Retrieval-Augmented Generation (RAG) system using smaller pre-trained models from the Hugging Face Transformers library. It involves four main steps:

#### 1. Install Necessary Packages
The code begins by installing required libraries: `transformers`, `datasets`, and `sentencepiece`. These libraries provide tools for working with pre-trained language models and datasets, and `sentencepiece` is specifically used for tokenization with the T5 model.

#### 2. Import Libraries
Next, the code imports necessary modules from the `transformers` library, as well as standard Python libraries for numerical operations and tensor handling. The transformers library provides access to various pre-trained models and tokenizers.

#### 3. Load Pre-trained Models and Tokenizers
The script loads pre-trained models and tokenizers for both the retriever and the generator components:
- **Retriever**: Utilizes `DistilBERT`, a smaller and faster variant of BERT, for retrieving relevant documents. The `DistilBertTokenizer` tokenizes input text for the `DistilBertModel`, which generates embeddings for the input text.
- **Generator**: Uses `T5`, a model designed for text generation tasks. The `T5Tokenizer` tokenizes input text for the `T5ForConditionalGeneration` model, which generates text responses based on the input context.

#### 4. Create a Knowledge Base
The script creates a simple knowledge base, which is a list of documents containing basic factual sentences. These documents are encoded using the retriever model to create embeddings that represent the content of each document.

#### 5. Retrieve Relevant Documents
A function `retrieve_docs` is defined to retrieve relevant documents from the knowledge base based on a given query. The function:
- Encodes the query using the retriever tokenizer and model.
- Computes the similarity between the query embeddings and the document embeddings using dot products.
- Retrieves the top k most relevant documents based on the computed similarity scores.

#### 6. Generate a Response
Another function `generate_response` is defined to generate a text response based on the retrieved documents. The function:
- Retrieves relevant documents using the `retrieve_docs` function.
- Concatenates the retrieved documents to form a context.
- Tokenizes the query and context using the T5 tokenizer.
- Generates a response using the T5 model based on the tokenized input.
- Decodes the generated response from token IDs back to text.

#### Example Usage
The script includes an example usage section where a query ("What color is the sky?") is processed to generate a response. The query is passed to the `generate_response` function, which retrieves relevant documents from the knowledge base and generates an informative response using the T5 model. The response is then printed out, demonstrating the RAG system in action.

In [5]:
# Install necessary packages
#!pip install transformers datasets sentencepiece

import numpy as np
import torch
from transformers import DistilBertTokenizer, DistilBertModel, T5Tokenizer, T5ForConditionalGeneration

# 1. Load Smaller Pre-trained Models and Tokenizers
# Using DistilBERT for the retriever
retriever_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
retriever_model = DistilBertModel.from_pretrained('distilbert-base-uncased')

# Using T5 for the generator
generator_tokenizer = T5Tokenizer.from_pretrained('t5-small')
generator_model = T5ForConditionalGeneration.from_pretrained('t5-small')

# 2. Create a Knowledge Base
documents = [
    "The sky is blue.",
    "The sun is bright.",
    "There are many stars in the universe.",
    "AI is transforming the world."
]

# Encode the documents using the retriever
contexts = retriever_tokenizer(documents, truncation=True, padding=True, return_tensors="pt")
context_embeddings = retriever_model(**contexts).last_hidden_state.mean(dim=1)

# 3. Retrieve Relevant Documents
def retrieve_docs(query, top_k=1):
    # Encode the query using the retriever
    query_inputs = retriever_tokenizer(query, return_tensors="pt")
    query_embeddings = retriever_model(**query_inputs).last_hidden_state.mean(dim=1)
    
    # Compute dot product between query and context embeddings
    scores = torch.matmul(query_embeddings, context_embeddings.T)
    
    # Retrieve top k documents
    top_k_indices = torch.topk(scores, k=top_k, dim=-1).indices.squeeze().tolist()
    if isinstance(top_k_indices, int):
        top_k_indices = [top_k_indices]
    return [documents[i] for i in top_k_indices]

# 4. Generate a Response
def generate_response(query):
    retrieved_docs = retrieve_docs(query)
    context = " ".join(retrieved_docs)
    
    # Tokenize input for T5
    inputs = generator_tokenizer(f"question: {query} context: {context}", return_tensors="pt", max_length=512, truncation=True)
    
    # Generate response
    output = generator_model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], num_return_sequences=1)
    
    return generator_tokenizer.decode(output[0], skip_special_tokens=True)

# Example usage
query = "What color is the sky?"
response = generate_response(query)
print(f"Query: {query}")
print(f"Response: {response}")


Query: What color is the sky?
Response: blue


