In [1]:
# Load all required Libraries
import pandas as pd
import transformers, torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import gc
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType

  from .autonotebook import tqdm as notebook_tqdm


# Read Passages from the Datasets and Drop rows if they are NA or empty

In [2]:
passages = pd.read_parquet("hf://datasets/rag-datasets/rag-mini-wikipedia/data/passages.parquet/part.0.parquet")

print(passages.shape)
passages.head()

(3200, 1)


Unnamed: 0_level_0,passage
id,Unnamed: 1_level_1
0,"Uruguay (official full name in ; pron. , Eas..."
1,"It is bordered by Brazil to the north, by Arge..."
2,Montevideo was founded by the Spanish in the e...
3,The economy is largely based in agriculture (m...
4,"According to Transparency International, Urugu..."


# Do EDA on the passage dataset
- You can try to find the maximum and minimum length of the passages before indexing (just a direction)

In [3]:
# Code for EDA

# Calculate passage lengths
passages['passage_length'] = passages['passage'].str.len()

# Display distribution of passage lengths
print("\nPassage Length Statistics:")
print(passages['passage_length'].describe())

# Show examples of shortest and longest passages
print(f"\nShortest passage (length = {passages['passage_length'].min()}):")
print(passages.loc[passages['passage_length'].idxmin(), 'passage'])

print(f"\nLongest passage (length = {passages['passage_length'].max()}):")
print(passages.loc[passages['passage_length'].idxmax(), 'passage'][:500] + "..." if passages['passage_length'].max() > 500 else passages.loc[passages['passage_length'].idxmax(), 'passage'])



Passage Length Statistics:
count        3200.0
mean     389.848125
std      348.368869
min             1.0
25%           108.0
50%           299.0
75%           574.0
max          2515.0
Name: passage_length, dtype: Float64

Shortest passage (length = 1):
|

Longest passage (length = 2515):
As Ford approached his ninetieth year, he began to experience significant health problems associated with old age. He suffered two minor strokes at the 2000 Republican National Convention, but made a quick recovery.  Gerald Ford recovering after strokes. BBC, August 2, 2000.  Retrieved on December 31, 2006.  In January 2006, he spent 11 days at the Eisenhower Medical Center near his residence at Rancho Mirage, California, for treatment of pneumonia.  Former President Ford, 92, hospitalized with ...


# Tokenize Text and Generate Embeddings using Sentence Transformers

In [4]:
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# Encode Text
embeddings = embedding_model.encode(passages['passage'].tolist())

print(f"Generated embeddings shape: {embeddings.shape}")
print(f"Embedding dimension: {embeddings.shape[1]}")

Generated embeddings shape: (3200, 384)
Embedding dimension: 384


# Create Milvus Client and Insert your Embeddings to your DB

In [5]:
# Define every column of your schema

id_ = FieldSchema(
    name="id",
    dtype=DataType.INT64,
    is_primary=True,
    auto_id=True
)

passage = FieldSchema(
    name="passage",
    dtype=DataType.VARCHAR,
    max_length=65535
)

embedding = FieldSchema(
    name="embedding",
    dtype=DataType.FLOAT_VECTOR,
    dim=384
)

In [6]:
schema = CollectionSchema(
    fields=[id_, passage, embedding],
    description="RAG Wikipedia Mini Collection Schema"
)

In [7]:
client = MilvusClient("../data/rag_wikipedia_mini.db")

# Create the Collection with Collection Name = "rag_mini"
client.create_collection(
    collection_name="rag_mini",
    schema=schema
)

print("Collection 'rag_mini' created.")

  from pkg_resources import DistributionNotFound, get_distribution
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collection 'rag_mini' created.


**Convert your Pandas Dataframe to a list of dictionaries**
- The Dictionary at least have 3 keys [id, passage, embedding]

In [None]:
# Create the dataframe with the required columns

passages_df = pd.DataFrame({
    'passage': passages['passage'],     # Use the passage column from the existing passages dataframe
    'embedding': embeddings.tolist()    # Convert the embeddings to a list format
})

rag_data = passages_df.to_dict('records')

print(rag_data)

In [None]:
# Code to insert the data to your DB
res = client.insert(collection_name="rag_mini", data=rag_data)

print(res)

{'insert_count': 3200, 'ids': [461246114837561344, 461246114837561345, 461246114837561346, 461246114837561347, 461246114837561348, 461246114837561349, 461246114837561350, 461246114837561351, 461246114837561352, 461246114837561353, 461246114837561354, 461246114837561355, 461246114837561356, 461246114837561357, 461246114837561358, 461246114837561359, 461246114837561360, 461246114837561361, 461246114837561362, 461246114837561363, 461246114837561364, 461246114837561365, 461246114837561366, 461246114837561367, 461246114837561368, 461246114837561369, 461246114837561370, 461246114837561371, 461246114837561372, 461246114837561373, 461246114837561374, 461246114837561375, 461246114837561376, 461246114837561377, 461246114837561378, 461246114837561379, 461246114837561380, 461246114837561381, 461246114837561382, 461246114837561383, 461246114837561384, 461246114837561385, 461246114837561386, 461246114837561387, 461246114837561388, 461246114837561389, 461246114837561390, 461246114837561391, 461246114

- Do a Sanity Check on your database 

**Do not delete the below line during your submission**

In [None]:
print("Entity count:", client.get_collection_stats("rag_mini")["row_count"])
print("Collection schema:", client.describe_collection("rag_mini"))

Entity count: 3200
Collection schema: {'collection_name': 'rag_mini', 'auto_id': True, 'num_shards': 0, 'description': 'RAG Wikipedia Mini Collection Schema', 'fields': [{'field_id': 100, 'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'params': {}, 'auto_id': True, 'is_primary': True}, {'field_id': 101, 'name': 'passage', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 65535}}, {'field_id': 102, 'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 384}}], 'functions': [], 'aliases': [], 'collection_id': 0, 'consistency_level': 0, 'properties': {}, 'num_partitions': 0, 'enable_dynamic_field': False}


# Steps to Fetch Results
- Read the Question Dataset
- Clean the Question Dataset if necessary (Drop Questions with NaN etc.)
- Convert Each Query to a Vector Embedding (Use the same embedding model you used to embed your document)
- Try for a Single Question First
- Load Collection into Memory after creating Index for Search on your embedding field (This is an essential step before you can search in your db)
- Search and Fetch Top N Results

In [None]:
import pandas as pd

queries = pd.read_parquet("hf://datasets/rag-datasets/rag-mini-wikipedia/data/test.parquet/part.0.parquet")

print("Question Dataset Shape:", queries.shape)
print("\nFirst few rows:")
queries.head()

Question Dataset Shape: (918, 2)

First few rows:


Unnamed: 0_level_0,question,answer
id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Was Abraham Lincoln the sixteenth President of...,yes
2,Did Lincoln sign the National Banking Act of 1...,yes
4,Did his mother die of pneumonia?,no
6,How many long was Lincoln's formal education?,18 months
8,When did Lincoln begin his political career?,1832


In [None]:
# Check for NaN values and drop if necessary
print("Checking for NaN values in each column:")
print(queries.isnull().sum())

## Drop rows with NaN values if any exist
# queries_clean = queries.dropna()

Checking for NaN values in each column:
question    0
answer      0
dtype: int64


In [None]:
# Convert Each Query to Vector Embedding using the same embedding model
print("Converting queries to embeddings using the same model...")

# Use the same embedding model that was used for passages
query_embeddings = embedding_model.encode(queries['question'].tolist())

print(f"Query embeddings shape: {query_embeddings.shape}")
print(f"Embedding dimension: {query_embeddings.shape[1]}")

Converting queries to embeddings using the same model...
Query embeddings shape: (918, 384)
Embedding dimension: 384


In [None]:
query = queries['question'].iloc[917]     # Your single query
query_embedding = query_embeddings[917]

print(f"Single query: {query}")
print(f"Single query embedding shape: {query_embedding.shape}")
print(f"Single query embedding (first 10 values): {query_embedding[:10]}")


Single query: What happened in 1917?
Single query embedding shape: (384,)
Single query embedding (first 10 values): [-0.00681219  0.05897511 -0.05983922  0.01842949  0.06863772  0.00521405
 -0.0055782   0.0070867  -0.10614451 -0.01090951]


#### Create Index on the embedding column on your DB

In [None]:
index_params = MilvusClient.prepare_index_params()

# Add an index on the embedding field
index_params.add_index(
    field_name="embedding",
    index_type="IVF_FLAT",
    metric_type="COSINE",
    params={"nlist": 1024}
)

# Create the index
try:
    client.create_index(
        collection_name="rag_mini",
        index_params=index_params
    )
    print("Index created successfully on embedding field")
except Exception as e:
    print(f"Index creation result: {e}")

# Load collection into memory (required for search)
client.load_collection(collection_name="rag_mini")
print("Collection loaded into memory")

Index created successfully on embedding field
Collection loaded into memory


In [None]:
# Search the db with query embedding (simplified)

def search_and_fetch_top_n_passages(query_emb, limit):
    """
    Search for similar passages in the vector database
    
    Args:
        query_emb: Query embedding vector
        limit: Number of top results to return
    
    Returns:
        Search results from Milvus
    """
    search_params = {
        "metric_type": "COSINE",
        "params": {"nprobe": 10}
    }
    
    output_ = client.search(
        collection_name="rag_mini",
        data=[query_emb.tolist()],
        anns_field="embedding",
        search_params=search_params,
        limit=limit,
        output_fields=["passage"]
    )
    return output_

# Use the function
output_ = search_and_fetch_top_n_passages(query_embedding, 3)

print("Search results:")
print(f"Number of results: {len(output_[0])}")
for i, result in enumerate(output_[0]):
    print(f"\nResult {i+1}:")
    print(f"Distance: {result['distance']}")
    print(f"Passage: {result['entity']['passage'][:200]}...")  # Show first 200 chars

Search results:
Number of results: 3

Result 1:
Distance: 0.5092912912368774
Passage: President Wilson before Congress, announcing the break in official relations with Germany. February 3, 1917....

Result 2:
Distance: 0.5046647787094116
Passage: After Russia left the war in 1917 following the Bolshevik Revolution the Allies sent troops, presumably, to prevent a German or Bolshevik takeover of allied-provided weapons, munitions and other suppl...

Result 3:
Distance: 0.5011445879936218
Passage: Wilson had ignored the problems of demobilization after the war, and the process was chaotic and violent. Four million soldiers were sent home with little planning, little money, and few benefits.  A ...


## Now get the Context 
- Initially use the first passage ONLY as your context
- In Later Experiments, you must try at least 2 different passage selection strategies (Top 3 / Top 5 / Top 10) and pass to your prompt

In [None]:
# Extract context from search results
# Initially using the first passage only as context

context = output_[0][0]['entity']['passage']
print(f"Context (first passage): {context}")

Context (first passage): President Wilson before Congress, announcing the break in official relations with Germany. February 3, 1917.


**Develop your Prompt**

In [None]:
system_prompt = """You are a helpful assistant that answers questions based on the provided context. 
Use only the information from the context to answer the question. 
If the context doesn't contain enough information to answer the question, say "not enough information".
Be concise and accurate in your response."""

# Extract top n passages as context
top_n_passages = []

n = 1

for i in range(min(n, len(output_[0]))):
    top_n_passages.append(output_[0][i]['entity']['passage'])

context = "\n\t".join(top_n_passages)

prompt = f"""{system_prompt}\n
Context: {context}\n
Question: {query}"""

print(prompt)

You are a helpful assistant that answers questions based on the provided context. 
Use only the information from the context to answer the question. 
If the context doesn't contain enough information to answer the question, say "not enough information".
Be concise and accurate in your response.

Context: President Wilson before Congress, announcing the break in official relations with Germany. February 3, 1917.

Question: What happened in 1917?


# RAG Response for a Single Query

In [None]:
# Load the LLM Model you want to use
# Using a smaller model
model_name = "google/flan-t5-base"

try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        dtype=torch.float32
    )
    
    print(f"Loaded model: {model_name}")
    print(f"Model parameters: {model.num_parameters():,}")
    
except Exception as e:
    print(f"Failed to load model: {e}")
    raise

Loaded model: google/flan-t5-base
Model parameters: 247,577,856


In [None]:
# Generate answer with proper memory management
try:
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
        
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=150,
            num_beams=4,
            early_stopping=True,
            do_sample=False
        )
    
    # Decode and extract answer
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Query: {query}")
    print(f"Context: {context}")
    print(f"Generated Answer: {answer}")
    
    # Clear memory after generation
    del inputs, outputs
    gc.collect()
        
