In [1]:
import os
import json
import torch
import torch.optim as optim
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from transformers import T5TokenizerFast, T5ForConditionalGeneration, MT5ForConditionalGeneration
from transformers.optimization import Adafactor
from transformers.trainer_utils import set_seed
from utils.spider_metric.evaluator import EvaluateTool
from utils.load_dataset import Text2SQLDataset
from utils.text2sql_decoding_utils import decode_sqls, decode_natsqls

In [3]:
batch_size = 16

In [4]:
model = T5ForConditionalGeneration.from_pretrained('models/text2natsql-t5-large/checkpoint-21216/')
if torch.cuda.is_available():
    model = model.cuda()
    

In [5]:
tokenizer = T5TokenizerFast.from_pretrained('models/text2natsql-t5-large/checkpoint-21216')

In [8]:
import sqlite3
conn = sqlite3.Connection('inference_db/test.db')
# conn = sqlite3.Connection('database/concert_singer/concert_singer.sqlite')

cursor = conn.cursor()

cursor.execute("SELECT name,sql FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
print(tables[0][1])

CREATE TABLE Employees (
    EmployeeID INT PRIMARY KEY,
    Name VARCHAR(100),
    JobTitle VARCHAR(100),
    DepartmentID INT,
    Salary DECIMAL(10, 2),
    FOREIGN KEY (DepartmentID) REFERENCES Departments(DepartmentID)
)


In [9]:
print(tables[0][1])

CREATE TABLE Employees (
    EmployeeID INT PRIMARY KEY,
    Name VARCHAR(100),
    JobTitle VARCHAR(100),
    DepartmentID INT,
    Salary DECIMAL(10, 2),
    FOREIGN KEY (DepartmentID) REFERENCES Departments(DepartmentID)
)


In [10]:
sql_input = ''

In [11]:
for table in tables:
    table_name = table[0]
    sql_input += f'| {table_name} : '
    cursor.execute("PRAGMA table_info({})".format(table_name))
    columns = cursor.fetchall()
    for column in columns:
        # print(column)
    
        sql_input += f"{table_name}.{column[1]}, "
    sql_input += f"{table_name}.* "
    

In [12]:
sql_input

'| Employees : Employees.EmployeeID, Employees.Name, Employees.JobTitle, Employees.DepartmentID, Employees.Salary, Employees.* | Departments : Departments.DepartmentID, Departments.Name, Departments.ManagerID, Departments.* | Projects : Projects.ProjectID, Projects.Name, Projects.DepartmentID, Projects.StartDate, Projects.EndDate, Projects.* | Tasks : Tasks.TaskID, Tasks.Description, Tasks.ProjectID, Tasks.AssigneeID, Tasks.* '

In [83]:
nlp_input = "Retrieve all employees names and their departments names"
inputs = nlp_input + sql_input

In [84]:
tokenized_inputs = tokenizer(
            inputs, 
            return_tensors="pt",
            padding = "max_length",
            max_length = 512,
            truncation = True
        )

In [85]:
encoder_input_ids = tokenized_inputs["input_ids"]
encoder_input_attention_mask = tokenized_inputs["attention_mask"]
if torch.cuda.is_available():
        encoder_input_ids = encoder_input_ids.cuda()
        encoder_input_attention_mask = encoder_input_attention_mask.cuda()

In [86]:
num_beams = 8
num_return_sequences = 8

In [87]:
with torch.no_grad():
    model_outputs = model.generate(
        input_ids = encoder_input_ids,
        attention_mask = encoder_input_attention_mask,
        max_length = 256,
        decoder_start_token_id = model.config.decoder_start_token_id,
        num_beams = 2,
        num_return_sequences = 1
    )

In [88]:
torch.cuda.empty_cache()

In [89]:
model_outputs.shape

torch.Size([1, 22])

In [90]:
model_outputs[0,0]

tensor(0, device='cuda:0')

In [91]:
model_outputs

tensor([[    0,  1738,     3,   834,    45,     3,   834,  1820,  1738, 15871,
             7,     5, 23954,     6,  1775,     7,     5, 23954,    45, 15871,
             7,     1]], device='cuda:0')

In [92]:
model_outputs

tensor([[    0,  1738,     3,   834,    45,     3,   834,  1820,  1738, 15871,
             7,     5, 23954,     6,  1775,     7,     5, 23954,    45, 15871,
             7,     1]], device='cuda:0')

In [98]:
model_outputs = model_outputs.view(len(1), opt.num_return_sequences, model_outputs.shape[1])

TypeError: object of type 'int' has no len()

In [93]:
pred_sequence = tokenizer.decode(model_outputs[0], skip_special_tokens = True)

In [94]:
pred_sequence

'select _ from _ | select Employees.Name, Departments.Name from Employees'

In [95]:
pred_sql = pred_sequence.split("|")[-1].strip()
pred_sql

'select Employees.Name, Departments.Name from Employees'

In [97]:
# cursor.execute('select Employees.Name from Employees where Employees.Salary > 70000')
cursor.execute('''SELECT e.Name AS EmployeeName, e.JobTitle, d.Name AS DepartmentName
FROM Employees e
INNER JOIN Departments d ON e.DepartmentID = d.DepartmentID;
''')
cursor.fetchall()

[('John Doe', 'Software Engineer', 'Engineering'),
 ('Jane Smith', 'Data Analyst', 'Data Science'),
 ('Michael Johnson', 'Project Manager', 'Project Management'),
 ('Emily Brown', 'HR Specialist', 'Human Resources'),
 ('David Wilson', 'Marketing Manager', 'Marketing')]

In [None]:
tables = json.load(open('NatSQL/NatSQLv1_6/tables_for_natsql.json','r'))
table_dict = dict()
for t in tables:
    table_dict[t["db_id"]] = t
    

OperationalError: near ",": syntax error

In [183]:
table_dict

{'perpetrator': {'column_names': [[0, '*'],
   [0, 'id'],
   [0, 'people id'],
   [0, 'date'],
   [0, 'year'],
   [0, 'location'],
   [0, 'country'],
   [0, 'killed'],
   [0, 'injured'],
   [1, '*'],
   [1, 'id'],
   [1, 'name'],
   [1, 'height'],
   [1, 'weight'],
   [1, 'home town']],
  'column_names_original': [[0, '*'],
   [0, 'Perpetrator_ID'],
   [0, 'People_ID'],
   [0, 'Date'],
   [0, 'Year'],
   [0, 'Location'],
   [0, 'Country'],
   [0, 'Killed'],
   [0, 'Injured'],
   [1, '*'],
   [1, 'People_ID'],
   [1, 'Name'],
   [1, 'Height'],
   [1, 'Weight'],
   [1, 'Home Town']],
  'column_types': ['others',
   'number',
   'number',
   'text',
   'year',
   'text',
   'text',
   'number',
   'number',
   'others',
   'number',
   'text',
   'number',
   'number',
   'text'],
  'db_id': 'perpetrator',
  'foreign_keys': [[2, 10]],
  'primary_keys': [1, 10],
  'table_names': ['perpetrator', 'people'],
  'table_names_original': ['perpetrator', 'people'],
  'original_primary_keys': [1, 1

In [131]:
def get_cursor_from_path(sqlite_path):
    try:
        if not os.path.exists(sqlite_path):
            print("Openning a new connection %s" % sqlite_path)
        connection = sqlite3.connect(sqlite_path, check_same_thread = False)
    except Exception as e:
        print(sqlite_path)
        raise e
    connection.text_factory = lambda b: b.decode(errors="ignore")
    cursor = connection.cursor()
    return cursor

In [132]:
def execute_sql(cursor, sql):
    cursor.execute(sql)

    return cursor.fetchall()

## For running SQL queries

In [143]:
final_sqls  = []

for batch_id in range(model_outputs.shape[0]):
    for seq_id in range(model_outputs.shape[1]): # model_outputs[1] represents num_of return_sequences
        pred_executable_sql = "sql placeholder"
        cursor = get_cursor_from_path('inference_db/test.db')
        pred_sequence = tokenizer.decode(model_outputs[batch_id,seq_id:], skip_special_tokens=True)
        
        pred_sql = pred_sequence.split("|")[-1].strip()
        pred_sql = pred_sql.replace("='", "= '").replace("!=", " !=").replace(",", " ,")
        try:
            # Note: execute_sql will be success for empty string
            assert len(pred_sql) > 0, "pred sql is empty!"

            results = execute_sql(cursor, pred_sql)
            # if the current sql has no execution error, we record and return it
            pred_executable_sql = pred_sql
            cursor.close()
            cursor.connection.close()
            break
        except Exception as e:
            print(pred_sql)
            print(e)
            cursor.close()
            cursor.connection.close()
            

        final_sqls.append(pred_executable_sql)

select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Employees.Name from Tasks where Tasks.ProjectID = 'Website Development'
no such column: Employees.Name
select Emp

## For natsqls 

In [None]:
final_sqls  = []

for batch_id in range(model_outputs.shape[0]):
    for seq_id in range(model_outputs.shape[1]): # model_outputs[1] represents num_of return_sequences
        pred_executable_sql = "sql placeholder"
        cursor = get_cursor_from_path('inference_db/test.db')
        pred_sequence = tokenizer.decode(model_outputs[batch_id,seq_id:], skip_special_tokens=True)
        
        pred_natsql = pred_sequence.split("|")[-1].strip()
        pred_natsql = pred_natsql.replace("='", "= '").replace("!=", " !=").replace(",", " ,")
        old_pred_natsql = pred_natsql
        
        pred_natsql = fix_fata_errors_in_natsql(pred_natsql, batch_tc_original[batch_id])
        try:
            # Note: execute_sql will be success for empty string
            assert len(pred_sql) > 0, "pred sql is empty!"

            results = execute_sql(cursor, pred_sql)
            # if the current sql has no execution error, we record and return it
            pred_executable_sql = pred_sql
            cursor.close()
            cursor.connection.close()
            break
        except Exception as e:
            print(pred_sql)
            print(e)
            cursor.close()
            cursor.connection.close()
            

        final_sqls.append(pred_executable_sql)