In [1]:
%%capture --no-stderr
%pip install --upgrade --quiet -r requirements.txt

In [None]:
import json
# from llama_index.llms.huggingface import HuggingFaceLLM
# from llama_index.core.tools import QueryEngineTool, ToolMetadata
# from llama_index.core.agent import ReActAgent

# Load model
from agents import load_model
model, pipe = load_model('agents/Llama-3_2-3B-Instruct')

# Load retriever
from agents import llamaindex_retriever

# Load prompts
from agents import prompts
system_prompt = prompts.metadata_prompt
sql_prompt = prompts.sql_prompt
checker_prompt = prompts.checker_prompt








## Create database replica

In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

## Create metadata

### Samples and Info Schemas

In [3]:
# List tables
l_tables = db.get_usable_table_names()

# Create metadata
d_metadata = {}
for table in l_tables:
    # Sample 10 rows from each table
    sample = db.run(f"SELECT * FROM {table} LIMIT 5;", fetch="cursor")
    # Get table schema
    infoschema = db.run(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';", fetch="cursor")
    # Store metadata
    d_metadata[table] = {
        "sample": list(sample.mappings()),
        "infoschema": list(infoschema.mappings())[0]['sql'],
    }

### GenAI Metadata

In [7]:
for table in l_tables:
    print(f"Processing table: {table}")
    # Generate response
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"{d_metadata[table]}"},
    ]
    response = pipe(messages)

    # Clean response
    max_retries = 3
    for attempt in range(max_retries):
        try:
            response = pipe(messages)
            json_string = response[0]['generated_text'][2]['content'].replace("'", '"')
            data_dict = json.loads(json_string)
            break  # Exit loop if successful
        except Exception as e:
            print(f"Error processing table {table}: {e}")
            if attempt == max_retries - 1:
                raise  # Re-raise the exception if max retries reached
            print(f"Retrying... ({attempt + 1}/{max_retries})")

    # Store metadata
    d_metadata[table]['description'] = data_dict['description']
    d_metadata[table]['use case'] = data_dict['use case']

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Album


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Artist


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Customer


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Employee


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Genre


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Invoice


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: InvoiceLine


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: MediaType


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Playlist


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Error processing table Playlist: Expecting ',' delimiter: line 3 column 207 (char 513)
Retrying... (1/3)


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: PlaylistTrack


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Processing table: Track


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [None]:
# Convert RowMapping to Dict
for table in l_tables:
    d_metadata[table]['sample'] = [dict(row) for row in d_metadata[table]['sample']]

# Save dictionary to a JSON file
with open('data/metadata.json', 'w') as json_file:
    json.dump(d_metadata, json_file, indent=4)

In [26]:
with open('data/metadata.json', 'r') as json_file:
    d_metadata = json.load(json_file)

In [28]:
# Create text files
for table in l_tables:
    with open(f'data/metadata/{table}.txt', 'w') as txt_file:
        txt_file.write(f"\n\nTable: {table}\n\n")
        txt_file.write(f"Description: \n{d_metadata[table]['description']}\n\n")
        txt_file.write(f"Use Case: \n{d_metadata[table]['use case']}\n\n")
        txt_file.write(f"Schema: \n{d_metadata[table]['infoschema']}\n\n")
        txt_file.write(f"Sample:\n")
        for row in d_metadata[table]['sample']:
            txt_file.write(f"{row}\n")

## Create RAG

In [None]:
db_engine = llamaindex_retriever(
    [f'data/metadata/{table}.txt' for table in l_tables]
    )

In [None]:
prompt = "Which album has the most tracks?"
tables = db_engine.retrieve(prompt)
messages = [
        {"role": "system", "content": sql_prompt},
        {"role": "sqlagent", "content": '\n'.join([t.text for t in tables])},
        {"role": "user", "content": prompt},
    ]
response = pipe(messages)
query = response[0]['generated_text'][3]['content'].replace('```sql', '').replace('```', '').strip()
print(query)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


SELECT T1.Title, COUNT(T2.TrackId) AS NumTracks 
FROM Album AS T1 
JOIN Track AS T2 ON T1.AlbumId = T2.AlbumId 
GROUP BY T1.Title 
ORDER BY NumTracks DESC LIMIT 1


In [13]:
db.run(fr"""{query}""")

"[('Greatest Hits', 57)]"

In [14]:
prompt = "What were the total sales from customers in Brazil?"
tables = db_engine.retrieve(prompt)
messages = [
        {"role": "system", "content": sql_prompt},
        {"role": "sqlagent", "content": '\n'.join([t.text for t in tables])},
        {"role": "user", "content": prompt},
    ]
response = pipe(messages)
query = response[0]['generated_text'][3]['content'].replace('```sql', '').replace('```', '').strip()
print(query)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


SELECT SUM(T2.Total) FROM Customer AS T1 INNER JOIN Invoice AS T2 ON T1.CustomerId = T2.CustomerId WHERE T1.Country = 'Brazil'


In [16]:
db.run(fr"""{query}""")

'[(190.1,)]'

In [17]:
prompt = "What were the top 10 total sales in $ by track from customers in Brazil?"
tables = db_engine.retrieve(prompt)
messages = [
        {"role": "system", "content": sql_prompt},
        {"role": "sqlagent", "content": '\n'.join([t.text for t in tables])},
        {"role": "user", "content": prompt},
    ]
response = pipe(messages)
query = response[0]['generated_text'][3]['content'].replace('```sql', '').replace('```', '').strip()
print(query)
try:
    db.run(fr"""{query}""")
except Exception as e:
    messages = [
            {"role": "system", "content": checker_prompt},
            {"role": "sqlagent", "content": '\n'.join([t.text for t in tables])},
            {"role": "user", "content": e},
        ]
    response = pipe(messages)
    query = response[0]['generated_text'][3]['content'].replace('```sql', '').replace('```', '').strip()
    print(query)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


SELECT T1.Name, SUM(T2.Total) AS TotalSales 
FROM Track AS T1 
INNER JOIN InvoiceLine AS T3 ON T1.TrackId = T3.TrackId 
INNER JOIN Invoice AS T2 ON T3.InvoiceId = T2.InvoiceId 
INNER JOIN Customer AS T4 ON T2.CustomerId = T4.CustomerId 
WHERE T4.Country = 'Brazil' 
GROUP BY T1.Name 
ORDER BY TotalSales DESC 
LIMIT 10


In [18]:
db.run(fr"""{query}""")

'[(\'Untitled\', 27.72), (\'Água de Beber\', 13.86), (\'Your Blue Room\', 13.86), ("You\'ve Been A Long Time Coming", 13.86), ("You\'re My Best Friend", 13.86), (\'X-9 2001\', 13.86), (\'Why Go\', 13.86), (\'Wanted Dread And Alive\', 13.86), (\'Vai Valer\', 13.86), (\'TriboTchan\', 13.86)]'