In [2]:
from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from infection import prompt as sprompt
from infection.databases import SQL3Database
from infection.trustworthiness.hallucination import fix_sql_hallucination
import sqlite3

In [3]:
import torch 
device = torch.device('cpu')

In [4]:
# model_name = "juierror/flan-t5-text2sql-with-schema-v2"
model_name = 'NumbersStation/nsql-350M'
# model_name = 'NumbersStation/nsql-2B'

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()
def connect_func(database_name: str, database_type: str = 'sqlite3'):
    try:
        connection = SQL3Database(database_name)
        return connection
    except sqlite3.Error as e:
        print(f"Error connecting to the database: {e}")
        return None

Downloading (…)okenizer_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

In [6]:
question = "What's the statistic code used for fully vaccinated?"
# question = 'How many different age groups were tracked for covid vacciations?'
# question = "What was the biggest vaccination rate achieved?"
# question = "Which electoral area has worst latest fully vaccinated rate?"

In [17]:
import re

def transform_sql_schema_to_list(sql_schema):
    # Initialize an empty dictionary to store the table and column information
    schema_dict = {}

    # Split the SQL schema into individual CREATE TABLE statements
    create_statements = sql_schema.split(";")

    # Regular expression pattern to extract table and column names
    pattern = r"CREATE TABLE (\w+) \((.*?)\)"

    # Iterate through each CREATE TABLE statement
    for statement in create_statements:
        match = re.match(pattern, statement.strip())
        if match:
            table_name = match.group(1)
            column_definitions = match.group(2)
            columns = [column.strip().split()[0] for column in column_definitions.split(",")]
            schema_dict[table_name] = columns

    # Convert the schema_dict into the desired format
    result_list = [{table_name: columns} for table_name, columns in schema_dict.items()]

    return result_list

# Example usage:
sql_schema = """
CREATE TABLE head (age INTEGER);
CREATE TABLE body (height FLOAT, weight FLOAT);
"""

transformed_schema = transform_sql_schema_to_list(sql_schema)
print(transformed_schema)


[{'head': ['age']}, {'body': ['height', 'weight']}]


In [13]:
# %%
connection = connect_func('../data/example-data/example-covid-vaccinations.sqlite3')
schemas = connection.format_schemas(add_examples=1)
# query = sprompt.SQL_QUERY_PROMPT_TEMPLATE.format(question=question, db_schema=schemas, tables_hints=None)
prompt = """
{schemas}
**Using valid SQLite, answer the following questions for the tables provided above**.
-- {question}
```sql
"""
query = prompt.format(schemas=schemas, question=question)
print(query)


CREATE TABLE covid_vaccinations (
 	STATISTIC_CODE varchar(10),
	Statistic_Label varchar(30),
	TLIST(M1) INT,
	Month varchar(20),
	C03898V04649 varchar(30),
	Local Electoral Area varchar(50),
	C02076V03371 varchar(10),
	Age Group varchar(30),
	UNIT varchar(10),
	VALUE float,
);
SELECT * FROM covid_vaccinations LIMIT 1;
| STATISTIC_CODE   | Statistic_Label   |   TLIST(M1) | Month        | C03898V04649                         | Local Electoral Area                 |   C02076V03371 | Age Group    | UNIT   |   VALUE |
|------------------|-------------------|-------------|--------------|--------------------------------------|--------------------------------------|----------------|--------------|--------|---------|
| CDC45C01         | Fully Vaccinated  |      202101 | 2021 January | 2ae19629-3eff-13a3-e055-000000000001 | Borris-In-Ossory-Mountmellick, Laois |            247 | 5 - 11 years | %      |       0 |


**Using valid SQLite, answer the following questions for the tables provided ab

In [14]:

# prompt = """Convert question and tables into SQL query. 
# schemas = 'covid_vaccinations(STATISTIC_CODE,Statistic_Label,TLIST(M1),Month,C03898V04649,"Local Electoral Area",C02076V03371,"Age Group",UNIT,VALUE)'
# tables: {schemas}.
# question: {question}""".format(question=question, schemas=schemas)
# # example rows: covid_vaccinations(CDC45C01|Fully Vaccinated|202101|2021 January|2ae19629-3eff-13a3-e055-000000000001|Borris-In-Ossory-Mountmellick, Laois|247|5 - 11 years|%|0|)

# print(prompt)

In [16]:
eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
inputs = tokenizer(query, return_tensors="pt")

generated_ids = model.generate(
            **inputs,
            num_return_sequences=1,
            eos_token_id=eos_token_id,
            pad_token_id=eos_token_id,
            max_new_tokens=500,
            do_sample=False,
            num_beams=1
        )
        
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
outputs = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"

print(outputs)

SELECT STATISTIC_CODE FROM covid_vaccinations WHERE Statistic_Label = 'Fully Vaccinated';


In [56]:
# check_sql_hallucination(schemas=connection.get_schemas(), sql_query=outputs)
# outputs = "SELECT * FROM covid_vaccinations WHERE C03898V04649 = '202101' AND Month = '2ae19629-3eff-13a3-e055-000000000001' ORDER BY `Age Group` DESC;"
# outputs = "SELECT `Local Electoral Area` FROM covid_vaccinations ORDER BY VALUE LIMIT 1;"
# outputs = "SELECT Distinct STATISTIC_CODE FROM covid_vaccinations WHERE `TLIST(M1)` = 202101;"

In [57]:
connection.execute_sql(outputs)

([('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',),
  ('CDC45C01',