In [1]:
import pandas as pd
import numpy as np
import json
import utils
import tqdm

metrics: \
sqtring matching score \
executability \
execution score  

In [2]:
with open('spider_data/train_spider.json') as f:
    query_data = json.load(f)

In [3]:
with open('spider_data/tables.json') as f:
    table_data = json.load(f)

In [4]:
table_data_dict = {rec['db_id']:rec for rec in table_data}

In [5]:
def get_schema(query, table_data_dict):
    db = query['db_id']
    tbl_id = query['sql']['from']['table_units'][0][1]
    table_name = table_data_dict[db]['table_names_original'][tbl_id]

    columns = []
    for names,type in zip(table_data_dict[db]['column_names_original'],table_data_dict[db]['column_types']):
        if names[0]==tbl_id:
            columns.append((names[1], type))
    
    tbl_details = {
        table_name:columns
    }
    
    df = pd.DataFrame(columns, columns = ['name','type'])
    df['comment'] = np.nan
    tbl_schema = utils.format_topk_sql(tbl_details, shuffle=False)

    return tbl_schema, df, table_name

In [6]:
single_table_queries = [query for query in query_data if len(query['sql']['from']['table_units'])==1]

In [7]:
prompt_template = """ 
                        ### Instructions:
                        Your task is to convert a question into a SQL query, given a Postgres database schema.
                        Adhere to these rules:
                        - **Deliberately go through the question and database schema word by word** to appropriately answer the question
                        - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
                        - When creating a ratio, always cast the numerator as float
                        
                        ### Input:
                        Generate a SQL query that answers the question {question}.
                        This query will run on a database whose schema is represented in this string:
                        {db_schema}
                        
                        
                        ### Response:
                        Based on your instructions, here is the SQL query I have generated to answer the question {question}:
                        ```sql
                        """

In [60]:
model_name = 'sqlcoder'

sql_query, prompt = utils.getModelResult(tbl_schema, query['question'], model_name, table_name, df)

Running pre-processing...
Pruned schema: 
CREATE TABLE head (
  age DECIMAL(38, 0), --None
  head_ID DECIMAL(38, 0), --None
  name TEXT, --None
  born_state TEXT, --None
);



CREATE TABLE head (
  age DECIMAL(38, 0), --None
  head_ID DECIMAL(38, 0), --None
  name TEXT, --None
  born_state TEXT, --None
);


prompt:  
                        ### Instructions:
                        Your task is to convert a question into a SQL query, given a Postgres database schema.
                        Adhere to these rules:
                        - **Deliberately go through the question and database schema word by word** to appropriately answer the question
                        - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
                        - When creating a ratio, always cast the numerator as float
                        
                        ### Input:
                        Generate 

In [16]:
from importlib import reload
reload(utils)

<module 'utils' from '/workspace/codegen/codegen/utils.py'>

In [10]:
import requests

OLLAMA_URL = 'http://127.0.0.1:11434'
class OLLAMA:
    def __init__(self, OLLAMA_URL, model_name):
        self.model_name = model_name
        self.ollama_url = OLLAMA_URL
        self.ollama_endpoint = '/api/generate'

    def run(self, prompt):
        data = {
            'model': self.model_name,
            'prompt': prompt,
            'stream': False,
            "options":{"temperature":0.0}
        }

        headers = {
            'Accept': 'application/json',
            'Content-Type': 'application/json'
        }

        resp = requests.post(url = f'{self.ollama_url}{self.ollama_endpoint}',
                             data = json.dumps(data),
                             headers = headers)
        query = resp.json()['response']
        # print(f'JSON resp: {query}')
        return query

In [11]:
model_name = 'sqlcoder:7b'
prompt = prompt_template.format(question=single_table_queries[0]['question'], db_schema=tbl_schema)
ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
generated_query, resp = ollama.run(prompt)

NameError: name 'tbl_schema' is not defined

In [14]:
model_name = 'sqlcoder:7b'
results = []
i=0
for query in tqdm.tqdm(single_table_queries):
    i+=1
    try:
        tbl_schema, _, _ = get_schema(query, table_data_dict)
        
        prompt = prompt_template.format(question=query['question'], db_schema=tbl_schema)
        ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
        generated_query = ollama.run(prompt)
    
        results.append(generated_query)
    except:
        results.append('failure\n')

    if i%30==0:
        break

  1%|          | 29/4361 [01:20<3:19:50,  2.77s/it]


In [13]:
results

[' SELECT COUNT(*) FROM head WHERE age > 56;\n                        ```',
 ' SELECT head.name, head.born_state, head.age FROM head ORDER BY head.age NULLS LAST;\n                        ```',
 ' SELECT CAST(creation AS integer), name, budget_in_billions FROM department;\n                        ```',
 ' SELECT MAX(Budget_in_Billions) AS max_budget, MIN(Budget_in_Billons) AS min_budget FROM department;\n                        ```',
 ' AVG(department.num_employees) FROM department WHERE department.ranking BETWEEN 10 AND 15;\n                        ```']

In [10]:
tbl_schema, df, table_name = get_schema(single_table_queries[0], table_data_dict)

In [15]:
len(df)

4

In [None]:
model_name = 'sqlcoder:7b'
results = []
i=0
for query in tqdm.tqdm(single_table_queries):
    i+=1
    try:
        tbl_schema, df, table_name = get_schema(query, table_data_dict)

        sql_query, prompt = utils.getModelResult(tbl_schema, query['question'], model_name, table_name, df)
    
        results.append(sql_query)
    except:
        results.append('failure\n')
        
    # if i%50==0:
    #     with open('results_codegen.txt', 'a') as f:
    #         for rec in results:
    #             if '\n' in rec:
    #                 f.write(rec)
    #             else:
    #                 f.write(f"{rec}\n")
    #     results = []

