# example of BIRD setup and inference (WIP, no class-based inference!)

In [1]:
import sys
sys.path.append("../src")
from text2sql import hello
print(hello.message)

hello, world!


In [2]:
import json
import os

import numpy as np

## load data

In [3]:
bird_dataset_path = "/home/derek/PythonProjects/gena/data/text2sql_datasets/bird"
# training data and databases
bird_train_database_path = os.path.join(bird_dataset_path, "train", "train_databases")
bird_train_json_path = os.path.join(bird_dataset_path, "train", "train.json")
# dev data and databases
bird_dev_database_path = os.path.join(bird_dataset_path, "dev_20240627", "dev_databases")
bird_dev_json_path = os.path.join(bird_dataset_path, "dev_20240627", "dev.json")

### create dataset loaders for sqlite files

In [4]:
from text2sql.data import SqliteDataset

In [5]:
# create dataset loaders
train_dataset = SqliteDataset(bird_train_database_path)
print(f"loaded training dataset containing {len(train_dataset.get_databases())} databases")
dev_dataset = SqliteDataset(bird_dev_database_path)
print(f"loaded dev dataset containing {len(dev_dataset.get_databases())} databases")

loaded training dataset containing 69 databases
loaded dev dataset containing 11 databases


In [6]:
# demonstrate querying a database - query to get list of tables
train_dataset.query_database("beer_factory", "SELECT name FROM sqlite_master WHERE type='table';")

[{'name': 'customers'},
 {'name': 'geolocation'},
 {'name': 'location'},
 {'name': 'rootbeerbrand'},
 {'name': 'rootbeer'},
 {'name': 'rootbeerreview'},
 {'name': 'transaction'}]

### load json data

In [7]:
# load train data
with open(bird_train_json_path, "r") as f:
    bird_train_json = json.load(f)
print(f"loaded {len(bird_train_json)} training examples from {bird_train_json_path}")
print(f"each sample includes keys: {list(bird_train_json[0].keys())}")

loaded 9428 training examples from /home/derek/PythonProjects/gena/data/text2sql_datasets/bird/train/train.json
each sample includes keys: ['db_id', 'question', 'evidence', 'SQL']


In [8]:
bird_train_json[0]

{'db_id': 'movie_platform',
 'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.',
 'evidence': 'released in the year 1945 refers to movie_release_year = 1945;',
 'SQL': 'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'}

In [9]:
with open(bird_dev_json_path, "r") as f:
    bird_dev_json = json.load(f)
print(f"loaded {len(bird_dev_json)} dev examples from {bird_dev_json_path}")
print(f"each sample includes keys: {list(bird_dev_json[0].keys())}")

loaded 1534 dev examples from /home/derek/PythonProjects/gena/data/text2sql_datasets/bird/dev_20240627/dev.json
each sample includes keys: ['question_id', 'db_id', 'question', 'evidence', 'SQL', 'difficulty']


In [10]:
bird_dev_json[0]

{'question_id': 0,
 'db_id': 'california_schools',
 'question': 'What is the highest eligible free rate for K-12 students in the schools in Alameda County?',
 'evidence': 'Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`',
 'SQL': "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
 'difficulty': 'simple'}

## embed the training data for retrieval

In [11]:
from text2sql.engine.embeddings import SentenceTransformerEmbedder

  from tqdm.autonotebook import tqdm, trange


In [12]:
# create sentence-transformers embedder with LaBSE model
embedder = SentenceTransformerEmbedder(
    model_path="sentence-transformers/LaBSE"
)



In [13]:
# embed queries and save embeddings to temp file 
train_queries = [example["question"] for example in bird_train_json]
train_embeddings_file = "bird_query_labse_embeddings.npy"
if not os.path.isfile(train_embeddings_file):
    print(f"generating train embeddings and saving to '{train_embeddings_file}'")
    train_embeddings = embedder.embed(train_queries, verbose=True)
    np.save("bird_query_labse_embeddings.npy", train_embeddings)
else:
    print(f"loading train embeddings from existing file '{train_embeddings_file}'")
    train_embeddings = np.load(train_embeddings_file)

loading train embeddings from existing file 'bird_query_labse_embeddings.npy'


## create local retriever

we can use weaviate as well, but for ease of setup, use local

In [14]:
from text2sql.engine.retrieval import LocalRetriever

In [15]:
local_retriever = LocalRetriever(embeddings=train_embeddings, data=bird_train_json)

