#### Libraries

In [1]:
import dspy
import random
from dotenv import load_dotenv
from dspy.datasets import DataLoader
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShotWithRandomSearch, LabeledFewShot

#from src.starling import StarlingLM # <- Custom Local Model Client for llama3-8b

_ = load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


#### LLM

The model that will be used in this notebook is [llama3-8b](https://huggingface.co/Nexusflow/Starling-LM-7B-beta), and the evaluation model will be [GPT-4 Turbo](https://openai.com/gpt-4).

In [2]:
# Share generation args between models
generation_args = {
    "temperature":0,
    "max_tokens":500,
    "stop":"\n\n",
    "model_type":"chat",
    "n": 1
}
# Model specific args
model_info = {
    "gpt-4": {"model": "gpt-4-0125-preview", "api_base": "https://api.openai.com/v1/"},
    "starling": {"model": "Nexusflow/Starling-LM-7B-beta"}
}

In [3]:
# Set up the models
# lm = StarlingLM(**model_info["starling"], **generation_args)
# evaluator_lm = dspy.OpenAI(**model_info["gpt-4"], **generation_args)

lm = dspy.OllamaLocal(model='llama3-8b-instruct-fp16-lora-60-steps:latest', base_url='http://localhost:11435')
# evaluator_lm = dspy.OpenAI(**model_info["gpt-4"], **generation_args)
evaluator_lm = dspy.OllamaLocal(model='llama3:70b', base_url='http://localhost:11434')

dspy.configure(lm=lm)

In [4]:
# Testing inference of Starling
lm("What is the capital of Colombia?")

['The capital of Colombia is Bogotá.<|eot_id|>']

#### Load dataset

The dataset that will be used in this notebook is [gretelai/synthetic_text_to_sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql)

In [5]:
# Define random seed
random.seed(1399)

In [6]:
# Load dataset
dl = DataLoader()
trainset = dl.from_huggingface(
    dataset_name="gretelai/synthetic_text_to_sql", # Dataset name from Huggingface
    fields=("sql_prompt", "sql_context", "sql"), # Fields needed
    input_keys=("sql_prompt", "sql_context"), # What our model expects to recieve to generate an output
    split="train"
)

testset = dl.from_huggingface(
    dataset_name="gretelai/synthetic_text_to_sql", # Dataset name from Huggingface
    fields=("sql_prompt", "sql_context", "sql"), # Fields needed
    input_keys=("sql_prompt", "sql_context"), # What our model expects to recieve to generate an output
    split="test"
)


trainset = dl.sample(dataset=trainset, n=100)
testset = dl.sample(dataset=testset, n=75)
finetuneset = dl.sample(dataset=trainset, n=100)

_trainval = dl.train_test_split(dataset=trainset, test_size=0.25, random_state=1399) # 25% of training data for validation
trainset, valset = _trainval["train"], _trainval["test"]

len(trainset), len(valset), len(testset)

(75, 25, 75)

In [7]:
# Verify an example of the dataset
sample = dl.sample(dataset=trainset, n=1)[0]
for k, v in sample.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL_PROMPT:

List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.

SQL_CONTEXT:

CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);

SQL:

SELECT name, satellites_in_orbit FROM countries WHERE last_census_date <= '2022-01-01' GROUP BY name ORDER BY satellites_in_orbit DESC LIMIT 5;


#### Signature (Input/Output)

In [8]:
class TextToSql(dspy.Signature):
    """Transform a natural language query into a SQL query."""

    sql_prompt = dspy.InputField(desc="Natural language query")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.OutputField(desc="SQL query")

### Inference

#### Baseline Inference

In [9]:
generate_sql_query = dspy.Predict(signature=TextToSql)

result = generate_sql_query(
    sql_prompt=sample["sql_prompt"],
    sql_context=sample["sql_context"]
)

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)


SQL:

Here is the SQL query:

INSERT INTO countries (id, name, population, satellites_in_orbit, last_census_date)
VALUES (1, 'United States', 331002651, 1955, '2022-01-01');
INSERT INTO countries (id, name, population, satellites_in_orbit, last_census_date)
VALUES (2, 'China', 1439323776, 1550, '2022-01-01');
INSERT INTO countries (id, name, population, satellites_in_orbit, last_census_date)
VALUES (3, 'Russia', 145934027, 1405, '2022-01-01');
INSERT INTO countries (id, name, population,


#### ChainOfThought Inference

In [10]:
generate_sql_query = dspy.ChainOfThought(signature=TextToSql)

result = generate_sql_query(
    sql_prompt=sample["sql_prompt"],
    sql_context=sample["sql_context"]
)

for k, v in result.items():
    print(f"\n{k.upper()}:\n")
    print(v)


RATIONALE:

Sql Prompt: List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.
Sql Context: CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);
Reasoning: Let's think step by step in order to produce the sql. We need to filter the data for the specific date (2022-01-01) and then sort it by the number of satellites in orbit in descending order.

SQL:

SELECT name, satellites_in_orbit FROM countries WHERE last_census_date = '2022-01-01' ORDER BY satellites_in_orbit DESC LIMIT 5


### Metric of evaluation

#### Metric definition

In [11]:
class Correctness(dspy.Signature):
    """Assess if the SQL query accurately answers the given natural language query based on the provided context."""

    sql_prompt = dspy.InputField(desc="Natural language query ")
    sql_context = dspy.InputField(desc="Context for the query")
    sql = dspy.InputField(desc="SQL query")
    correct = dspy.OutputField(desc="Indicate whether the SQL query correctly answers the natural language query based on the given context", prefix="Yes/No:")

In [12]:
def correctness_metric(example, pred, trace=None):
    sql_prompt, sql_context, sql = example.sql_prompt, example.sql_context, pred.sql

    correctness = dspy.Predict(Correctness)

    with dspy.context(lm=evaluator_lm): 
        correct = correctness(
            sql_prompt=sql_prompt,
            sql_context=sql_context,
            sql=sql,
        )
    
    score = int(correct.correct=="Yes")

    if trace is not None:
        return score == 1

    return score

#### Evaluate single data point

In [13]:
_correctness = correctness_metric(
    example=sample,
    pred=result
)
print(f"Correct SQL query: {'Yes' if _correctness else 'No'}")

Correct SQL query: Yes


In [14]:
evaluator_lm.inspect_history(n=1)




Assess if the SQL query accurately answers the given natural language query based on the provided context.

---

Follow the following format.

Sql Prompt: Natural language query

Sql Context: Context for the query

Sql: SQL query

Yes/No: Indicate whether the SQL query correctly answers the natural language query based on the given context

---

Sql Prompt: List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.

Sql Context: CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);

Sql: SELECT name, satellites_in_orbit FROM countries WHERE last_census_date = '2022-01-01' ORDER BY satellites_in_orbit DESC LIMIT 5

Yes/No: Yes





"\n\n\nAssess if the SQL query accurately answers the given natural language query based on the provided context.\n\n---\n\nFollow the following format.\n\nSql Prompt: Natural language query\n\nSql Context: Context for the query\n\nSql: SQL query\n\nYes/No: Indicate whether the SQL query correctly answers the natural language query based on the given context\n\n---\n\nSql Prompt: List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.\n\nSql Context: CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);\n\nSql: SELECT name, satellites_in_orbit FROM countries WHERE last_census_date = '2022-01-01' ORDER BY satellites_in_orbit DESC LIMIT 5\n\nYes/No:\x1b[32m Yes\x1b[0m\n\n\n"

#### Evaluate entire dataset - GPT 3.5

<div style="background-color: #F0F0F0; padding: 10px; border-radius: 5px;"> <p style="color: #4B4B4B; font-size: 18px; font-weight: bold; margin: 0;"> 📊 Baseline Evaluation </p> <p style="color: #4B4B4B; font-size: 16px; margin: 5px 0 0;"> Without any optimization, <strong>Starling7B</strong> achieves an <strong>80% correctness in validation (25 samples)</strong> and <strong>70.07% correctness in test (75 samples).</strong> </p> </div>

In [16]:
print("GPT 3.5 Turbo - Validation Score: \n")
with dspy.context(lm=dspy.OpenAI(model="gpt-3.5-turbo-0125", api_base="https://api.openai.com/v1/", **generation_args)):
    evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
    evaluate(generate_sql_query)

GPT 3.5 Turbo - Validation Score: 



  0%|          | 0/25 [00:00<?, ?it/s]