# with open('results_codegen.txt', 'a') as f:
#     for rec in results:
#         if '\n' in rec:
#             f.write(rec)
#         else:
#             f.write(f"{rec}\n")

  0%|          | 0/4361 [00:00<?, ?it/s]

Running pre-processing...


  0%|          | 1/4361 [00:00<55:40,  1.31it/s]

Running pre-processing...


  0%|          | 2/4361 [00:01<1:08:14,  1.06it/s]

Running pre-processing...


  0%|          | 3/4361 [00:03<1:20:54,  1.11s/it]

Running pre-processing...


  0%|          | 4/4361 [00:04<1:37:14,  1.34s/it]

Running pre-processing...


  0%|          | 5/4361 [00:06<1:33:55,  1.29s/it]

Post-processing failed!
Invalid expression / Unexpected token. Line 1, Col: 35.
   AVG(department.num_employees) [4mFROM[0m department WHERE department.ranking BETWEEN 10 AND 15 <traceback object at 0x75c3a28b3f00>
Running pre-processing...


  0%|          | 6/4361 [00:06<1:20:21,  1.11s/it]

Running pre-processing...


  0%|          | 7/4361 [00:08<1:24:22,  1.16s/it]

Running pre-processing...


  0%|          | 8/4361 [00:10<1:48:15,  1.49s/it]

Running pre-processing...


  0%|          | 9/4361 [00:10<1:30:08,  1.24s/it]

Running pre-processing...


  0%|          | 10/4361 [00:12<1:32:29,  1.28s/it]

Running pre-processing...


  0%|          | 11/4361 [00:13<1:25:22,  1.18s/it]

Running pre-processing...


  0%|          | 12/4361 [00:13<1:15:25,  1.04s/it]

Running pre-processing...


  0%|          | 13/4361 [00:14<1:05:22,  1.11it/s]

Running pre-processing...


  0%|          | 14/4361 [00:15<1:07:36,  1.07it/s]

Running pre-processing...


  0%|          | 15/4361 [00:17<1:25:59,  1.19s/it]

Running pre-processing...


  0%|          | 16/4361 [00:18<1:24:57,  1.17s/it]

Running pre-processing...


  0%|          | 17/4361 [00:19<1:24:28,  1.17s/it]

Running pre-processing...


  0%|          | 18/4361 [00:20<1:25:24,  1.18s/it]

Running pre-processing...


  0%|          | 19/4361 [00:22<1:26:14,  1.19s/it]

Running pre-processing...


  0%|          | 20/4361 [00:23<1:29:59,  1.24s/it]

Running pre-processing...


  0%|          | 21/4361 [00:24<1:32:16,  1.28s/it]

Running pre-processing...


  1%|          | 22/4361 [00:25<1:28:52,  1.23s/it]

Running pre-processing...


  1%|          | 23/4361 [00:27<1:26:07,  1.19s/it]

Running pre-processing...


  1%|          | 24/4361 [00:27<1:20:39,  1.12s/it]

Running pre-processing...


  1%|          | 25/4361 [00:28<1:16:50,  1.06s/it]

Running pre-processing...


  1%|          | 26/4361 [00:29<1:13:23,  1.02s/it]

Running pre-processing...


  1%|          | 27/4361 [00:30<1:10:50,  1.02it/s]

Running pre-processing...


  1%|          | 28/4361 [00:31<1:16:11,  1.06s/it]

Running pre-processing...


  1%|          | 29/4361 [00:33<1:19:17,  1.10s/it]

Running pre-processing...


  1%|          | 30/4361 [00:34<1:20:49,  1.12s/it]

Running pre-processing...


  1%|          | 31/4361 [00:35<1:21:53,  1.13s/it]

Running pre-processing...


  1%|          | 32/4361 [00:37<1:36:34,  1.34s/it]

Running pre-processing...


  1%|          | 33/4361 [00:38<1:36:41,  1.34s/it]

Running pre-processing...


  1%|          | 34/4361 [00:40<1:38:27,  1.37s/it]

Running pre-processing...


  1%|          | 35/4361 [00:41<1:39:31,  1.38s/it]

Running pre-processing...


In [15]:
results

['SELECT COUNT(*) FROM head WHERE age > 56',
 'SELECT head.name, head.born_state, head.age FROM head ORDER BY head.age',
 'SELECT DATE_PART(YEAR, CAST(creation AS DATE)) AS creation_year, name, budget_in_billions FROM department',
 'SELECT MAX(Budget_in_Billions) AS max_budget, MIN(Budget_in_Billons) AS min_budget FROM department',
 ' AVG(department.num_employees) FROM department WHERE department.ranking BETWEEN 10 AND 15;\n                        ```',
 "SELECT head.name FROM head WHERE head.born_state <> 'California'",
 'SELECT DISTINCT head.born_state FROM head GROUP BY head.born_state HAVING COUNT(head.head_id) >= 3',
 "SELECT DATE_PART(YEAR, TO_DATE(creation, 'YYYY-MM-DD')) AS YEAR, COUNT(*) AS number_of_departments FROM department GROUP BY YEAR ORDER BY number_of_departments DESC NULLS LAST LIMIT 1",
 'SELECT COUNT(*) AS total_acting_status FROM management',
 'SELECT COUNT(*) AS num_departments FROM department WHERE NOT "Head" IS NULL AND NOT "Head" IN (\'John\', \'Jane\')',
 "SE