In [16]:
# search an (existing) query as a sanity test; top result should be the query itself
search_query = "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity."
search_results = local_retriever.query(embedder.embed(search_query), top_k=3)
print(json.dumps(search_results, indent=2))

[
  {
    "id": 0,
    "distance": 0.0,
    "data": {
      "db_id": "movie_platform",
      "question": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
      "evidence": "released in the year 1945 refers to movie_release_year = 1945;",
      "SQL": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1"
    }
  },
  {
    "id": 537,
    "distance": 0.2576185464859009,
    "data": {
      "db_id": "movies_4",
      "question": "List the movies released in 1945.",
      "evidence": "List the movies refers to title; released in 1945 refers to release_date LIKE '1945%'",
      "SQL": "SELECT title FROM movie WHERE CAST(STRFTIME('%Y', release_date) AS INT) = 1945"
    }
  },
  {
    "id": 4709,
    "distance": 0.40609586238861084,
    "data": {
      "db_id": "disney",
      "question": "List the titles of movies directed by Jack Kinney that were released before 1947.",
      "evidence

## create prompt formatter and LLM generator

In [50]:
from dotenv import load_dotenv
load_dotenv()
from text2sql.engine.prompts import BasicFewShotPromptFormatter
from text2sql.engine.generation import AzureGenerator, BedrockGenerator

In [51]:
formatter = BasicFewShotPromptFormatter()

In [52]:
# model = "meta.llama3-1-8b-instruct-v1:0"
# generator = BedrockGenerator(
#     region_name="us-west-2",
#     model=model,
# )

model = os.environ.get("AZURE_OPENAI_GEN_MODEL")
generator = AzureGenerator(
    api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
    azure_endpoint=os.environ.get("AZURE_OPENAI_API_ENDPOINT"),
    api_version=os.environ.get("AZURE_OPENAI_API_VERSION"),
    model=model,
)

print(f"using '{model}'")

using 'gena-4o'


## predict

In [65]:
# load one training sample
training_sample = bird_train_json[0]
sample_db_name = training_sample["db_id"]
sample_query = training_sample["question"]
sample_sql = training_sample["SQL"]
print(sample_db_name)
print(sample_query)
print(sample_sql)

movie_platform
Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.
SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1


In [66]:
# create chat messages for LLM input
system_message = "Write sqlite SQL query only and with no explanation. Some example questions and corresponding SQL queries are provided based on similar problems."
schema_description = train_dataset.describe_database_schema(sample_db_name, mode="datagrip")  # mode="basic_types_relations")
few_shot_examples = local_retriever.query(embedder.embed(sample_query), top_k=4)[1:]  # it's training so remove the real query
# create chat messages
messages = formatter.generate_messages(
    system_message=system_message,
    schema_description=schema_description,
    query=sample_query,
    few_shot_examples=few_shot_examples,
    few_shot_query_key="question",
    few_shot_target_key="SQL",
)

In [67]:
messages

[{'role': 'system',
  'content': 'Write sqlite SQL query only and with no explanation. Some example questions and corresponding SQL queries are provided based on similar problems.'},
 {'role': 'user',
  'content': 'similar example query: List the movies released in 1945.'},
 {'role': 'assistant',
  'content': "SELECT title FROM movie WHERE CAST(STRFTIME('%Y', release_date) AS INT) = 1945"},
 {'role': 'user',
  'content': 'similar example query: List the titles of movies directed by Jack Kinney that were released before 1947.'},
 {'role': 'assistant',
  'content': "SELECT T1.movie_title FROM characters AS T1 INNER JOIN director AS T2 ON T1.movie_title = T2.name WHERE T2.director = 'Jack Kinney' AND SUBSTR(T1.release_date, LENGTH(T1.release_date) - 1, LENGTH(T1.release_date)) < '47'"},
 {'role': 'user',
  'content': 'similar example query: Provide the list of the longest movies. Arrange these titles in alphabetical order.'},
 {'role': 'assistant',
  'content': 'SELECT title FROM film WHE

In [68]:
# this is the final user prompt
print(messages[-1]["content"])

target schema:
movie_platform schema:
    + tables
        lists: table
            + columns
                user_id: INTEGER
                list_id: INTEGER
                list_title: TEXT
                list_movie_number: INTEGER
                list_update_timestamp_utc: TEXT
                list_creation_timestamp_utc: TEXT
                list_followers: INTEGER
                list_url: TEXT
                list_comments: INTEGER
                list_description: TEXT
                list_cover_image_url: TEXT
                list_first_image_url: TEXT
                list_second_image_url: TEXT
                list_third_image_url: TEXT
            + keys
                lists_pk: PK (list_id)
            + foreign-keys
                lists_user_id_fk: foreign key (user_id) -> lists_users[.lists_users_pk] (user_id)

        movies: table
            + columns
                movie_id: INTEGER
                movie_title: TEXT
                movie_release_year: INTEGER
    

In [69]:
# inference
prediction = generator.generate(messages).strip("\n")  # need better cleaning later
print(f"target:\n```\n{sample_sql}\n```\n\n")
print(f"prediction:\n```\n{prediction}\n```\n\n")

target:
```
SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1
```


prediction:
```
SELECT movie_title 
FROM movies 
WHERE movie_release_year = 1945 
ORDER BY movie_popularity DESC;
```




### verify

we can add verification later, for now, we can test directly

In [70]:
results = train_dataset.query_database(sample_db_name, prediction.strip("\n"))
print(results[:3])

[{'movie_title': 'Brief Encounter'}, {'movie_title': 'Children of Paradise'}, {'movie_title': 'Rome, Open City'}]


## predict (dev set example)

In [71]:
# load one training sample
dev_sample = bird_dev_json[10]
sample_db_name = dev_sample["db_id"]
sample_query = dev_sample["question"]
sample_sql = dev_sample["SQL"]
print(sample_db_name)
print(sample_query)
print(sample_sql)

california_schools
For the school with the highest average score in Reading in the SAT test, what is its FRPM count for students aged 5-17?
SELECT T2.`FRPM Count (Ages 5-17)` FROM satscores AS T1 INNER JOIN frpm AS T2 ON T1.cds = T2.CDSCode ORDER BY T1.AvgScrRead DESC LIMIT 1


In [72]:
# create chat messages for LLM input
system_message = "Write sqlite SQL query only and with no explanation. Some example questions and corresponding SQL queries are provided based on similar problems."
schema_description = dev_dataset.describe_database_schema(sample_db_name, mode="datagrip")
few_shot_examples = local_retriever.query(embedder.embed(sample_query), top_k=3)
# create chat messages
messages = formatter.generate_messages(
    system_message=system_message,
    schema_description=schema_description,
    query=sample_query,
    few_shot_examples=few_shot_examples,
    few_shot_query_key="question",
    few_shot_target_key="SQL",
)

In [73]:
# this is the final user prompt
print(messages[-1]["content"])

target schema:
california_schools schema:
    + tables
        frpm: table
            + columns
                CDSCode: TEXT
                Academic Year: TEXT
                County Code: TEXT
                District Code: INTEGER
                School Code: TEXT
                County Name: TEXT
                District Name: TEXT
                School Name: TEXT
                District Type: TEXT
                School Type: TEXT
                Educational Option Type: TEXT
                NSLP Provision Status: TEXT
                Charter School (Y/N): INTEGER
                Charter School Number: TEXT
                Charter Funding Type: TEXT
                IRC: INTEGER
                Low Grade: TEXT
                High Grade: TEXT
                Enrollment (K-12): REAL
                Free Meal Count (K-12): REAL
                Percent (%) Eligible Free (K-12): REAL
                FRPM Count (K-12): REAL
                Percent (%) Eligible FRPM (K-12): REAL
    

In [74]:
# inference
prediction = generator.generate(messages).strip("\n")  # need better cleaning later
print(f"target:\n```\n{sample_sql}\n```\n\n")
print(f"prediction:\n```\n{prediction}\n```\n\n")

target:
```
SELECT T2.`FRPM Count (Ages 5-17)` FROM satscores AS T1 INNER JOIN frpm AS T2 ON T1.cds = T2.CDSCode ORDER BY T1.AvgScrRead DESC LIMIT 1
```


prediction:
```
SELECT f.`FRPM Count (Ages 5-17)` 
FROM satscores s 
JOIN frpm f ON s.cds = f.CDSCode 
ORDER BY s.AvgScrRead DESC 
LIMIT 1;
```




In [75]:
results = dev_dataset.query_database(sample_db_name, prediction.strip("\n"))
print(results[:3])

[{'FRPM Count (Ages 5-17)': 136.0}]
