# Grounding LLM statistics facts using Retrieval Augmented Generation (RAG)

In this notebook, we share promising, early research advancements that tackle the challenge of provenance around real-world statistical data. This notebook connects to DataGemma, the first open-source models designed to connect large language models with the extensive, real-world data housed within Google's Data Commons.

This established approach retrieves relevant information from Data Commons before the LLM generates text, providing it with a factual foundation for its response. This implementation is only possible because of Gemini 1.5 Pro’s long context window allowing us to append the user query with the Data Commons data. More technical details of this approach can be found in **TODO:(paper link)**.

This demo is based on a finetuned Gemma2 27B model.

Please read [Gemma Terms of Use](https://ai.google.dev/gemma/terms).

***Disclaimer:***

*You're accessing a very early version of DataGemma. It is meant for trusted tester use (primarily for academic and research use) and not yet ready for commercial or general public use. This version was trained on a very small corpus of examples and may exhibit unintended, and at times controversial or inflammatory behavior. Please anticipate errors and limitations as we actively develop this large language model interface.*

*Your feedback and evaluations are critical to refining DataGemma's performance and will directly contribute to its training process. Known limitations are detailed in the paper, and we encourage you to consult it for a comprehensive understanding of DataGemma's current capabilities.*

## Step 0: Setup

To run this colab, you will need to use the A100 GPU and High-RAM runtime in Colab. With this runtime configuration, the total runtime of this notebook would take about 20 minutes.

You also need authentication for model and data access:

*   **Hugging Face Token**. To obtain the token, login to your Hugging Face account [token settings](https://huggingface.co/settings/tokens) to create a new token. Copy this token and store it on the colab notebook `Secrets` section with Name `HF_TOKEN`.

*   **Data Commons API Key**. Register for an API key from Data Commons [API key portal](https://apikeys.datacommons.org). Once you get the API key, store it on the colab notebook `Secrets` section with Name `DC_API_KEY`.

*   **Gemini 1.5 Pro API Key**. Register for an API key from [Google AI Studio](https://aistudio.google.com/app/apikey). Once you get the API key, store it on the colab notebook `Secrets` section with Name `GEMINI_API_KEY`

Toggle the "Notebook access" button to enable the secrets.



Then install the required libraries.

In [None]:
!pip install -q git+https://github.com/datacommonsorg/llm-tools
!pip install -q bitsandbytes accelerate

## Step 1: Load the model

This section loads the finetuned Gemma2 27B model from Huggingface and creates a transformer model wrapper than can be used in the Retrieval Augmented Generation (RAG) workflow. More technical details of this approach can be found in **TODO:(paper link)**.

In [None]:
import torch

import data_gemma as dg

from google.colab import userdata
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Initialize Data Commons API client
DC_API_KEY = userdata.get('DC_API_KEY')
dc = dg.DataCommons(api_key=DC_API_KEY)

# Get Gemini 1.5 Pro model
GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
gemini_model = dg.GoogleAIStudio(model='gemini-1.5-pro', api_keys=[GEMINI_API_KEY])


# Get finetuned Gemma2 model from HuggingFace
HF_TOKEN = userdata.get('HF_TOKEN')

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_compute_dtype=torch.bfloat16
)

model_name = 'gg-hf/data-gemma-rag-27b-it' # TODO(update)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map="auto",
                                             quantization_config=nf4_config,
                                             torch_dtype=torch.bfloat16,
                                             token=HF_TOKEN)

# Build the LLM Model stub to use in RAG flow
hfm = dg.HFBasic(model, tokenizer)

## Step 2: Pick or enter a query for RAG

You can select a query or enter your own query to test RAG.


In [None]:

#@title Pick a query from a sample list{ run: "auto" }
QUERY = "Do the US states with high coal fired power also have high rates of COPD?" #@param ["Do the US states with high coal fired power also have high rates of COPD?", "Is obesity in America continuing to increase?", "Which US states have the highest percentage of uninsured children?", "Which US states have the highest cancer rates?", "How have CO2 emissions changed in France over the last 10 years?", "How many US households have individuals over than 65 in them", "Which New Jersey schools have the highest student to teacher ratio?", "Show me a breakdown of income distribution for Seattle.", "Which New Jersey cities have the best commute times for workers?", "If you excluded the SF/Bay Area from California, what would the GDP then be?", "What are the highest paid jobs in Texas?", "Does India have more people living in the urban areas or rural areas?  How does that vary by states?  Are the districts with the most urban population also located in the states with the most urban population?", "Can you find a district in India each where: 1. there are more muslims than hindus or christians or sikhs;  2. more christians than the rest;  3. more sikhs than the rest.", "What are some interesting trends in Sunnyvale spanning gender, age, race, immigration, health conditions, economic conditions, crime and education?", "Which US States are the best environmentally?", "Where are the most violent places in the world?", "Compare Cambridge, MA and Palo Alto, CA in terms of demographics, education, and economy stats.", "What trends can be observed among the countries that are the top consumers, importers and exporters of electricity?", "Give me some farming statistics about Kern county, CA.", "What is the fraction households below poverty status receive food stamps in the US?  How does that vary across states?", "Is there evidence that single-parent families are more likely to be below the poverty line compared to married-couple families in the US?", "At what points in the past did house prices in bay area counties dip?", "What patterns emerge from statistics on safe birth rates across states in India?", "Based on the distribution of foreign language speakers compare the diversity of people in: NYC, Seattle, Austin, Chicago and Tampa", "Are there significant differences in the prevalence of various types of disabilities (such as vision, hearing, mobility, cognitive) between Dallas and Houston?", "Are there states in the US that stand out as outliers in terms of the prevalence of drinking and smoking?", "Has the use of renewables increased globally?", "Has the average lifespan increased globally?"]


In [None]:
#@title Use your own query (Please see disclaimer at the top)
QUERY = 'In the US states with the highest foreign language speakers, how does the unemployment rate compare to the national average?' #@param {type:"string"}

## Step 3: Run RAG and Print Output


In [None]:
print(f"[QUERY]: {QUERY}\n")
ans = dg.RAGFlow(llm_question=hfm, llm_answer=gemini_model, data_fetcher=dc).query(query=QUERY)
print(ans.answer())

[QUERY]: What are some interesting trends in Sunnyvale spanning gender, age, race, immigration, health conditions, economic conditions, crime and education?

... [RAG] Calling FINETUNED model for DC questions
... calling HF Pipeline API ""
Your role is that of a Question Generator.  Give..."
... [RAG] Making DC Calls
... calling DC for table with "What is the percentage of Sunnyvale residents born outside of Sunnyvale private schools?"
... calling DC for table with "What is the percentage of Sunnyvale residents born outside of Sunnyvale school districts?"
... calling DC for table with "What is the percentage of Sunnyvale residents born in Sunnyvale private schools?"
... calling DC for table with "What is the percentage of Sunnyvale residents born in Sunnyvale charter schools?"
... calling DC for table with "What is the percentage of Sunnyvale residents born in Sunnyvale school districts?"
... calling DC for table with "What is the population of Sunnyvale residents born in the US?"... c

# More Information on Retrieval Augmented Generation (RAG)

Retrieval Augmented Generation (RAG): This established approach retrieves relevant information from Data Commons before the LLM generates text, providing it with a factual foundation for its response. This implementation is only possible because of Gemini 1.5 Pro’s long context window allowing us to append the user query with the Data Commons data.

Here's how RAG works:

1.   User Query: A user submits a query to the LLM.
2.   Query Analysis & Data Commons Query Generation: The DataGemma model (based on the Gemma 2 (27B) model and fully fine-tuned for this RAG task) analyzes the user's query and generates a corresponding query (or queries) in natural language that can be understood by Data Commons' existing natural language interface.
3.   Data Retrieval from Data Commons: Data Commons is queried using this natural language query, and relevant data tables, source information, and links are retrieved.
4.   Augmented Prompt: The retrieved information is added to the original user query, creating an augmented prompt.
5.   Final Response Generation: A larger LLM (Gemini 1.5 Pro) uses this augmented prompt, including the retrieved data, to generate a comprehensive and grounded response.

In the above example, 14 questions are asked of Data Commons (eg "What is the population of Sunnyvale?") and corresponding data tables are retrieved. The data in these table is used to compose the final response with coherent information and insight.