# Grounding LLM statistics facts using Retrieval Interleaved Generation (RIG)

In this notebook, we share promising, early research advancements that tackle the challenge of provenance around real-world statistical data. This notebook connects to DataGemma, the first open-source models designed to connect large language models with the extensive, real-world data housed within Google's Data Commons.

This novel approach fine-tunes Gemma 2 to recognize when it needs to replace a generated number with more accurate information from Data Commons. Think of it as the model double-checking its work against a trusted source. More technical details of this approach can be found in **TODO:(paper link)**.

This demo is based on a finetuned Gemma2 27B model.

Please read [Gemma Terms of Use](https://ai.google.dev/gemma/terms).

***Disclaimer:***

*You're accessing a very early version of DataGemma. It is meant for trusted tester use (primarily for academic and research use) and not yet ready for commercial or general public use. This version was trained on a very small corpus of examples and may exhibit unintended, and at times controversial or inflammatory behavior. Please anticipate errors and limitations as we actively develop this large language model interface.*

*Your feedback and evaluations are critical to refining DataGemma's performance and will directly contribute to its training process. Known limitations are detailed in the paper, and we encourage you to consult it for a comprehensive understanding of DataGemma's current capabilities.*

## Step 0: Setup

To run this colab, you will need authentication for various resources:

*   **Hugging Face Token**. To obtain the token, login to your Hugging Face account [token settings](https://huggingface.co/settings/tokens) to create a new token. Copy this token and store it on the colab notebook `Secrets` section with Name `HF_TOKEN`.

*   **Data Commons API Key**. Register for an API key from Data Commons [API key portal](https://apikey.datacommons.org). Once you get the API key, store it on the colab notebook `Secrets` section with Name `DC_API_KEY`.

Then install the required libraries.

In [1]:
!pip install -q git+https://github.com/shifucun/dc-llm-tools@fix
!pip install -q bitsandbytes accelerate

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for data_gemma (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
[?25h

## Step 1: Load the model

This section loads the finetuned Gemma2 27B model from Huggingface and creates a transformer model wrapper than can be used in the Retrieval Interleaved Generation (RIG) workflow. More technical details of this approach can be found in **TODO:(paper link)**.

In [2]:
import torch

import data_gemma as dg

from google.colab import userdata
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Initialize Data Commons API client
DC_API_KEY = userdata.get('DC_API_KEY')
dc = dg.DataCommons(api_key=DC_API_KEY)


# Get finetuned Gemma2 model from HuggingFace
HF_TOKEN = userdata.get('HF_TOKEN')

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_compute_dtype=torch.bfloat16
)

model_name = 'gg-hf/data-gemma-rig-27b-it' # TODO(update)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             device_map="auto",
                                             quantization_config=nf4_config,
                                             torch_dtype=torch.bfloat16,
                                             token=HF_TOKEN)

# Build the LLM Model stub to use in RIG flow
hfm = dg.HFBasic(model, tokenizer)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/850 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/42.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/12 [00:00<?, ?it/s]

model-00001-of-00012.safetensors:   0%|          | 0.00/4.74G [00:00<?, ?B/s]

model-00002-of-00012.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00003-of-00012.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00004-of-00012.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00005-of-00012.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00006-of-00012.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00007-of-00012.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00008-of-00012.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00009-of-00012.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00010-of-00012.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00011-of-00012.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00012-of-00012.safetensors:   0%|          | 0.00/680M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

## Step 2: Pick or enter a query for RIG

You can selected a query or enter your own query to test RIG.


In [5]:

#@title Pick a query from a sample list{ run: "auto" }
QUERY = "What progress has Pakistan made against health goals?" #@param ["What percentage of the Indian population lives in slums?","In which countries are more women getting college degrees than men?","What is the percentage of the financial sector in GDP for different countries like the United States and China, based on the latest data?","Which New Jersey Towns have the best commute times for workers?","Does India have more people living in the urban areas or rural areas?  How does that vary by states?  Are the districts with the most urban population also located in the states with the most urban population?","What are some interesting trends in Sunnyvale spanning gender, age, race, immigration, health conditions, economic conditions, crime and education?","In the US states with the highest foreign language speakers, how does the unemployment rate compare to the national average?","There is probably a delay between a high percentage of females attending school and females being in elected positions in the government. Can you show me what that delay looks like in different African countries?","Which US counties share a very similar demographic composition to the US overall in terms of gender, age and racial breakdown?","Compare Cambridge, MA and Palo Alto, CA in terms of demographics, education, and economy stats.","How does the size of household compare across counties in Utah vs. California?  Does it change between owned vs. rental properties?","When comparing median ages across different racial groups in various cities in California, what differences emerge?","Based on the distribution of foreign language speakers compare the diversity of people in: NYC, Seattle, Austin, Chicago and Tampa","Are there significant differences in the prevalence of various types of disabilities (such as vision, hearing, mobility, cognitive) between Dallas and Houston?","It might be expected that countries producing the most food waste have a lower prevalence of undernourishment. Does the data support this?","Are there countries in the world where the forest area has actually increased?","What progress has Pakistan made against health goals?","Does an increase in female participation in education result in a higher number of women holding political office?","Which countries have the highest life expectancy?","Which countries have the lowest poverty rates?","Which countries have the highest GDP?","Has the use of renewables increased globally?","Has the average lifespan increased globally?"]


In [3]:
#@title Use your own query (Please see disclaimer at the top)
QUERY = 'In the US states with the highest foreign language speakers, how does the unemployment rate compare to the national average?' #@param {type:"string"}

## Step 3: Run RIG and Print Output


In [6]:
print(f"[QUERY]: {QUERY}\n")
ans = dg.RIGFlow(llm=hfm,
                 data_fetcher=dc,
                 in_context=False).query(query=QUERY)

print(ans.answer())

[QUERY]: What progress has Pakistan made against health goals?

... [RIG] Calling FINETUNED Model
... calling HF Pipeline API "What progress has Pakistan made against health goa..."
... calling DC with "what was the life expectancy in Pakistan in 2000?"
... calling DC with "what was the life expectancy in Pakistan in 2020?"
... calling DC with "what was the maternal mortality rate in Pakistan in 2000?"
... calling DC with "what was the maternal mortality rate in Pakistan in 2018?"
... [RIG] Calling DC Evaluate


Pakistan has made some progress against its health goals, but significant challenges remain. 

**Here are some key points:**

**Progress made:**

* **Increased life expectancy:** Life expectancy at birth has increased from [__DC__#1(62.102 yr [1] || 61.8 years)] in 2000 to [__DC__#2(66.269 yr [2] || 67.2 years)] in 2020.
* **Reduced maternal mortality:** Maternal mortality ratio has declined from [__DC__#3(387.3715 Per 100,000 live births [3]* || 276 per 100,000 live births)] i

# More Information on Retrieval Interleaved Generation (RIG)

Retrieval Interleaved Generation (RIG): This novel approach fine-tunes Gemma 2 to recognize when it needs to replace a generated number with more accurate information from Data Commons. Think of it as the model double-checking its work against a trusted source.

Here's how RIG works:
1. User Query: A user submits a query to the LLM.
2. Initial Response & Data Commons Query: The DataGemma model (based on the Gemma 2 27 billion parameter (27B) model and fully fine-tuned for this RIG task) generates a response, which includes a natural language query for Data Commons' existing natural language interface,  specifically designed to retrieve relevant data.
3. Data Retrieval & Correction: Data Commons is queried, and the data are retrieved. These data, along with source information and a link, are then used to replace potentially inaccurate numbers in the initial response.
4. Final Response with Source Link: The final response is presented to the user, including a link to the source data and metadata in Data Commons for transparency and verification.

In the above example, notice the questions being asked of Data Commons (eg "what was the life expectancy in Pakistan in 2000?") which is being used to compare the initial LLM response in `[__DC__#1(62.102 yr [1] || 61.8 years)]`. `61.8 years` is the  value generated by Gemma2 27B. DataGemma is trained to query Data Commons with "what was the life expectancy in Pakistan in 2000?". Statistics from the World Bank along with citations to the initial source are returned by Data Commons and replaced in the final response.  