Text-to-SQL Guide Using Olympic Medal Data from CSV

Step 1: Install Required Libraries

In [1]:
# Install necessary libraries
%pip install llama-index llama-index-llms-groq groq llama-index-embeddings-huggingface sqlalchemy pandas python-dotenv

You should consider upgrading via the '/Users/taurangela/Desktop/Github/Database-Query-Retrival/env/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


Step 2: Import Libraries

In [2]:
import os
import pandas as pd
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    text
)

from dotenv import load_dotenv
from llama_index.core import SQLDatabase
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.groq import Groq

import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


Step 3: Load Data from CSV

In [3]:
# Load the CSV data into a DataFrame
csv_file_path = "/Users/taurangela/Desktop/Github/Database-Query-Retrival/data/olympics2024.csv"  # Replace with your CSV file path
df = pd.read_csv(csv_file_path)
df.head()

Unnamed: 0,Rank,Country,Country Code,Gold,Silver,Bronze,Total
0,1,United States,US,40,44,42,126
1,2,China,CHN,40,27,24,91
2,3,Japan,JPN,20,12,13,45
3,4,Australia,AUS,18,19,16,53
4,5,France,FRA,16,26,22,64


Step 4: Create Database Schema

In [4]:
# Create in-memory SQLite database
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# Define the Olympic medals table with correct column names
table_name = "olympic_medals"
olympic_medals_table = Table(
    table_name,
    metadata_obj,
    Column("Rank", Integer),
    Column("Country", String(16)),
    Column("Country Code", String(3)),
    Column("Gold", Integer),
    Column("Silver", Integer),
    Column("Bronze", Integer),
    Column("Total", Integer),
)

In [5]:


# Drop the table if it exists
with engine.connect() as connection:
    connection.execute(text(f"DROP TABLE IF EXISTS {table_name}"))

# Create the table in the database
metadata_obj.create_all(engine)

Step 5: Insert Data into the Table

In [6]:
# Insert the data from the DataFrame into the SQL table
df.to_sql(table_name, con=engine, if_exists='append', index=False)

91

In [7]:
# Verify the insertion by querying the table
query = text(f"SELECT * FROM {table_name}")
with engine.connect() as connection:
    results = connection.execute(query).fetchall()
    print(f"Executed Query: {query}")
    for result in results:
        print(result)