except Exception as e:
    print(f"Answer generation failed: {e}")
    raise

Query: What happened in 1917?
Context: President Wilson before Congress, announcing the break in official relations with Germany. February 3, 1917.
Generated Answer: break in official relations with Germany.


# Generate Responses for all the Queries in the Dataset

In [None]:
# Generate responses for all queries in the dataset
generated_answers = pd.read_csv('../results/rag_generated_answers.csv')['generated_answer'].tolist()
context_list = pd.read_csv('../results/rag_generated_answers.csv')['top_1_context'].tolist()

system_prompt = """You are a helpful assistant that answers questions based on the provided context. 
Use only the information from the context to answer the question. 
If the context doesn't contain enough information to answer the question, say so.
Be concise and accurate in your response."""

n = 3
max_queries = len(queries) # 3  # Limit for testing

for idx, (question_item, embedding_item) in enumerate(zip(queries['question'][:max_queries], query_embeddings[:max_queries])):
    print(f"Processing query {idx + 1}/{max_queries}")
        
    if generated_answers[idx] != "":
        print(f"Skipping query {idx + 1} since it has a generated answer.")
        continue

    try:
        search_results = search_and_fetch_top_n_passages(embedding_item, n)

        # Extract top n passages as context
        top_n_passages = []
        for i in range(min(n, len(search_results[0]))):
            top_n_passages.append(search_results[0][i]['entity']['passage'])

        context = "\n\t".join(top_n_passages)

        prompt = f"""{system_prompt}\n
        Context: {context}\n
        Question: {question_item}"""
        
        inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_length=150,
                num_beams=4,
                early_stopping=True,
                do_sample=False
            )
        
        # Decode the generated answer
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_answers.append(answer)
        context_list.append(search_results[0][0]['entity']['passage'])
        
        # Clear memory after each query
        del inputs, outputs
        gc.collect()
        
        print(f"Completed query {idx + 1}/{max_queries}")
        
    except Exception as e:
        print(f"Failed to process query {idx + 1}: {e}")
        generated_answers.append("Error generating answer")
        continue

