<a href="https://colab.research.google.com/github/RajuDasa/llm_engineering/blob/week5_branch/week5/community-contributions/raju/RAG_exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##RAG using Gemini models, Huggingface dataset, Chromadb:

In [None]:
!pip install -q chromadb

In [None]:
#sample datasets - https://huggingface.co/datasets/rag-datasets/rag-mini-wikipedia
#ref code - "https://github.com/" + "google-gemini/cookbook/blob/main/examples/" + "chromadb/Vectordb_with_chroma.ipynb"
#Colab notebook - make sure HF_TOKEN and GEMINI_API_KEY are present in secrets

from google import genai
from google.genai import types
from google.colab import userdata
from chromadb import EmbeddingFunction, Client
from datasets import load_dataset, load_dataset_builder
from huggingface_hub import login
from sentence_transformers import SentenceTransformer  #direct version

#hf_token = userdata.get('HF_TOKEN')
#login(hf_token, add_to_git_credential=True) #HF

GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
gemini = genai.Client(api_key=GEMINI_API_KEY)

EMBEDDING_MODEL_ID = "gemini-embedding-001"
LLM_MODEL_ID = "gemini-2.5-flash"  #"gemini-2.5-flash-lite"

**Load Datasets:**

In [None]:
#ds_builder = load_dataset_builder("rag-datasets/rag-mini-wikipedia", "text-corpus")
#ds_builder.info.features  # {'passage': Value('string'), 'id': Value('int64')}
#ds_builder.info.description # empty

In [None]:
#HF dataset has total 3200 records, lets take first 500

ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages", streaming=True) #features: ['passage', 'id']
ds_delta = list(ds.take(500))

In [None]:
# Download 100 Q&A to test and evaluate - total: 918 records

ds_qa = load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer", split="test", streaming=True) #features: ['question', 'answer', 'id']
dsqa_delta = list(ds_qa.take(100))

In [None]:
dsqa_delta[-5:]

In [None]:
#filtered_ds_delta = [record for record in ds_delta if 'Coolidge' in record.get('passage', '')]
#filtered_ds_delta[:5]

**Prepare vectorstore:**

In [None]:
#using heavy model via API will limit/cost
class CustomEmbeddingFunction(EmbeddingFunction):
  def __init__(self, embed_model_id):
    self.model_id = embed_model_id
    self.title = "wiki query"
    self.task_type = "RETRIEVAL_DOCUMENT"

  def __call__(self, docs):
    response = gemini.models.embed_content(
        model=self.model_id,
        contents=docs,
        config=types.EmbedContentConfig(
          task_type=self.task_type,
          title=self.title
        )
    )
    return response.embeddings[0].values

#error: 429 RESOURCE_EXHAUSTED. limit: 100. retry in 5.666227241s => with free account

In [None]:
#using small model in local (free)
class AllMiniEmbeddingFunction(EmbeddingFunction):
  def __init__(self):
    self.model = SentenceTransformer("all-MiniLM-L6-v2")

  def __call__(self, docs):
    embeddings = self.model.encode(docs, normalize_embeddings=True)
    return embeddings


In [None]:
def get_chroma_db(documents, name):
  client = Client()
  db = client.get_or_create_collection(
      name = name,
      embedding_function = AllMiniEmbeddingFunction()  #CustomEmbeddingFunction(EMBEDDING_MODEL_ID)
  )

  for docs in documents:
    db.add(
      documents=docs['passage'],
      ids= str(docs['id'])
    )
  return db, client

In [None]:
db, client = get_chroma_db(ds_delta, "Wiki_DB")

In [None]:
#clear coll to start fresh
#Client().delete_collection(name="Wiki_DB")

In [None]:
def get_context(query):
  docs = db.query(query_texts=[query], n_results=8)['documents'][0]
  return docs or []


In [None]:
#If topic not found in context, return irrevant content
get_context('Do beetles antennae function primarily as organs of smell')

**Connect with Model:**

In [None]:
def generate_prompt(**kwargs):  #query, context
  prompt = ("""
    You are a helpful assistant who answers questions using the context provided below.
    Respond with short answer, e.g: 'yes', 'no', '18 months'.
    Only use the provided context as grounding true source.
    If no context is provided or it is irrelevant, respond with - <I don't know>

    QUESTION: '{query}'
    CONTEXT: '{context}'

    ANSWER:
  """).format(**kwargs)  #query=query, context=context
  return prompt

In [None]:
def get_answer_for(question):
  context = get_context(question)
  context = "\n".join(context)
  prompt = generate_prompt(query=question, context=context)
  #print(prompt)
  answer = gemini.models.generate_content(
      model = LLM_MODEL_ID,
      contents = prompt
  )
  return answer.text

**Test and Evaluate RAG:**


In [None]:
#Manual testing:
get_answer_for("Who or what vary greatly in form within the coleoptera?")

In [None]:
import random

#Auto test - 5 tests per topic
def evaluate(qa_data):
  if not qa_data:
    display("Empty Q&A set")
    return

  limit = min(5, len(qa_data))
  for i in range(0, limit):
    idx = i  #random.randint(0, len(qa_data)-1)
    qa = qa_data[idx]
    qa_data
    gen_answer = get_answer_for(qa["question"])
    display(qa["id"], f"Question: {qa["question"]}\n Actual: {qa["answer"]}\n Generated: {gen_answer}\n\n")


In [None]:
#dsqa_delta has questions whose topics are not present in ds_delta (context)
#choose questions based on topics in ds_delta

TOPIC = "Coolidge"
qa_data = [record for record in dsqa_delta if TOPIC in record.get('question', '')]
evaluate(qa_data)