In [2]:
import os
import lamini
from dotenv import load_dotenv

In [4]:
# Load environment variables
load_dotenv()
lamini.api_key = os.getenv('lamini_API_KEY')

In [11]:
llm = lamini.Lamini("meta-llama/Meta-Llama-3.1-8B-Instruct")

In [3]:
def make_llama_3_prompt(user, system=""):
    system_prompt = ""
    if system != "":
        system_prompt = (
            f"<|start_header_id|>system<|end_header_id|>\n\n{system}<|eot_id|>"
        )
    return f"<|begin_of_text|>{system_prompt}<|start_header_id|>user<|end_header_id|>\n\n{user}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

In [13]:
def get_schema():
    return """\
0|Team|TEXT eg. "Toronto Raptors"
1|NAME|TEXT eg. "Otto Porter Jr."
2|Jersey|TEXT eg. "0" and when null has a value "NA"
3|POS|TEXT eg. "PF"
4|AGE|INT eg. "22" in years
5|HT|TEXT eg. `6' 7"` or `6' 10"`
6|WT|TEXT eg. "232 lbs" 
7|COLLEGE|TEXT eg. "Michigan" and when null has a value "--"
8|SALARY|TEXT eg. "$9,945,830" and when null has a value "--"
"""

In [14]:
question = """Who is the highest paid NBA player?"""

In [15]:
system = f"""You are an NBA analyst with 15 years of experience writing complex SQL queries. Consider the nba_roster table with the following schema:
{get_schema()}

Write a sqlite query to answer the following question. Follow instructions exactly"""
prompt = make_llama_3_prompt(question, system)

In [16]:
generated_query = llm.generate(prompt, output_type={"sqlite_query": "str"}, max_new_tokens=200)

generated_query

{'sqlite_query': 'SELECT NAME FROM nba_roster ORDER BY SALARY DESC LIMIT 1'}

## Diagnose Hallucinations

The **wrong** query looks like this:

```sql
SELECT NAME, SALARY
FROM nba_roster
WHERE salary != '--'
ORDER BY CAST(SALARY AS REAL) DESC
LIMIT 1;
```


The **correct** query is:

```sql
SELECT salary, name 
FROM nba_roster
WHERE salary != '--'
ORDER BY CAST(REPLACE(REPLACE(salary, '$', ''), ',','') AS INTEGER) DESC
LIMIT 1;
```