# Add generated answers to the queries dataframe
queries['generated_answer'] = generated_answers + [""] * (len(queries) - len(generated_answers))
queries['top_1_context'] = context_list
print(f"Generated answers for {len(generated_answers)} queries")

# this took 23:50 so there has to be a better way

Processing query 1/918
Skipping query 1 since it has a generated answer.
Processing query 2/918
Skipping query 2 since it has a generated answer.
Processing query 3/918
Skipping query 3 since it has a generated answer.
Processing query 4/918
Skipping query 4 since it has a generated answer.
Processing query 5/918
Skipping query 5 since it has a generated answer.
Processing query 6/918
Skipping query 6 since it has a generated answer.
Processing query 7/918
Skipping query 7 since it has a generated answer.
Processing query 8/918
Skipping query 8 since it has a generated answer.
Processing query 9/918
Skipping query 9 since it has a generated answer.
Processing query 10/918
Skipping query 10 since it has a generated answer.
Processing query 11/918
Skipping query 11 since it has a generated answer.
Processing query 12/918
Skipping query 12 since it has a generated answer.
Processing query 13/918
Skipping query 13 since it has a generated answer.
Processing query 14/918
Skipping query 14 s

In [None]:
# Save results to CSV
queries.to_csv("../results/rag_generated_answers.csv", index=False)

In [None]:
queries[['question', 'answer', 'generated_answer','top_1_context']].head(100)

Unnamed: 0_level_0,question,answer,generated_answer,top_1_context
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,Was Abraham Lincoln the sixteenth President of...,yes,yes.,Young Abraham Lincoln
2,Did Lincoln sign the National Banking Act of 1...,yes,House of Representatives.,Lincoln believed in the Whig theory of the pre...
4,Did his mother die of pneumonia?,no,No.,An autopsy performed after his death revealed ...
6,How many long was Lincoln's formal education?,18 months,18 months.,Lincoln's formal education consisted of about ...
8,When did Lincoln begin his political career?,1832,1832.,"Lincoln began his political career in 1832, at..."
...,...,...,...,...
221,Was Coolidge the thirteenth President of the U...,No,no,"Coolidge with his Vice President, Charles G. D..."
223,Was Calvin Coolidge Republican?,Yes,no.,Calvin Coolidge as a young legislator
225,Was Calvin Coolidge a governor of MassachuS08_...,Yes,yes,Calvin Coolidge as a young legislator
227,When was Coolidge born?,July 4 1872,"January 3, 1921.",Calvin Coolidge as a young legislator
