llama-3-8b-Instruct-bnb-4bit-synthetic_text_to_sql-lora-3epochs-Q5_K_M

#### Libraries

In [1]:
import dspy
import random
from dotenv import load_dotenv
from dspy.datasets import DataLoader
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShotWithRandomSearch, LabeledFewShot

#from src.starling import StarlingLM # <- Custom Local Model Client for llama3-8b

_ = load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


#### LLM

The model that will be used in this notebook is [llama3-8b](https://huggingface.co/Nexusflow/Starling-LM-7B-beta), and the evaluation model will be [GPT-4 Turbo](https://openai.com/gpt-4).

In [2]:
# Share generation args between models
generation_args = {
    "temperature":0,
    "max_tokens":500,
    "stop":"\n\n",
    "model_type":"chat",
    "n": 1
}
# Model specific args
model_info = {
    "gpt-4": {"model": "gpt-4-0125-preview", "api_base": "https://api.openai.com/v1/"},
    "starling": {"model": "Nexusflow/Starling-LM-7B-beta"}
}

In [3]:
# Set up the models
# lm = StarlingLM(**model_info["starling"], **generation_args)
# evaluator_lm = dspy.OpenAI(**model_info["gpt-4"], **generation_args)

lm = dspy.OllamaLocal(model="Phi-3-medium-4k-instruct-synthetic_text_to_sql-lora-3epochs-q5_k_m:latest", base_url='http://localhost:11435')
# evaluator_lm = dspy.OpenAI(**model_info["gpt-4"], **generation_args)
evaluator_lm = dspy.OllamaLocal(model='llama3:70b', base_url='http://localhost:11434')

dspy.configure(lm=lm)

In [4]:
# Testing inference of Starling
lm("What is the capital of Colombia?")

['Bogotá.']

#### Load dataset

The dataset that will be used in this notebook is [gretelai/synthetic_text_to_sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql)

In [5]:
# Define random seed
random.seed(1399)

In [6]:
# # Load dataset
# dl = DataLoader()
# trainset = dl.from_huggingface(
#     dataset_name="gretelai/synthetic_text_to_sql", # Dataset name from Huggingface
#     fields=("sql_prompt", "sql_context", "sql"), # Fields needed
#     input_keys=("sql_prompt", "sql_context"), # What our model expects to recieve to generate an output
#     split="train"
# )

# testset = dl.from_huggingface(
#     dataset_name="gretelai/synthetic_text_to_sql", # Dataset name from Huggingface
#     fields=("sql_prompt", "sql_context", "sql"), # Fields needed
#     input_keys=("sql_prompt", "sql_context"), # What our model expects to recieve to generate an output
#     split="test"
# )


# trainset = dl.sample(dataset=trainset, n=100)
# testset = dl.sample(dataset=testset, n=75)
# finetuneset = dl.sample(dataset=trainset, n=100)

# _trainval = dl.train_test_split(dataset=trainset, test_size=0.25, random_state=1399) # 25% of training data for validation
# trainset, valset = _trainval["train"], _trainval["test"]

# len(trainset), len(valset), len(testset)

In [7]:
# Load dataset
dl = DataLoader()

testset = dl.from_huggingface(
    dataset_name="gretelai/synthetic_text_to_sql", # Dataset name from Huggingface
    fields=("sql_prompt", "sql_context", "sql"), # Fields needed
    input_keys=("sql_prompt", "sql_context"), # What our model expects to recieve to generate an output
    split="test"
)
# print(len(testset))

testset = dl.sample(dataset=testset, n=200)

# Calculate the sizes of each set
total_size = len(testset)
train_size = int(total_size * 0.4)  # 40% of the data
val_size = int(total_size * 0.2)  # 20% of the data

# Split the dataset into train, validation and test sets
trainset = testset[:train_size]
valset = testset[train_size:train_size + val_size]
testset = testset[train_size + val_size:]

print(len(trainset), len(valset), len(testset))

80 40 80


In [8]:
# Verify an example of the dataset
sample = trainset[0]
for k, v in sample.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL_PROMPT:

What is the average monthly data usage for each mobile network operator?

SQL_CONTEXT:

CREATE TABLE mobile_operators (operator_id INT, operator_name VARCHAR(50)); CREATE TABLE mobile_plans (plan_id INT, plan_name VARCHAR(50), operator_id INT, data_limit INT); CREATE TABLE usage (usage_id INT, subscriber_id INT, plan_id INT, usage_amount INT, usage_date DATE);

SQL:

SELECT o.operator_name, AVG(u.usage_amount) AS avg_monthly_data_usage FROM mobile_operators o INNER JOIN mobile_plans p ON o.operator_id = p.operator_id INNER JOIN usage u ON p.plan_id = u.plan_id WHERE u.usage_date >= DATEADD(month, -1, GETDATE()) GROUP BY o.operator_name;


In [9]:
# # Verify an example of the dataset
# sample = dl.sample(dataset=trainset, n=1)[0]
# for k, v in sample.items():
#     print(f"\n{k.upper()}:\n")
#     print(v)

#### Signature (Input/Output)

In [10]:
class TextToSql(dspy.Signature):
    """Transform a natural language query into a SQL query."""

    sql_prompt = dspy.InputField(desc="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.OutputField(desc="SQL query")

### Inference

#### Baseline Inference

In [11]:
generate_sql_query = dspy.Predict(signature=TextToSql)

result = generate_sql_query(
    sql_prompt=sample["sql_prompt"],
    sql_context=sample["sql_context"]
)

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL:

SELECT mo.operator_name, AVG(u.usage_amount) as avg_data_usage FROM mobile_operators mo JOIN mobile_plans mp ON mo.operator_id = mp.operator_id JOIN usage u ON mp.plan_id = u.plan_id GROUP BY mo.operator_name;


#### ChainOfThought Inference

In [12]:
generate_sql_query = dspy.ChainOfThought(signature=TextToSql)

result = generate_sql_query(
    sql_prompt=sample["sql_prompt"],
    sql_context=sample["sql_context"]
)

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)


RATIONALE:

CREATE VIEW avg_monthly_data_usage AS SELECT mobile_plans.operator_id, AVG(usage_amount / 12) as avg_monthly_data_usage FROM usage JOIN mobile_plans ON usage.plan_id = mobile_plans.plan_id GROUP BY mobile_plans.operator_id;

SQL:

CREATE VIEW avg_monthly_data_usage AS SELECT mobile_plans.operator_name, AVG(usage_amount / 12) as avg_monthly_data_usage FROM usage JOIN mobile_plans ON usage.plan_id = mobile_plans.plan_id GROUP BY mobile_plans.operator_name;


### Metric of evaluation

#### Metric definition

In [13]:
class Correctness(dspy.Signature):
    """Assess if the SQL query accurately answers the given natural language query based on the provided context."""

    sql_prompt = dspy.InputField(desc="Natural language query ")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.InputField(desc="SQL query")
    correct = dspy.OutputField(desc="Indicate whether the SQL query correctly answers the natural language query based on the given context", prefix="Yes/No:")

In [14]:
def correctness_metric(example, pred, trace=None):
    sql_prompt, sql_context, sql = example.sql_prompt, example.sql_context, pred.sql

    correctness = dspy.Predict(Correctness)

    with dspy.context(lm=evaluator_lm): 
        correct = correctness(
            sql_prompt=sql_prompt,
            sql_context=sql_context,
            sql=sql,
        )
    
    score = int(correct.correct=="Yes")

    if trace is not None:
        return score == 1

    return score

#### Evaluate single data point

In [15]:
_correctness = correctness_metric(
    example=sample,
    pred=result
)
print(f"Correct SQL query: {'Yes' if _correctness else 'No'}")

Correct SQL query: Yes


In [16]:
evaluator_lm.inspect_history(n=1)




Assess if the SQL query accurately answers the given natural language query based on the provided context.

---

Follow the following format.

Sql Prompt: Natural language query

Sql Context: Context for the query

Sql: SQL query

Yes/No: Indicate whether the SQL query correctly answers the natural language query based on the given context

---

Sql Prompt: What is the average monthly data usage for each mobile network operator?

Sql Context: CREATE TABLE mobile_operators (operator_id INT, operator_name VARCHAR(50)); CREATE TABLE mobile_plans (plan_id INT, plan_name VARCHAR(50), operator_id INT, data_limit INT); CREATE TABLE usage (usage_id INT, subscriber_id INT, plan_id INT, usage_amount INT, usage_date DATE);

Sql: CREATE VIEW avg_monthly_data_usage AS SELECT mobile_plans.operator_name, AVG(usage_amount / 12) as avg_monthly_data_usage FROM usage JOIN mobile_plans ON usage.plan_id = mobile_plans.plan_id GROUP BY mobile_plans.operator_name;

Yes/No: Yes





'\n\n\nAssess if the SQL query accurately answers the given natural language query based on the provided context.\n\n---\n\nFollow the following format.\n\nSql Prompt: Natural language query\n\nSql Context: Context for the query\n\nSql: SQL query\n\nYes/No: Indicate whether the SQL query correctly answers the natural language query based on the given context\n\n---\n\nSql Prompt: What is the average monthly data usage for each mobile network operator?\n\nSql Context: CREATE TABLE mobile_operators (operator_id INT, operator_name VARCHAR(50)); CREATE TABLE mobile_plans (plan_id INT, plan_name VARCHAR(50), operator_id INT, data_limit INT); CREATE TABLE usage (usage_id INT, subscriber_id INT, plan_id INT, usage_amount INT, usage_date DATE);\n\nSql: CREATE VIEW avg_monthly_data_usage AS SELECT mobile_plans.operator_name, AVG(usage_amount / 12) as avg_monthly_data_usage FROM usage JOIN mobile_plans ON usage.plan_id = mobile_plans.plan_id GROUP BY mobile_plans.operator_name;\n\nYes/No:\x1b[32

#### Evaluate entire dataset - GPT 3.5

<div style="background-color: #F0F0F0; padding: 10px; border-radius: 5px;"> <p style="color: #4B4B4B; font-size: 18px; font-weight: bold; margin: 0;"> 📊 Baseline Evaluation </p> <p style="color: #4B4B4B; font-size: 16px; margin: 5px 0 0;"> Without any optimization, <strong>Starling7B</strong> achieves an <strong>80% correctness in validation (25 samples)</strong> and <strong>70.07% correctness in test (75 samples).</strong> </p> </div>

In [None]:
print("GPT 3.5 Turbo - Validation Score: \n")
with dspy.context(lm=dspy.OpenAI(model="gpt-3.5-turbo-0125", api_base="https://api.openai.com/v1/", **generation_args)):
    evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
    evaluate(generate_sql_query)

In [None]:
print("GPT 3.5 Turbo - Test Score: \n")
with dspy.context(lm=dspy.OpenAI(model="gpt-3.5-turbo-0125", api_base="https://api.openai.com/v1/", **generation_args)):
    evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
    evaluate(generate_sql_query)

#### Evaluate entire dataset - llama3-8b

<div style="background-color: #FFCCCB; padding: 10px; border-radius: 5px;"> <p style="color: #8B0000; font-size: 18px; font-weight: bold; margin: 0;"> ⚠️ Evaluation Stage 1 </p> <p style="color: #8B0000; font-size: 16px; margin: 5px 0 0;"> Without any optimization, <strong>llama3-8b</strong> achieves an <strong>72% correctness in validation (25 samples)</strong> and <strong>50.67% correctness in test (75 samples).</strong> </p> </div>

In [17]:
print("llama3-8b - Validation Score: \n")
evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(generate_sql_query)

llama3-8b - Validation Score: 



Average Metric: 15 / 40  (37.5): 100%|██████████| 40/40 [03:33<00:00,  5.33s/it]


37.5

In [18]:
print("llama3-8b - Test Score: \n")
evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(generate_sql_query)

llama3-8b - Test Score: 



Average Metric: 28 / 80  (35.0): 100%|██████████| 80/80 [06:31<00:00,  4.90s/it]


35.0

### Optimize for Text2SQL

#### Create program

In [19]:
# Define the program ~ You can think of this a Pytorch model.
class TextToSqlProgram(dspy.Module):
    def __init__(self):
        super().__init__()
        self.program = dspy.ChainOfThought(signature=TextToSql)
    
    def forward(self, sql_prompt, sql_context):
        return self.program(
            sql_prompt=sql_prompt,
            sql_context=sql_context
        )

### FewShot

In [20]:
# Execute the optimizer -> this only adds few shots to the prompt
optimizer = LabeledFewShot(k=4)
optmized_program = optimizer.compile(student=TextToSqlProgram(), trainset=trainset)

In [21]:
optmized_program(sql_context=sample["sql_context"], sql_prompt=sample["sql_prompt"])

Prediction(
    rationale='Sql Prompt: What is the average monthly data usage for each mobile network operator?\nSql Context: CREATE TABLE mobile_operators (operator_id INT, operator_name VARCHAR(50)); CREATE TABLE mobile_plans (plan_id INT, plan_name VARCHAR(50), operator_id INT, data_limit INT); CREATE TABLE usage (usage_id INT, subscriber_id INT, plan_id INT, usage_amount INT, usage_date DATE);',
    sql='SELECT mo.operator_name, AVG(u.usage_amount) AS mobile? What is the average monthly data usage for each mobile network operator?\n\nSql Context: CREATE TABLE mobile_operators ('
)

#### What is happening inside?

In [22]:
lm.inspect_history(n=1)




Transform a natural language query into a SQL query.

---

Follow the following format.

Sql Prompt: Natural language query

Sql Context: Context for the query

Reasoning: Let's think step by step in order to ${produce the sql}. We ...

Sql: SQL query

---

Sql Prompt: How many solar power projects were completed in California and Texas in 2020 and 2021?
Sql Context: CREATE TABLE solar_projects (project_id INT, state VARCHAR(50), completion_year INT); INSERT INTO solar_projects (project_id, state, completion_year) VALUES (1, 'California', 2020), (2, 'Texas', 2021), (3, 'California', 2019), (4, 'Texas', 2020), (5, 'California', 2021), (6, 'Texas', 2019), (7, 'California', 2018), (8, 'Texas', 2018);
Sql: SELECT state, COUNT(*) FROM solar_projects WHERE completion_year IN (2020, 2021) AND state IN ('California', 'Texas') GROUP BY state;

---

Sql Prompt: What is the sum of lanthanum imports to Norway and Sweden for the years 2018 and 2019?
Sql Context: CREATE TABLE lanthanum_imports (y

"\n\n\nTransform a natural language query into a SQL query.\n\n---\n\nFollow the following format.\n\nSql Prompt: Natural language query\n\nSql Context: Context for the query\n\nReasoning: Let's think step by step in order to ${produce the sql}. We ...\n\nSql: SQL query\n\n---\n\nSql Prompt: How many solar power projects were completed in California and Texas in 2020 and 2021?\nSql Context: CREATE TABLE solar_projects (project_id INT, state VARCHAR(50), completion_year INT); INSERT INTO solar_projects (project_id, state, completion_year) VALUES (1, 'California', 2020), (2, 'Texas', 2021), (3, 'California', 2019), (4, 'Texas', 2020), (5, 'California', 2021), (6, 'Texas', 2019), (7, 'California', 2018), (8, 'Texas', 2018);\nSql: SELECT state, COUNT(*) FROM solar_projects WHERE completion_year IN (2020, 2021) AND state IN ('California', 'Texas') GROUP BY state;\n\n---\n\nSql Prompt: What is the sum of lanthanum imports to Norway and Sweden for the years 2018 and 2019?\nSql Context: CREATE

#### Evaluate the optimized program


<div style="background-color: #FFF8DC; padding: 10px; border-radius: 5px;"> <p style="color: #DAA520; font-size: 18px; font-weight: bold; margin: 0;"> 🌟 Evaluation Stage 2 </p> <p style="color: #DAA520; font-size: 16px; margin: 5px 0 0;"> With <strong>Few Shot</strong> optimization, <strong>llama3-8b</strong> achieves an <strong>64% correctness in validation (25 samples)</strong> and <strong>60% correctness in test (75 samples).</strong> </p> </div>

In [23]:
print("llama3-8b + FewShotOptimizer - Validation Score: \n")
evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program)

llama3-8b + FewShotOptimizer - Validation Score: 



Average Metric: 27 / 40  (67.5): 100%|██████████| 40/40 [02:45<00:00,  4.13s/it]


67.5

In [24]:
print("llama3-8b + FewShotOptimizer - Test Score: \n")
evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program)

