# Data Setup and Embedding Generation

This notebook handles:
1. Loading the Wikipedia passages and questions
2. Generating embeddings for passages
3. Setting up the Milvus vector database
4. Creating the database schema and indexes

**Run this notebook first before running naive_rag.ipynb or enhanced_rag.ipynb**


In [1]:
# Load all required Libraries
import pandas as pd
import transformers, torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModelForCausalLM
import gc
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Load Passages and Questions

In [2]:
# Load passages from the dataset
passages = pd.read_parquet("hf://datasets/rag-datasets/rag-mini-wikipedia/data/passages.parquet/part.0.parquet")

print(f"Loaded {len(passages)} passages")
print(f"Passage shape: {passages.shape}")
print("\nFirst few passages:")
passages.head()

Loaded 3200 passages
Passage shape: (3200, 1)

First few passages:


Unnamed: 0_level_0,passage
id,Unnamed: 1_level_1
0,"Uruguay (official full name in ; pron. , Eas..."
1,"It is bordered by Brazil to the north, by Arge..."
2,Montevideo was founded by the Spanish in the e...
3,The economy is largely based in agriculture (m...
4,"According to Transparency International, Urugu..."


In [3]:
# Load questions from the dataset
queries = pd.read_parquet("hf://datasets/rag-datasets/rag-mini-wikipedia/data/test.parquet/part.0.parquet")

print(f"Loaded {len(queries)} questions")
print(f"Question shape: {queries.shape}")
print("\nFirst few questions:")
queries.head()

Loaded 918 questions
Question shape: (918, 2)

First few questions:


Unnamed: 0_level_0,question,answer
id,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Was Abraham Lincoln the sixteenth President of...,yes
2,Did Lincoln sign the National Banking Act of 1...,yes
4,Did his mother die of pneumonia?,no
6,How many long was Lincoln's formal education?,18 months
8,When did Lincoln begin his political career?,1832


In [4]:
# Initialize embedding model
print("Loading sentence transformer model...")
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print(f"Model loaded: {embedding_model.get_sentence_embedding_dimension()} dimensions")

Loading sentence transformer model...
Model loaded: 384 dimensions


In [5]:
# Generate embeddings for passages
print("Generating embeddings for passages...")
passage_texts = passages['passage'].tolist()

# Generate embeddings in batches to manage memory
batch_size = 100
passage_embeddings = []

for i in tqdm(range(0, len(passage_texts), batch_size), desc="Processing passages"):
    batch = passage_texts[i:i+batch_size]
    batch_embeddings = embedding_model.encode(batch, show_progress_bar=False)
    passage_embeddings.extend(batch_embeddings)
    
    # Clear memory periodically
    if i % 500 == 0:
        gc.collect()

passage_embeddings = np.array(passage_embeddings)
print(f"Generated passage embeddings: {passage_embeddings.shape}")
print(f"Embedding dimension: {passage_embeddings.shape[1]}")


Generating embeddings for passages...


Processing passages: 100%|██████████| 32/32 [00:49<00:00,  1.54s/it]

Generated passage embeddings: (3200, 384)
Embedding dimension: 384





In [6]:
# Generate embeddings for queries
print("Generating embeddings for queries...")
query_texts = queries['question'].tolist()

# Generate embeddings in batches
batch_size = 50
query_embeddings = []

for i in tqdm(range(0, len(query_texts), batch_size), desc="Processing queries"):
    batch = query_texts[i:i+batch_size]
    batch_embeddings = embedding_model.encode(batch, show_progress_bar=False)
    query_embeddings.extend(batch_embeddings)
    
    # Clear memory periodically
    if i % 200 == 0:
        gc.collect()

query_embeddings = np.array(query_embeddings)
print(f"Generated query embeddings: {query_embeddings.shape}")
print(f"Embedding dimension: {query_embeddings.shape[1]}")


Generating embeddings for queries...


Processing queries: 100%|██████████| 19/19 [00:05<00:00,  3.63it/s]

Generated query embeddings: (918, 384)
Embedding dimension: 384





In [7]:
# Setup Milvus Vector Database
# Define database schema
id_ = FieldSchema(
    name="id",
    dtype=DataType.INT64,
    is_primary=True,
    auto_id=True
)

passage = FieldSchema(
    name="passage",
    dtype=DataType.VARCHAR,
    max_length=65535
)

embedding = FieldSchema(
    name="embedding",
    dtype=DataType.FLOAT_VECTOR,
    dim=384  # all-MiniLM-L6-v2 dimension
)

schema = CollectionSchema(
    fields=[id_, passage, embedding],
    description="RAG Wikipedia Mini Collection Schema"
)

print("Database schema defined:")
print(f"- ID field: {id_.name} ({id_.dtype})")
print(f"- Passage field: {passage.name} ({passage.dtype})")
print(f"- Embedding field: {embedding.name} ({embedding.dtype}, dim={embedding.dim})")

Database schema defined:
- ID field: id (5)
- Passage field: passage (21)
- Embedding field: embedding (101, dim=384)


In [8]:
# Create Milvus client and collection
client = MilvusClient("../data/processed/rag_wikipedia_mini.db")

# Create the collection
client.create_collection(
    collection_name="rag_mini",
    schema=schema
)

print("Collection 'rag_mini' created successfully!")

  from pkg_resources import DistributionNotFound, get_distribution
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collection 'rag_mini' created successfully!


In [9]:
# Prepare data for insertion
print("Preparing data for database insertion...")

# Create dataframe with required columns
passages_df = pd.DataFrame({
    'passage': passages['passage'],
    'embedding': passage_embeddings.tolist()
})

# Convert to list of dictionaries
rag_data = passages_df.to_dict('records')
print(f"Prepared {len(rag_data)} records for insertion")


Preparing data for database insertion...
Prepared 3200 records for insertion


In [10]:
# Insert data into database
print("Inserting data into Milvus database...")

# Insert in batches to manage memory
batch_size = 100
total_inserted = 0

for i in tqdm(range(0, len(rag_data), batch_size), desc="Inserting data"):
    batch = rag_data[i:i+batch_size]
    
    try:
        result = client.insert(collection_name="rag_mini", data=batch)
        total_inserted += result['insert_count']
        
        # Clear memory periodically
        if i % 500 == 0:
            gc.collect()
            
    except Exception as e:
        print(f"Error inserting batch {i//batch_size + 1}: {e}")
        continue

print(f"\nSuccessfully inserted {total_inserted} records")


Inserting data into Milvus database...


Inserting data: 100%|██████████| 32/32 [00:01<00:00, 22.31it/s]


Successfully inserted 3200 records





In [11]:
# Create index for efficient searching
print("Creating index for efficient searching...")

index_params = MilvusClient.prepare_index_params()

# Add an index on the embedding field
index_params.add_index(
    field_name="embedding",
    index_type="IVF_FLAT",
    metric_type="COSINE",
    params={"nlist": 1024}
)

# Create the index
try:
    client.create_index(
        collection_name="rag_mini",
        index_params=index_params
    )
    print("Index created successfully on embedding field")
except Exception as e:
    print(f"Index creation result: {e}")

# Load collection into memory (required for search)
client.load_collection(collection_name="rag_mini")
print("Collection loaded into memory")


Creating index for efficient searching...
Index created successfully on embedding field
Collection loaded into memory


In [12]:
# Save embeddings for reuse
print("Saving embeddings for reuse...")

# Save query embeddings
np.save("../data/processed/query_embeddings.npy", query_embeddings)
print(f"Query embeddings saved: {query_embeddings.shape}")

# Save passage embeddings
np.save("../data/processed/passage_embeddings.npy", passage_embeddings)
print(f"Passage embeddings saved: {passage_embeddings.shape}")

# Save queries dataframe
queries.to_csv("../data/processed/queries.csv", index=False)
print("Queries dataframe saved")

print("\n✅ Data setup completed successfully!")
print("You can now run naive_rag.ipynb or enhanced_rag.ipynb")


Saving embeddings for reuse...
Query embeddings saved: (918, 384)
Passage embeddings saved: (3200, 384)
Queries dataframe saved

✅ Data setup completed successfully!
You can now run naive_rag.ipynb or enhanced_rag.ipynb