Average Metric: 20 / 25  (80.0): 100%|██████████| 25/25 [00:09<00:00,  2.77it/s]

Average Metric: 20 / 25  (80.0%)



  df = df.applymap(truncate_cell)


In [17]:
print("GPT 3.5 Turbo - Test Score: \n")
with dspy.context(lm=dspy.OpenAI(model="gpt-3.5-turbo-0125", api_base="https://api.openai.com/v1/", **generation_args)):
    evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
    evaluate(generate_sql_query)

GPT 3.5 Turbo - Test Score: 



Average Metric: 53 / 75  (70.7): 100%|██████████| 75/75 [00:14<00:00,  5.06it/s] 

Average Metric: 53 / 75  (70.7%)



  df = df.applymap(truncate_cell)


#### Evaluate entire dataset - llama3-8b

<div style="background-color: #FFCCCB; padding: 10px; border-radius: 5px;"> <p style="color: #8B0000; font-size: 18px; font-weight: bold; margin: 0;"> ⚠️ Evaluation Stage 1 </p> <p style="color: #8B0000; font-size: 16px; margin: 5px 0 0;"> Without any optimization, <strong>llama3-8b</strong> achieves an <strong>72% correctness in validation (25 samples)</strong> and <strong>50.67% correctness in test (75 samples).</strong> </p> </div>

In [15]:
print("llama3-8b - Validation Score: \n")
evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(generate_sql_query)

llama3-8b - Validation Score: 



Average Metric: 16 / 25  (64.0): 100%|██████████| 25/25 [01:51<00:00,  4.47s/it]


64.0

In [16]:
print("llama3-8b - Test Score: \n")
evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(generate_sql_query)

llama3-8b - Test Score: 



Average Metric: 45 / 75  (60.0): 100%|██████████| 75/75 [05:02<00:00,  4.04s/it]


60.0

### Optimize for Text2SQL

#### Create program

In [17]:
# Define the program ~ You can think of this a Pytorch model.
class TextToSqlProgram(dspy.Module):
    def __init__(self):
        super().__init__()
        self.program = dspy.ChainOfThought(signature=TextToSql)
    
    def forward(self, sql_prompt, sql_context):
        return self.program(
            sql_prompt=sql_prompt,
            sql_context=sql_context
        )

### FewShot

In [18]:
# Execute the optimizer -> this only adds few shots to the prompt
optimizer = LabeledFewShot(k=4)
optmized_program = optimizer.compile(student=TextToSqlProgram(), trainset=trainset)

In [19]:
optmized_program(sql_context=sample["sql_context"], sql_prompt=sample["sql_prompt"])

Prediction(
    rationale='Here is the SQL query:',
    sql="SELECT name, satellites_in_orbit FROM countries WHERE last_census_date = '2022-01-01' ORDER BY satellites_in_orbit DESC LIMIT 5;<|eot_id|>"
)

#### What is happening inside?

In [20]:
lm.inspect_history(n=1)




Transform a natural language query into a SQL query.

---

Follow the following format.

Sql Prompt: Natural language query

Sql Context: Context for the query

Reasoning: Let's think step by step in order to ${produce the sql}. We ...

Sql: SQL query

---

Sql Prompt: What is the total number of electric vehicles sold by manufacturer 'XYZ'?
Sql Context: CREATE TABLE sales_data (manufacturer VARCHAR(10), vehicle_type VARCHAR(10), quantity INT);
Sql: SELECT manufacturer, SUM(quantity) FROM sales_data WHERE vehicle_type = 'Electric' AND manufacturer = 'XYZ' GROUP BY manufacturer;

---

Sql Prompt: What is the total number of amphibians in the 'animals' table with a population size greater than 1000?
Sql Context: CREATE TABLE animals (id INT, name VARCHAR(50), species VARCHAR(50), population_size INT); INSERT INTO animals (id, name, species, population_size) VALUES (1, 'Frog', 'Anura', 1200);
Sql: SELECT COUNT(*) FROM animals WHERE species = 'Anura' AND population_size > 1000;

---

Sq