llama3-8b + FewShotOptimizer - Test Score: 



Average Metric: 38 / 80  (47.5): 100%|██████████| 80/80 [05:49<00:00,  4.37s/it]


47.5

### BootstrapFewShotWithRandomSearch

[DSPy docs](https://dspy-docs.vercel.app/docs/building-blocks/optimizers) recommend that in a setup like the one with have at hand, with ~50 samples, the best option is to use `BootstrapFewShotWithRandomSearch`:

![image](assets/dspy.png)

In [25]:
optimizer2 = BootstrapFewShotWithRandomSearch(metric=correctness_metric, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=5)
optmized_program_2 = optimizer2.compile(student = TextToSqlProgram(), trainset=trainset, valset=valset)

Average Metric: 17 / 40  (42.5): 100%|██████████| 40/40 [03:30<00:00,  5.26s/it]
Average Metric: 25 / 40  (62.5): 100%|██████████| 40/40 [02:47<00:00,  4.18s/it]
  5%|▌         | 4/80 [01:01<19:33, 15.45s/it]
Average Metric: 25 / 40  (62.5): 100%|██████████| 40/40 [02:54<00:00,  4.35s/it]
  4%|▍         | 3/80 [00:40<17:19, 13.50s/it]
Average Metric: 23 / 40  (57.5): 100%|██████████| 40/40 [02:54<00:00,  4.35s/it]
  1%|▏         | 1/80 [00:09<13:09,  9.99s/it]
Average Metric: 29 / 40  (72.5): 100%|██████████| 40/40 [02:29<00:00,  3.75s/it]
  2%|▎         | 2/80 [00:26<17:18, 13.32s/it]
Average Metric: 26 / 40  (65.0): 100%|██████████| 40/40 [02:51<00:00,  4.29s/it]
  1%|▏         | 1/80 [00:11<14:48, 11.25s/it]
Average Metric: 23 / 40  (57.5): 100%|██████████| 40/40 [03:12<00:00,  4.82s/it]
  2%|▎         | 2/80 [00:23<15:21, 11.82s/it]
Average Metric: 24 / 40  (60.0): 100%|██████████| 40/40 [03:00<00:00,  4.51s/it]
  2%|▎         | 2/80 [00:21<14:09, 10.89s/it]
Average Metric: 24 / 40

In [26]:
optmized_program_2(sql_context=sample["sql_context"], sql_prompt=sample["sql_prompt"])

Prediction(
    rationale='Sql: SELECT m.operator_name, AVG(u.usage_amount) as avg_monthly_data_usage FROM mobile_operators m INNER JOIN mobile_plans mp ON m.operator_id = mp.operator_id INNER JOIN usage u ON mp.plan_id = u.plan_id GROUP BY m.operator_name;',
    sql='SELECT m.operator_name, AVG(u.usage_amount) as avg_monthly_data_usage FROM mobile_operators m INNER JOIN mobile_plans mp ON m.operator_id = mp.operator_id INNER JOIN usage u ON mp.plan_id = u.plan_id GROUP BY m.operator_name;'
)

In [27]:
lm.inspect_history(n=1)




Transform a natural language query into a SQL query.

---

Sql Prompt: What is the total number of labor rights advocacy events for each region, by region name?
Sql Context: CREATE TABLE Region (Id INT, Name VARCHAR(50)); INSERT INTO Region (Id, Name) VALUES (1, 'Region A'), (2, 'Region B'), (3, 'Region C'); CREATE TABLE AdvocacyEvents (Id INT, RegionId INT, EventCount INT); INSERT INTO AdvocacyEvents (Id, RegionId, EventCount) VALUES (1, 1, 50), (2, 1, 30), (3, 2, 70), (4, 2, 80), (5, 3, 60), (6, 3, 40);
Sql: SELECT R.Name, SUM(A.EventCount) as TotalEvents FROM Region R JOIN AdvocacyEvents A ON R.Id = A.RegionId GROUP BY R.Name;

Sql Prompt: Find and delete duplicate records in the resource_depletion table
Sql Context: CREATE TABLE resource_depletion (id INT, resource VARCHAR(255), depletion_rate DECIMAL(10,2));
Sql: DELETE t1 FROM resource_depletion t1 INNER JOIN (SELECT id, resource, depletion_rate, COUNT(*) FROM resource_depletion GROUP BY resource, depletion_rate HAVING COUNT(*

'\n\n\nTransform a natural language query into a SQL query.\n\n---\n\nSql Prompt: What is the total number of labor rights advocacy events for each region, by region name?\nSql Context: CREATE TABLE Region (Id INT, Name VARCHAR(50)); INSERT INTO Region (Id, Name) VALUES (1, \'Region A\'), (2, \'Region B\'), (3, \'Region C\'); CREATE TABLE AdvocacyEvents (Id INT, RegionId INT, EventCount INT); INSERT INTO AdvocacyEvents (Id, RegionId, EventCount) VALUES (1, 1, 50), (2, 1, 30), (3, 2, 70), (4, 2, 80), (5, 3, 60), (6, 3, 40);\nSql: SELECT R.Name, SUM(A.EventCount) as TotalEvents FROM Region R JOIN AdvocacyEvents A ON R.Id = A.RegionId GROUP BY R.Name;\n\nSql Prompt: Find and delete duplicate records in the resource_depletion table\nSql Context: CREATE TABLE resource_depletion (id INT, resource VARCHAR(255), depletion_rate DECIMAL(10,2));\nSql: DELETE t1 FROM resource_depletion t1 INNER JOIN (SELECT id, resource, depletion_rate, COUNT(*) FROM resource_depletion GROUP BY resource, depletion

#### Evaluate the optimized program

<div style="background-color: #E0F8E0; padding: 10px; border-radius: 5px;"> <p style="color: #006400; font-size: 18px; font-weight: bold; margin: 0;"> ✅ Evaluation Stage 3 </p> <p style="color: #006400; font-size: 16px; margin: 5px 0 0;"> With <strong>BootstrapFewShotWithRandomSearch</strong> optimization, <strong>llama3-8b</strong> achieves an <strong>80% correctness in validation (25 samples)</strong> and <strong>68% correctness in test (75 samples).</strong> </p> </div>

In [28]:
print("llama3-8b + BootstrapFewShotWithRandomSearch - Validation Score: \n")
evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program_2)

llama3-8b + BootstrapFewShotWithRandomSearch - Validation Score: 



Average Metric: 29 / 40  (72.5): 100%|██████████| 40/40 [02:11<00:00,  3.30s/it]


72.5

In [29]:
print("llama3-8b + BootstrapFewShotWithRandomSearch - Test Score: \n")
evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program_2)

llama3-8b + BootstrapFewShotWithRandomSearch - Test Score: 



Average Metric: 54 / 80  (67.5): 100%|██████████| 80/80 [04:03<00:00,  3.04s/it]


67.5