In [37]:
import pandas as pd
import numpy as np
import json
import utils

metrics: \
sqtring matching score \
executability \
execution score  

In [3]:
with open('spider_data/train_spider.json') as f:
    query_data = json.load(f)

In [5]:
with open('spider_data/tables.json') as f:
    table_data = json.load(f)

In [16]:
table_data_dict = {rec['db_id']:rec for rec in table_data}

In [10]:
pd.Series([rec['db_id'] for rec in query_data]).value_counts().to_csv('temp.csv')

In [79]:
def get_schema(query, table_data_dict):
    db = query['db_id']
    tbl_id = query['sql']['from']['table_units'][0][1]
    table_name = table_data_dict[db]['table_names_original'][tbl_id]

    columns = []
    for names,type in zip(table_data_dict[db]['column_names_original'],table_data_dict[db]['column_types']):
        if names[0]==tbl_id:
            columns.append((names[1], type))
    
    tbl_details = {
        table_name:columns
    }
    
    df = pd.DataFrame(columns, columns = ['name','type'])
    df['comment'] = np.nan
    tbl_schema = utils.format_topk_sql(tbl_details, shuffle=False)

    return tbl_schema, df

In [14]:
query = query_data[0]

{'db_id': 'department_management',
 'query': 'SELECT count(*) FROM head WHERE age  >  56',
 'query_toks': ['SELECT',
  'count',
  '(',
  '*',
  ')',
  'FROM',
  'head',
  'WHERE',
  'age',
  '>',
  '56'],
 'query_toks_no_value': ['select',
  'count',
  '(',
  '*',
  ')',
  'from',
  'head',
  'where',
  'age',
  '>',
  'value'],
 'question': 'How many heads of the departments are older than 56 ?',
 'question_toks': ['How',
  'many',
  'heads',
  'of',
  'the',
  'departments',
  'are',
  'older',
  'than',
  '56',
  '?'],
 'sql': {'from': {'table_units': [['table_unit', 1]], 'conds': []},
  'select': [False, [[3, [0, [0, 0, False], None]]]],
  'where': [[False, 3, [0, [0, 10, False], None], 56.0, None]],
  'groupBy': [],
  'having': [],
  'orderBy': [],
  'limit': None,
  'intersect': None,
  'union': None,
  'except': None}}

In [80]:
single_table_queries = [query for query in query_data if len(query['sql']['from']['table_units'])==1]

In [56]:
prompt_template = """ 
                        ### Instructions:
                        Your task is to convert a question into a SQL query, given a Postgres database schema.
                        Adhere to these rules:
                        - **Deliberately go through the question and database schema word by word** to appropriately answer the question
                        - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
                        - When creating a ratio, always cast the numerator as float
                        
                        ### Input:
                        Generate a SQL query that answers the question {question}.
                        This query will run on a database whose schema is represented in this string:
                        {db_schema}
                        
                        
                        ### Response:
                        Based on your instructions, here is the SQL query I have generated to answer the question {question}:
                        ```sql
                        """

In [60]:
model_name = 'sqlcoder'

sql_query, prompt = utils.getModelResult(tbl_schema, query['question'], model_name, table_name, df)

Running pre-processing...
Pruned schema: 
CREATE TABLE head (
  age DECIMAL(38, 0), --None
  head_ID DECIMAL(38, 0), --None
  name TEXT, --None
  born_state TEXT, --None
);



CREATE TABLE head (
  age DECIMAL(38, 0), --None
  head_ID DECIMAL(38, 0), --None
  name TEXT, --None
  born_state TEXT, --None
);


prompt:  
                        ### Instructions:
                        Your task is to convert a question into a SQL query, given a Postgres database schema.
                        Adhere to these rules:
                        - **Deliberately go through the question and database schema word by word** to appropriately answer the question
                        - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
                        - When creating a ratio, always cast the numerator as float
                        
                        ### Input:
                        Generate 

In [65]:
from importlib import reload
reload(utils)

<module 'utils' from 'C:\\Users\\91876\\Documents\\codegen_nl2sql\\utils.py'>

In [76]:
import requests

OLLAMA_URL = 'http://127.0.0.1:9565'
class OLLAMA:
    def __init__(self, OLLAMA_URL, model_name):
        self.model_name = model_name
        self.ollama_url = OLLAMA_URL
        self.ollama_endpoint = '/api/generate'

    def run(self, prompt):
        data = {
            'model': self.model_name,
            'prompt': prompt,
            'stream': False,
            "options":{"temperature":0.1}
        }

        headers = {
            'Accept': 'application/json',
            'Content-Type': 'application/json'
        }

        resp = requests.post(url = f'{self.ollama_url}{self.ollama_endpoint}',
                             data = json.dumps(data),
                             headers = headers)
        query = resp.json()['response']
        print(f'JSON resp: {query}')
        return query, resp.json()

In [77]:
model_name = 'sqlcoder'
prompt = prompt_template.format(question=query['question'], db_schema=tbl_schema)
ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name='sqlcoder')
generated_query, resp = ollama.run(prompt)

JSON resp:  SELECT COUNT(*) FROM head WHERE age > 56;
                        ```


In [83]:
model_name = 'sqlcoder'
results = []
for query in single_table_queries:
    tbl_schema, _ = get_schema(query, table_data_dict)
    
    prompt = prompt_template.format(question=query['question'], db_schema=tbl_schema)
    ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name='sqlcoder')
    generated_query, resp = ollama.run(prompt)

    results.append(generated_query)
    break

JSON resp:  SELECT COUNT(*) FROM head WHERE age > 56;
                        ```