"\n\n\nTransform a natural language query into a SQL query.\n\n---\n\nFollow the following format.\n\nSql Prompt: Natural language query\n\nSql Context: Context for the query\n\nReasoning: Let's think step by step in order to ${produce the sql}. We ...\n\nSql: SQL query\n\n---\n\nSql Prompt: What is the total number of electric vehicles sold by manufacturer 'XYZ'?\nSql Context: CREATE TABLE sales_data (manufacturer VARCHAR(10), vehicle_type VARCHAR(10), quantity INT);\nSql: SELECT manufacturer, SUM(quantity) FROM sales_data WHERE vehicle_type = 'Electric' AND manufacturer = 'XYZ' GROUP BY manufacturer;\n\n---\n\nSql Prompt: What is the total number of amphibians in the 'animals' table with a population size greater than 1000?\nSql Context: CREATE TABLE animals (id INT, name VARCHAR(50), species VARCHAR(50), population_size INT); INSERT INTO animals (id, name, species, population_size) VALUES (1, 'Frog', 'Anura', 1200);\nSql: SELECT COUNT(*) FROM animals WHERE species = 'Anura' AND popu

#### Evaluate the optimized program


<div style="background-color: #FFF8DC; padding: 10px; border-radius: 5px;"> <p style="color: #DAA520; font-size: 18px; font-weight: bold; margin: 0;"> 🌟 Evaluation Stage 2 </p> <p style="color: #DAA520; font-size: 16px; margin: 5px 0 0;"> With <strong>Few Shot</strong> optimization, <strong>llama3-8b</strong> achieves an <strong>64% correctness in validation (25 samples)</strong> and <strong>60% correctness in test (75 samples).</strong> </p> </div>

In [21]:
print("llama3-8b + FewShotOptimizer - Validation Score: \n")
evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program)

llama3-8b + FewShotOptimizer - Validation Score: 



Average Metric: 16 / 25  (64.0): 100%|██████████| 25/25 [01:27<00:00,  3.49s/it]


64.0

In [22]:
print("llama3-8b + FewShotOptimizer - Test Score: \n")
evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program)

llama3-8b + FewShotOptimizer - Test Score: 



Average Metric: 51 / 75  (68.0): 100%|██████████| 75/75 [03:36<00:00,  2.89s/it]


68.0

### BootstrapFewShotWithRandomSearch

