In [1]:
import pandas as pd
import numpy as np
import json
import utils
import tqdm

metrics: \
sqtring matching score \
executability \
execution score  

In [2]:
with open('spider_data/train_spider.json') as f:
    query_data = json.load(f)

In [3]:
with open('spider_data/tables.json') as f:
    table_data = json.load(f)

In [4]:
table_data_dict = {rec['db_id']:rec for rec in table_data}

In [5]:
multi_table_queries = [query for query in query_data if len(query['sql']['from']['table_units'])>1]

In [7]:
query = multi_table_queries[0]
db = query['db_id']
table_names = table_data_dict[db]['table_names_original']

In [11]:
schema_list = []
for i,table in enumerate(table_names):
    columns = []
    for names,type in zip(table_data_dict[db]['column_names_original'],table_data_dict[db]['column_types']):
        if names[0]==i:
            columns.append((names[1], type))
    
    tbl_details = {
        table:columns
    }

    tbl_schema = utils.format_topk_sql(tbl_details, shuffle=False)
    schema_list.append(tbl_schema)

In [15]:
print("".join(schema_list))


CREATE TABLE department (
  Department_ID number, 
  Name text, 
  Creation text, 
  Ranking number, 
  Budget_in_Billions number, 
  Num_Employees number, 
);


CREATE TABLE head (
  head_ID number, 
  name text, 
  born_state text, 
  age number, 
);


CREATE TABLE management (
  department_ID number, 
  head_ID number, 
  temporary_acting text, 
);




In [9]:
td = table_data_dict['department_management']
td

{'column_names': [[-1, '*'],
  [0, 'department id'],
  [0, 'name'],
  [0, 'creation'],
  [0, 'ranking'],
  [0, 'budget in billions'],
  [0, 'num employees'],
  [1, 'head id'],
  [1, 'name'],
  [1, 'born state'],
  [1, 'age'],
  [2, 'department id'],
  [2, 'head id'],
  [2, 'temporary acting']],
 'column_names_original': [[-1, '*'],
  [0, 'Department_ID'],
  [0, 'Name'],
  [0, 'Creation'],
  [0, 'Ranking'],
  [0, 'Budget_in_Billions'],
  [0, 'Num_Employees'],
  [1, 'head_ID'],
  [1, 'name'],
  [1, 'born_state'],
  [1, 'age'],
  [2, 'department_ID'],
  [2, 'head_ID'],
  [2, 'temporary_acting']],
 'column_types': ['text',
  'number',
  'text',
  'text',
  'number',
  'number',
  'number',
  'number',
  'text',
  'text',
  'number',
  'number',
  'number',
  'text'],
 'db_id': 'department_management',
 'foreign_keys': [[12, 7], [11, 1]],
 'primary_keys': [1, 7, 11],
 'table_names': ['department', 'head', 'management'],
 'table_names_original': ['department', 'head', 'management']}

In [13]:
fk_list = []
for fk in td['foreign_keys']:
    s1 = (td['column_names_original'][fk[0]][0], td['table_names'][td['column_names_original'][fk[0]][0]], td['column_names_original'][fk[0]][1], td['column_types'][fk[0]])
    s2 = (td['column_names_original'][fk[1]][0], td['table_names'][td['column_names_original'][fk[1]][0]], td['column_names_original'][fk[1]][1], td['column_types'][fk[1]])

    fk_list.append((s1,s2))

In [6]:
def get_schema(query, table_data_dict):
    db = query['db_id']
    # tbl_id = query['sql']['from']['table_units'][0][1]
    table_names = table_data_dict[db]['table_names_original']

    td = table_data_dict[db]
    fk_list = []
    for fk in table_data_dict[db]['foreign_keys']:
        s1 = (td['column_names_original'][fk[0]][0], td['table_names'][td['column_names_original'][fk[0]][0]], td['column_names_original'][fk[0]][1], td['column_types'][fk[0]])
        s2 = (td['column_names_original'][fk[1]][0], td['table_names'][td['column_names_original'][fk[1]][0]], td['column_names_original'][fk[1]][1], td['column_types'][fk[1]])
    
        fk_list.append((s1,s2))
    
    schema_list = []
    for i,table in enumerate(table_names):
        columns = []
        for names,type in zip(table_data_dict[db]['column_names_original'],table_data_dict[db]['column_types']):
            if names[0]==i:
                columns.append((names[1], type))
        
        tbl_details = {
            table:columns
        }
    
        tbl_schema = utils.format_topk_sql(tbl_details, shuffle=False)
        schema_list.append(tbl_schema)

    schema = "".join(schema_list)
    return schema, schema_list, table_names, fk_list