Executed Query: SELECT * FROM olympic_medals
(1, 'United States', 'US', 40, 44, 42, 126)
(2, 'China', 'CHN', 40, 27, 24, 91)
(3, 'Japan', 'JPN', 20, 12, 13, 45)
(4, 'Australia', 'AUS', 18, 19, 16, 53)
(5, 'France', 'FRA', 16, 26, 22, 64)
(6, 'Netherlands', 'NED', 15, 7, 12, 34)
(7, 'Great Britain', 'GBG', 14, 22, 29, 65)
(8, 'South Korea', 'KOR', 13, 9, 10, 32)
(9, 'Italy', 'ITA', 12, 13, 15, 40)
(10, 'Germany', 'GER', 12, 13, 8, 33)
(11, 'New Zealand', 'NZ', 10, 7, 3, 20)
(12, 'Canada', 'CAN', 9, 7, 11, 27)
(13, 'Uzbekistan', 'UZB', 8, 2, 3, 13)
(14, 'Hungary', 'HUN', 6, 7, 6, 19)
(15, 'Spain', 'SPA', 5, 4, 9, 18)
(16, 'Sweden', 'SWE', 4, 4, 3, 11)
(17, 'Kenya', 'KEN', 4, 2, 5, 11)
(18, 'Norway', 'NOR', 4, 1, 3, 8)
(19, 'Ireland', 'IRE', 4, 0, 3, 7)
(20, 'Brazil', 'BRZ', 3, 7, 10, 20)
(21, 'Iran', 'IRN', 3, 6, 3, 12)
(22, 'Ukraine', 'UKR', 3, 5, 4, 12)
(23, 'Romania', 'ROM', 3, 4, 2, 9)
(24, 'Georgia', 'GEO', 3, 3, 1, 7)
(25, 'Belgium', 'BEL', 3, 1, 6, 10)
(26, 'Bulgaria', 'BUL', 3, 1

Step 6: Query the Table Using LLM

In [8]:
# Load environment variables from .env file
load_dotenv()

# Get the Groq API key from environment variables
GROQ_API_KEY = os.getenv('GROQ_API_KEY')

# Ensure the API key is loaded correctly
#print(f"Loaded API Key: {GROQ_API_KEY}")


In [9]:
# Initialize the Groq LLM and HuggingFace embedding model
llm = Groq(model="llama3-70b-8192", api_key=GROQ_API_KEY)
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
print(embed_model)

model_name='sentence-transformers/all-MiniLM-L6-v2' embed_batch_size=10 callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x1358bb3a0> num_workers=None max_length=256 normalize=True query_instruction=None text_instruction=None cache_folder=None


Step 7: Perform a Natural Language Query

In [10]:
# Define your prompt to ask for both SQL query and result
prompt = [
    """
    You are an expert in converting English questions to SQL queries!
    The SQL database has the name OLYMPIC_MEDALS and has the following columns - Rank, Country, Country Code, Gold, Silver, Bronze, Total

    For example,
    Example 1 - Which country won the most gold medals? 
    The SQL command will be something like this: 
    SELECT Country FROM OLYMPIC_MEDALS ORDER BY Gold DESC LIMIT 1;
    The result will be: United States
    
    Example 2 - Show the top 5 countries by total medals.
    The SQL command will be something like this: 
    SELECT Country FROM OLYMPIC_MEDALS ORDER BY Total DESC LIMIT 5;
    The result will be: United States, China, Russia, etc.
    
    Provide the SQL query and the result. The SQL code should not have ``` in the beginning or end, and the word SQL should not be included in the output.
    """
]

In [11]:
# Initialize the query engine with custom prompt
query_engine = NLSQLTableQueryEngine(
    sql_database=SQLDatabase(engine, include_tables=["olympic_medals"]), 
    tables=["olympic_medals"], 
    llm=llm, 
    embed_model=embed_model,
    prompt=prompt
)

In [12]:
# Perform a natural language query
query_str = "Which country won the most gold medals?"
response = query_engine.query(query_str)
response


Response(response='The country that won the most gold medals is the United States, with a total of 40 gold medals.', source_nodes=[NodeWithScore(node=TextNode(id_='364c488c-3914-41ad-b2a8-39cb252bf176', embedding=None, metadata={'sql_query': 'SELECT Country, Gold FROM olympic_medals ORDER BY Gold DESC LIMIT 1;', 'result': [('United States', 40)], 'col_keys': ['Country', 'Gold']}, excluded_embed_metadata_keys=['sql_query', 'result', 'col_keys'], excluded_llm_metadata_keys=['sql_query', 'result', 'col_keys'], relationships={}, text="[('United States', 40)]", mimetype='text/plain', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)], metadata={'364c488c-3914-41ad-b2a8-39cb252bf176': {'sql_query': 'SELECT Country, Gold FROM olympic_medals ORDER BY Gold DESC LIMIT 1;', 'result': [('United States', 40)], 'col_keys': ['Country', 'Gold']}, 'sql_query': 'SELECT Country, Gold FROM olympic_

In [13]:
# Extract SQL query and results from the response
metadata = response.source_nodes[0].metadata
sql_query = metadata.get('sql_query', 'No SQL query found')
result = metadata.get('result', 'No result found')

# Display the original response
print(f"Original Response: {response.response}")

# Display the SQL query, result, and the original response
print(f"SQL Query: {sql_query}")
print(f"Result: {result}")


Original Response: The country that won the most gold medals is the United States, with a total of 40 gold medals.
SQL Query: SELECT Country, Gold FROM olympic_medals ORDER BY Gold DESC LIMIT 1;
Result: [('United States', 40)]