[DSPy docs](https://dspy-docs.vercel.app/docs/building-blocks/optimizers) recommend that in a setup like the one with have at hand, with ~50 samples, the best option is to use `BootstrapFewShotWithRandomSearch`:

![image](assets/dspy.png)

In [23]:
optimizer2 = BootstrapFewShotWithRandomSearch(metric=correctness_metric, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=5)
optmized_program_2 = optimizer2.compile(student = TextToSqlProgram(), trainset=trainset, valset=valset)

  0%|          | 0/25 [00:00<?, ?it/s]

Average Metric: 16 / 25  (64.0): 100%|██████████| 25/25 [01:52<00:00,  4.52s/it]
Average Metric: 18 / 25  (72.0): 100%|██████████| 25/25 [01:31<00:00,  3.65s/it]
  5%|▌         | 4/75 [00:53<15:52, 13.42s/it]
Average Metric: 14 / 25  (56.0): 100%|██████████| 25/25 [01:50<00:00,  4.41s/it]
  4%|▍         | 3/75 [00:33<13:12, 11.00s/it]
Average Metric: 14 / 25  (56.0): 100%|██████████| 25/25 [01:46<00:00,  4.26s/it]
  1%|▏         | 1/75 [00:09<11:37,  9.43s/it]
Average Metric: 14 / 25  (56.0): 100%|██████████| 25/25 [01:42<00:00,  4.11s/it]
  1%|▏         | 1/75 [00:09<11:33,  9.37s/it]
Average Metric: 16 / 25  (64.0): 100%|██████████| 25/25 [01:37<00:00,  3.91s/it]
  1%|▏         | 1/75 [00:07<09:24,  7.63s/it]
Average Metric: 10 / 25  (40.0): 100%|██████████| 25/25 [02:03<00:00,  4.93s/it]
  4%|▍         | 3/75 [00:49<19:42, 16.43s/it]
Average Metric: 17 / 25  (68.0): 100%|██████████| 25/25 [01:31<00:00,  3.67s/it]
  3%|▎         | 2/75 [00:19<11:41,  9.61s/it]
Average Metric: 15 / 25

In [24]:
optmized_program_2(sql_context=sample["sql_context"], sql_prompt=sample["sql_prompt"])

Prediction(
    rationale='Here is the SQL query:',
    sql="SELECT c.name, c.satellites_in_orbit FROM countries c WHERE c.last_census_date <= '20-01-01' ORDER BY c.satellites_in_orbit DESC LIMIT 5;\n\nLet me explain the query.\n\n### Sql Context: CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);\n\nThis creates a table called `countries` with five columns: `id`, `name`, `population`, `satellites_in_orbit`, and `last_census_date`.\n\n### Sql Context: INSERT INTO countries (id, name, population, satellites_in_orbit, last_census_date) VALUES (1, 'USA',"
)

In [25]:
lm.inspect_history(n=1)




Transform a natural language query into a SQL query.

---

Follow the following format.

Sql Prompt: Natural language query

Sql Context: Context for the query

Reasoning: Let's think step by step in order to ${produce the sql}. We ...

Sql: SQL query

---

Sql Prompt: What is the total number of electric vehicles sold by manufacturer 'XYZ'?
Sql Context: CREATE TABLE sales_data (manufacturer VARCHAR(10), vehicle_type VARCHAR(10), quantity INT);
Sql: SELECT manufacturer, SUM(quantity) FROM sales_data WHERE vehicle_type = 'Electric' AND manufacturer = 'XYZ' GROUP BY manufacturer;

---

Sql Prompt: What is the total number of amphibians in the 'animals' table with a population size greater than 1000?
Sql Context: CREATE TABLE animals (id INT, name VARCHAR(50), species VARCHAR(50), population_size INT); INSERT INTO animals (id, name, species, population_size) VALUES (1, 'Frog', 'Anura', 1200);
Sql: SELECT COUNT(*) FROM animals WHERE species = 'Anura' AND population_size > 1000;

---

Sq

"\n\n\nTransform a natural language query into a SQL query.\n\n---\n\nFollow the following format.\n\nSql Prompt: Natural language query\n\nSql Context: Context for the query\n\nReasoning: Let's think step by step in order to ${produce the sql}. We ...\n\nSql: SQL query\n\n---\n\nSql Prompt: What is the total number of electric vehicles sold by manufacturer 'XYZ'?\nSql Context: CREATE TABLE sales_data (manufacturer VARCHAR(10), vehicle_type VARCHAR(10), quantity INT);\nSql: SELECT manufacturer, SUM(quantity) FROM sales_data WHERE vehicle_type = 'Electric' AND manufacturer = 'XYZ' GROUP BY manufacturer;\n\n---\n\nSql Prompt: What is the total number of amphibians in the 'animals' table with a population size greater than 1000?\nSql Context: CREATE TABLE animals (id INT, name VARCHAR(50), species VARCHAR(50), population_size INT); INSERT INTO animals (id, name, species, population_size) VALUES (1, 'Frog', 'Anura', 1200);\nSql: SELECT COUNT(*) FROM animals WHERE species = 'Anura' AND popu

#### Evaluate the optimized program

<div style="background-color: #E0F8E0; padding: 10px; border-radius: 5px;"> <p style="color: #006400; font-size: 18px; font-weight: bold; margin: 0;"> ✅ Evaluation Stage 3 </p> <p style="color: #006400; font-size: 16px; margin: 5px 0 0;"> With <strong>BootstrapFewShotWithRandomSearch</strong> optimization, <strong>llama3-8b</strong> achieves an <strong>80% correctness in validation (25 samples)</strong> and <strong>68% correctness in test (75 samples).</strong> </p> </div>

In [26]:
print("llama3-8b + BootstrapFewShotWithRandomSearch - Validation Score: \n")
evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program_2)

llama3-8b + BootstrapFewShotWithRandomSearch - Validation Score: 



Average Metric: 17 / 25  (68.0): 100%|██████████| 25/25 [01:22<00:00,  3.32s/it]


68.0

In [27]:
print("llama3-8b + BootstrapFewShotWithRandomSearch - Test Score: \n")
evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)
evaluate(optmized_program_2)

llama3-8b + BootstrapFewShotWithRandomSearch - Test Score: 



Average Metric: 45 / 75  (60.0): 100%|██████████| 75/75 [04:30<00:00,  3.60s/it]


60.0