In [7]:
prompt_template = """ 
                        ### Instructions:
                        Your task is to convert a question into a SQL query, given a Postgres database schema.
                        Adhere to these rules:
                        - **Deliberately go through the question and database schema word by word** to appropriately answer the question
                        - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
                        - **Do not use column aliases** 
                        - When creating a ratio, always cast the numerator as float
                        
                        ### Input:
                        Generate a SQL query that answers the question {question}.
                        This query will run on a database whose schema is represented in this string:
                        {db_schema}
                        
                        
                        ### Response:
                        Based on your instructions, here is the SQL query I have generated to answer the question {question}:
                        ```sql
                        """

In [60]:
model_name = 'sqlcoder'

sql_query, prompt = utils.getModelResult(tbl_schema, query['question'], model_name, table_name, df)

Running pre-processing...
Pruned schema: 
CREATE TABLE head (
  age DECIMAL(38, 0), --None
  head_ID DECIMAL(38, 0), --None
  name TEXT, --None
  born_state TEXT, --None
);



CREATE TABLE head (
  age DECIMAL(38, 0), --None
  head_ID DECIMAL(38, 0), --None
  name TEXT, --None
  born_state TEXT, --None
);


prompt:  
                        ### Instructions:
                        Your task is to convert a question into a SQL query, given a Postgres database schema.
                        Adhere to these rules:
                        - **Deliberately go through the question and database schema word by word** to appropriately answer the question
                        - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
                        - When creating a ratio, always cast the numerator as float
                        
                        ### Input:
                        Generate 

In [16]:
from importlib import reload
reload(utils)

<module 'utils' from '/workspace/codegen/codegen/utils.py'>

In [8]:
import requests

OLLAMA_URL = 'http://127.0.0.1:11434'
class OLLAMA:
    def __init__(self, OLLAMA_URL, model_name):
        self.model_name = model_name
        self.ollama_url = OLLAMA_URL
        self.ollama_endpoint = '/api/generate'

    def run(self, prompt):
        data = {
            'model': self.model_name,
            'prompt': prompt,
            'stream': False,
            "options":{"temperature":0.0}
        }

        headers = {
            'Accept': 'application/json',
            'Content-Type': 'application/json'
        }

        resp = requests.post(url = f'{self.ollama_url}{self.ollama_endpoint}',
                             data = json.dumps(data),
                             headers = headers)
        query = resp.json()['response']
        # print(f'JSON resp: {query}')
        return query

In [14]:
model_name = 'sqlcoder:7b'
results = []
i=0
for query in tqdm.tqdm(multi_table_queries[:500]):
    i+=1
    try:
        tbl_schema, _, _, _ = get_schema(query, table_data_dict)
        
        prompt = prompt_template.format(question=query['question'], db_schema=tbl_schema)
        ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
        generated_query = ollama.run(prompt)
    
        results.append(generated_query)
    except:
        results.append('failure\n')
        # if i%5==0:
    #     with open('results.txt', 'a') as f:
    #         for rec in results:
    #             if '\n' in rec:
    #                 f.write(rec)
    #             else:
    #                 f.write(f"{rec}\n")

100%|██████████| 500/500 [23:49<00:00,  2.86s/it]


In [38]:
import pickle

In [15]:
import pickle
with open('results_mt_v2.pkl', 'wb') as f:
    pickle.dump(results, f)

In [54]:
from importlib import reload
reload(utils)

<module 'utils' from '/workspace/codegen/codegen/utils.py'>

## preprocess

In [16]:
query = multi_table_queries[0]
query

