# Retrieval-Augmented Generation (RAG)

In [11]:
import chromadb
import dotenv
from pathlib import Path
from agents import Agent, Runner, function_tool, trace

dotenv.load_dotenv()

True

Create a static calorie table that we can use as a tool:

In [12]:
# We populated the RAG with the data from the data/calories.csv file in
# the rag_setup.ipynb notebook

chroma_client = chromadb.PersistentClient("../chroma")
nutrition_db = chroma_client.get_collection(name="nutrition_db")
nutrition_qa_db = chroma_client.get_collection(name="nutrition_qna")

In [13]:
results = nutrition_db.query(query_texts=["banana"], n_results=2)
for i, doc in enumerate(results["documents"][0]):
    print(sorted(results["metadatas"][0][i].items()))
    print(doc)
    print("\n")

[('calories_per_100g', 89.0), ('food_category', 'fruits'), ('food_item', 'banana'), ('keywords', 'banana_fruits'), ('kj_per_100g', 374.0), ('serving_info', '100g')]
Food: Banana
        Category: Fruits
        Nutritional Information:
        - Calories: 89 per 100g
        - Energy: 374 kJ per 100g
        - Serving size reference: 100g

        This is a fruits food item that provides 89 calories per 100 grams.


[('calories_per_100g', 50.0), ('food_category', '(fruit)juices'), ('food_item', 'banana juice'), ('keywords', 'banana_juice_(fruit)juices'), ('kj_per_100g', 210.0), ('serving_info', '100ml')]
Food: Banana Juice
        Category: (Fruit)Juices
        Nutritional Information:
        - Calories: 50 per 100g
        - Energy: 210 kJ per 100g
        - Serving size reference: 100ml

        This is a (fruit)juices food item that provides 50 calories per 100 grams.




In [14]:
@function_tool
def calorie_lookup_tool(query: str, max_results: int = 3) -> str:
    """
    Tool function for a RAG database to look up calorie information for specific food items, but not for meals.

    Args:
        query: The food item to look up.
        max_results: The maximum number of results to return.

    Returns:
        A string containing the nutrition information.
    """

    results = nutrition_db.query(query_texts=[query], n_results=max_results)

    if not results["documents"][0]:
        return f"No nutrition information found for: {query}"

    # Format results for the agent
    formatted_results = []
    for i, doc in enumerate(results["documents"][0]):
        metadata = results["metadatas"][0][i]
        food_item = metadata["food_item"].title()
        calories = metadata["calories_per_100g"]
        category = metadata["food_category"].title()

        formatted_results.append(
            f"{food_item} ({category}): {calories} calories per 100g"
        )

    return "Nutrition Information:\n" + "\n".join(formatted_results)

Let's test this out: 

_The following cell only works before you add the `@function_tool` annotation to `calorie_lookup_tool` function_

In [None]:
# calorie_lookup_tool('bananas')

'Nutrition Information:\nBanana (Fruits): 89.0 calories per 100g\nBanana Juice ((Fruit)Juices): 50.0 calories per 100g\nBanana Nut Bread (Pastries,Breads&Rolls): 326.0 calories per 100g'

In [15]:
calorie_agent = Agent(
    name="Nutrition Assistant",
    instructions="""
    You are a helpful nutrition assistant giving out calorie information.
    You give concise answers.
    If you need to look up calorie information, use the calorie_lookup_tool.
    """,
    tools=[calorie_lookup_tool]
)

In [16]:
with trace("Nutrition Assistant with RAG"):
    result = await Runner.run(
        calorie_agent,
        "How many calories are in total in a banana and an apple? Also give calories per 100g",
    )
    print(result.final_output)

- Calories per 100 g: banana 89 kcal, apple 52 kcal
- Typical medium banana (~118 g): ~105 kcal
- Typical medium apple (~182 g): ~95 kcal
- Estimated total (one medium banana + one medium apple): ~200 kcal


In [17]:
# Test Q&A RAG tool
results = nutrition_qa_db.query(query_texts=["malnutrition"], n_results=2)
for i, doc in enumerate(results["documents"][0]):
    print(sorted(results["metadatas"][0][i].items()))
    print(doc)
    print("\n")

[('is_pregnancy', False)]
Question: Which factors are cited as the main causes of malnutrition?
        Answer: The text mentions several key reasons for malnourishment. These include a lack of understanding about selecting nutritious foods, financial constraints and diseases caused by pathogens, drought-related scarcity, unequal distribution of available foods, societal stagnation and disputes, transportation issues, population growth, suboptimal weaning practices, agricultural techniques, ineffective resource management, terrain differences, crop damage from insects, and overused land due to excessive cultivation.

        This Q&A pair provides information about nutrition and health topics.


[('is_pregnancy', True)]
Question: Which demographic is primarily at risk for malnutrition?
        Answer: Malnourishment principally affects young children under two years old; nonetheless, it also impacts individuals below five years of age, teenagers, pregnant or nursing mothers, the elderl

In [22]:
@function_tool
def nutrition_qa_lookup_tool(query: str, max_results: int = 3) -> str:
    """
    Tool function for a RAG database to look up nutrtion information for some conditions.

    Args:
        query: The nutrition item to look up.
        max_results: The maximum number of results to return.

    Returns:
        A string containing the nutrition information.
    """

    results = nutrition_qa_db.query(query_texts=[query], n_results=max_results)

    if not results["documents"][0]:
        return f"No information found for: {query}"

    # Format results for the agent
    formatted_results = []
    for i, doc in enumerate(results["documents"][0]):
        formatted_results.append(doc)

    return "Related answers to your question:\n" + "\n".join(formatted_results)

In [20]:
# nutrition_QA_lookup_tool("What foods shouldn't be eaten in pregnancy?")

In [23]:
nutrition_qa_agent = Agent(
    name="Nutrition Q&A Assistant",
    instructions="""
    You are a helpful nutrition assistant giving out nutrition information, especially around pregnancy.
    You give concise answers.
    If you need to look up information, use the nutrition_qa_lookup_tool.
    """,
    tools=[nutrition_qa_lookup_tool])

In [24]:
with trace("Nutrition Q&A with RAG"):
    result = await Runner.run(
        calorie_agent,
        "What foods should be avoided in pregnancy? Why?",
    )
    print(result.final_output)

Here are foods to avoid or limit in pregnancy and why:

- Raw or undercooked eggs/meat and unpasteurized dairy: risk of Salmonella and other bacteria.
- Unpasteurized milk/dairy and soft cheeses (e.g., Brie, feta, camembert, blue cheese) unless clearly labeled pasteurized: higher risk of listeria.
- deli meats, hot dogs, and smoked meats unless heated until steaming: listeria risk.
- Raw or undercooked seafood, including sushi with raw fish and shellfish: foodborne illness and parasites.
- Certain fish high in mercury (e.g., shark, swordfish, king mackerel, tilefish, bigeye tuna): mercury can affect fetal development; limit albacore tuna and tuna steaks.
- Unwashed fruits/vegetables: risk of toxoplasmosis and other contaminants.
- Raw sprouts (alfalfa, bean sprouts): bacteria risk that is hard to wash off.
- Alcohol: can cause fetal alcohol spectrum disorders; no safe amount established.
- Excess vitamin A from supplements/animal liver: high intake can harm fetal development.

General 