##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Gemma - RAG with ChromaDB

This cookbook demonstrates how you can build a minimal-ish Retrieval-Augmented Generation (RAG) system without using any orchestration tool like LangChain or LlamaIndex. It only uses [ChromaDB](https://www.trychroma.com/) as the vector database for storing and querying embeddings.

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/RAG_with_ChomaDB.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Setup

### Select the Colab runtime
To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.


### Gemma setup on Hugging Face
This cookbook uses the instruction tuned Gemma 7B model through Hugging Face. So you will need to:

* Get access to Gemma on [huggingface.co](huggingface.co) by accepting the Gemma license on the Hugging Face page of the specific model, i.e., [Gemma 7B IT](https://huggingface.co/google/gemma-7b-it).
* Generate a [Hugging Face access token](https://huggingface.co/docs/hub/en/security-tokens) and configure it as a Colab secret 'HF_TOKEN'.

## Retrieval-Augmented Generation (RAG)

Large Language Models (LLMs) can learn new abilities without directly being trained on them. However, LLMs have been known to "hallucinate" when tasked with providing responses for questions they have not been trained on. This is partly because LLMs are unaware of events after training. It is also very difficult to trace the sources from which LLMs draw their responses from. For reliable, scalable applications, it is important that an LLM provides responses that are grounded in facts and is able to cite its information sources.

A common approach used to overcome these constraints is called Retrieval Augmented Generation (RAG), which augments the prompt sent to an LLM with relevant data retrieved from an external knowledge base through an Information Retrieval (IR) mechanism. The knowledge base can be your own corpora of documents, databases, or APIs.

### Chunking the data

To improve the relevance of content returned by the vector database during  retrieval, break down large documents into smaller pieces or chunks while ingesting the document.

In this cookcook, you will use the [Google I/O 2024 Gemma family expansion launch blog](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024/) as the sample document and Google's [Open Source HtmlChunker](https://github.com/google/labs-prototypes/tree/main/seeds/chunker-python) to chunk it up into passages.

In [None]:
!pip install google-labs-html-chunker

from google_labs_html_chunker.html_chunker import HtmlChunker

from urllib.request import urlopen

with urlopen(
    "https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024/"
) as f:
    html = f.read().decode("utf-8")

# Chunk the file using HtmlChunker
chunker = HtmlChunker(
    max_words_per_aggregate_passage=200,
    greedily_aggregate_sibling_nodes=True,
    html_tags_to_exclude={"noscript", "script", "style"},
)
passages = chunker.chunk(html)



Take a look at how the chunked text look like.

In [None]:
for passage in passages:
    print(passage)

Introducing PaliGemma, Gemma 2, and an Upgraded Responsible AI Toolkit
            
            
            
            - Google Developers Blog
Products Develop Android Chrome ChromeOS Cloud Firebase Flutter Google Assistant Google Maps Platform Google Workspace TensorFlow YouTube Grow Firebase Google Ads Google Analytics Google Play Search Web Push and Notification APIs Earn AdMob Google Ads API Google Pay Google Play Billing Interactive Media Ads Solutions Events Learn Community Groups Google Developer Groups Google Developer Student Clubs Woman Techmakers Google Developer Experts Tech Equity Collective Programs Accelerator Solution Challenge DevFest Stories All Stories Developer Profile Blog Search English English Español (Latam) Bahasa Indonesia 日本語 한국어 Português (Brasil) 简体中文
Products More Solutions Events Learn Community More Developer Profile Blog Develop Android Chrome ChromeOS Cloud Firebase Flutter Google Assistant Google Maps Platform Google Workspace TensorFlow YouTube G

## Index the chunks with a vector database

You will now use ChromaDB, an open source embedding database, to index the passages.

In [None]:
!pip install chromadb
import chromadb

chroma_client = chromadb.Client()
collection = chroma_client.create_collection(name="cookbook_collection")
collection.add(documents=passages, ids=[str(i) for i in range(len(passages))])



Next, you retrieve relevant passages from the vector database as the context, based on the user question and assemble a prompt using both the user question and retrieved context.

In [None]:
prompt_template = """You are an expert in answering user questions. You always understand user questions well, and then provide high-quality answers based on the information provided in the context.

If the provided context does not contain relevent information, just respond "I could not find the answer based on the context you provided."

User question: {}

Context:
{}
"""

user_question = "how many parameters does Gemma 2 have?"

results = collection.query(query_texts=user_question, n_results=3)

context = "\n".join(
    [f"{i+1}. {passage}" for i, passage in enumerate(results["documents"][0])]
)
prompt = f"{prompt_template.format(user_question, context)}"

Here is the final prompt that will be sent to Gemma.

In [None]:
print(prompt)

You are an expert in answering user questions. You always understand user questions well, and then provide high-quality answers based on the information provided in the context.

If the provided context does not contain relevent information, just respond "I could not find the answer based on the context you provided."

User question: how many parameters does Gemma 2 have?

Context:
1. Gemma 2 is still pretraining. This chart shows performance from the latest Gemma 2 checkpoint along with benchmark pretraining metrics. Source: Hugging Face Open LLM Leaderboard (April 22, 2024) and Grok announcement blog
2. Stay tuned for the official launch of Gemma 2 in the coming weeks! Expanding the Responsible Generative AI Toolkit For this reason we're expanding our Responsible Generative AI Toolkit to help developers conduct more robust model evaluations by releasing the LLM Comparator in open source. The LLM Comparator is a new interactive and visual tool to perform effective side-by-side evaluat

### Generate the answer

Now load the Gemma model in quanzied 4-bit mode using Hugging Face.

In [None]:
!pip install bitsandbytes accelerate
from transformers import AutoTokenizer
import transformers
import torch
import bitsandbytes, accelerate

model = "google/gemma-1.1-7b-it"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    model_kwargs={
        "torch_dtype": torch.float16,
        "quantization_config": {"load_in_4bit": True},
    },
)



tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/620 [00:00<?, ?B/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

Finally, generate the answer.

In [None]:
messages = [
    {"role": "user", "content": prompt},
]
prompt = pipeline.tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.1)
print(outputs[0]["generated_text"][len(prompt) :])

Gemma 2 has **27 billion parameters**.

The context explicitly states that Gemma 2 has 27 billion parameters.


Gemma is able to provide the correct answer based on the context.