{'db_id': 'department_management',
 'query': "SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id  =  T2.department_id JOIN head AS T3 ON T2.head_id  =  T3.head_id WHERE T3.born_state  =  'Alabama'",
 'query_toks': ['SELECT',
  'DISTINCT',
  'T1.creation',
  'FROM',
  'department',
  'AS',
  'T1',
  'JOIN',
  'management',
  'AS',
  'T2',
  'ON',
  'T1.department_id',
  '=',
  'T2.department_id',
  'JOIN',
  'head',
  'AS',
  'T3',
  'ON',
  'T2.head_id',
  '=',
  'T3.head_id',
  'WHERE',
  'T3.born_state',
  '=',
  "'Alabama",
  "'"],
 'query_toks_no_value': ['select',
  'distinct',
  't1',
  '.',
  'creation',
  'from',
  'department',
  'as',
  't1',
  'join',
  'management',
  'as',
  't2',
  'on',
  't1',
  '.',
  'department_id',
  '=',
  't2',
  '.',
  'department_id',
  'join',
  'head',
  'as',
  't3',
  'on',
  't2',
  '.',
  'head_id',
  '=',
  't3',
  '.',
  'head_id',
  'where',
  't3',
  '.',
  'born_state',
  '=',
  'value'],
 'que

In [22]:
schema, s_list, table_names, fk = get_schema(query, table_data_dict)

In [45]:
fk_string = ''
for row in fk:
    fk_string += f"{row[0][1]}.{row[0][2]} joins with {row[1][1]}.{row[1][2]} \n"

    if row[0][2] not in slist2[row[0][0]]:
        idx = slist2[row[0][0]].find(');')
        new = slist2[row[0][0]][:idx] + f'  {row[0][2]} {row[0][3]}, \n' + slist2[row[0][0]][idx:]
        slist2[row[0][0]] = new

    if row[1][2] not in slist2[row[1][0]]:
        idx = slist2[row[1][0]].find(');')
        new = slist2[row[1][0]][:idx] + f'  {row[1][2]} {row[1][3]}, \n' + slist2[row[1][0]][idx:]
        slist2[row[1][0]] = new

In [9]:
# set low topk value
model_name = 'sqlcoder:7b'
results = []
top_k = 8
i = 0
for query in tqdm.tqdm(multi_table_queries[:500]):
    i+=1
    try:
        schema, s_list, table_names, fk = get_schema(query, table_data_dict)
        pruned_schema_list = []
        
        for schema, table in zip(s_list, table_names):
            pruned_schema = utils.preprocess_table(query['question'], schema, table, top_k)
            pruned_schema_list.append(pruned_schema)
    
        fk_string = ''
        for row in fk:
            fk_string += f"-- {row[0][1]}.{row[0][2]} can be joined with {row[1][1]}.{row[1][2]} \n"
        
            if row[0][2] not in pruned_schema_list[row[0][0]]:
                idx = pruned_schema_list[row[0][0]].find(');')
                new = pruned_schema_list[row[0][0]][:idx] + f'  {row[0][2]} {row[0][3]}, \n' + pruned_schema_list[row[0][0]][idx:]
                pruned_schema_list[row[0][0]] = new
        
            if row[1][2] not in pruned_schema_list[row[1][0]]:
                idx = pruned_schema_list[row[1][0]].find(');')
                new = pruned_schema_list[row[1][0]][:idx] + f'  {row[1][2]} {row[1][3]}, \n' + pruned_schema_list[row[1][0]][idx:]
                pruned_schema_list[row[1][0]] = new
    
        final_schema = f"{"".join(pruned_schema_list)}\n\n{fk_string}"
        # print(final_schema)
        
        prompt = prompt_template.format(question=query['question'], db_schema=final_schema)
        ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
        generated_query = ollama.run(prompt)
        results.append(generated_query)
    except:
        results.append('-1')

    # if i%10==0:
    #     break

100%|██████████| 500/500 [33:20<00:00,  4.00s/it]


In [11]:
results[100]

' SELECT a.name FROM flight f JOIN aircraft a ON f.aid = a.aid WHERE f.flno = 99;\n                        ```'

In [10]:
import pickle
with open('results_preprocessing_mt.pkl', 'wb') as f:
    pickle.dump(results, f)