In [1]:
import pandas as pd
import numpy as np
import json
import utils
import tqdm

In [2]:
prompt_template = """ 
                        ### Instructions:
                        Your task is to convert a question into a SQL query, given a Postgres database schema.
                        Adhere to these rules:
                        - **Deliberately go through the question and database schema word by word** to appropriately answer the question
                        - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
                        - When creating a ratio, always cast the numerator as float
                        
                        ### Input:
                        Generate a SQL query that answers the question {question}.
                        This query will run on a database whose schema is represented in this string:
                        {db_schema}
                        
                        
                        ### Response:
                        Based on your instructions, here is the SQL query I have generated to answer the question {question}:
                        ```sql
                        """

In [21]:
import requests

OLLAMA_URL = 'http://127.0.0.1:11434'
class OLLAMA:
    def __init__(self, OLLAMA_URL, model_name):
        self.model_name = model_name
        self.ollama_url = OLLAMA_URL
        self.ollama_endpoint = '/api/generate'

    def run(self, prompt):
        data = {
            'model': self.model_name,
            'prompt': prompt,
            'stream': False,
            "options":{"temperature":0.0}
        }

        headers = {
            'Accept': 'application/json',
            'Content-Type': 'application/json'
        }

        resp = requests.post(url = f'{self.ollama_url}{self.ollama_endpoint}',
                             data = json.dumps(data),
                             headers = headers)
        query = resp.json()['response']
        # print(f'JSON resp: {query}')
        return query

In [10]:
db_schema = """CREATE TABLE user (
    id NUMBER ,
    username VARCHAR2(50) NOT NULL,
    email VARCHAR2(100),
    password_hash CHAR(64),
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP,
    last_login DATE,
    age NUMBER CHECK (age >= 0),
    is_active CHAR(1) DEFAULT 'Y',
    profile_views NUMBER(19),
    account_balance NUMBER(10,2),
    referral_code VARCHAR2(20),
    subscription_type VARCHAR2(30),
    subscription_expiry DATE,
    email_verified CHAR(1),
    phone_number VARCHAR2(20),
    country VARCHAR2(50),
    city VARCHAR2(50),
    zip_code VARCHAR2(10),
    address_line1 CLOB,
    address_line2 CLOB,
    bio CLOB,
    avatar_url CLOB,
    favorite_color VARCHAR2(30),
    preferred_language VARCHAR2(30),
    timezone VARCHAR2(50),
    login_attempts NUMBER DEFAULT 0,
    password_reset_token CHAR(64),
    password_reset_expiry TIMESTAMP,
    session_token CHAR(64),
    session_expiry TIMESTAMP,
    device_type VARCHAR2(30),
    os_version VARCHAR2(50),
    browser_info CLOB,
    two_factor_enabled CHAR(1),
    two_factor_secret CHAR(32),
    date_of_birth DATE,
    gender VARCHAR2(10),
    employment_status VARCHAR2(30),
    job_title VARCHAR2(100),
    company_name VARCHAR2(100),
    website_url CLOB,
    linkedin_url CLOB,
    github_url CLOB,
    twitter_handle VARCHAR2(50),
    instagram_handle VARCHAR2(50),
    youtube_channel CLOB,
    notification_preferences CLOB, 
    metadata CLOB
);
"""

## Example 1

In [11]:
query = """get count of user id by country and pin code"""
filtered_schema = utils.preprocess_table(query, db_schema, 'user')

In [14]:
print(filtered_schema)


CREATE TABLE user (
  country VARCHAR(50), --None
  phone_number VARCHAR(20), --None
  zip_code VARCHAR(10), --None
  id DECIMAL(38, 0), --None
  account_balance DECIMAL(10, 2), --None
  date_of_birth DATE, --None
  two_factor_secret CHAR(32), --None
  city VARCHAR(50), --None
  age DECIMAL(38, 0), --CHECK (age >= 0)
  username VARCHAR(50), --NOT NULL
  referral_code VARCHAR(20), --None
);




In [23]:
%%time
model_name = 'sqlcoder:7b'
prompt = prompt_template.format(question=query, db_schema=db_schema)
ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
generated_query = ollama.run(prompt)
generated_query

CPU times: user 6.06 ms, sys: 3.9 ms, total: 9.96 ms
Wall time: 2.55 s


' SELECT u.id AS user_id, c.country_code, p.pin_code, COUNT(*) OVER (PARTITION BY c.country_code, p.pin_code) AS total_users FROM user u JOIN country c ON u.country = c.country_code JOIN pin p ON u.zip_code = p.pin_code;\n                        ```'

In [24]:
%%time
model_name = 'sqlcoder:7b'
prompt = prompt_template.format(question=query, db_schema=filtered_schema)
ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
generated_query = ollama.run(prompt)
generated_query

CPU times: user 6.6 ms, sys: 2.68 ms, total: 9.28 ms
Wall time: 1.57 s


' SELECT u.country, u.zip_code, COUNT(u.id) AS user_count FROM user AS u GROUP BY u.country, u.zip_code;\n                        ```'

## Example 2

In [27]:
query = """get sum of views where balance is more than 1000"""
filtered_schema = utils.preprocess_table(query, db_schema, 'user')
print(filtered_schema)


CREATE TABLE user (
  account_balance DECIMAL(10, 2), --None
  profile_views DECIMAL(19), --None
  referral_code VARCHAR(20), --None
  age DECIMAL(38, 0), --CHECK (age >= 0)
  two_factor_enabled CHAR(1), --None
  subscription_expiry DATE, --None
  is_active CHAR(1), --DEFAULT 'Y'
  instagram_handle VARCHAR(50), --None
  subscription_type VARCHAR(30), --None
  created_at TIMESTAMP, --DEFAULT CURRENT_TIMESTAMP()
);




In [28]:
%%time
model_name = 'sqlcoder:7b'
prompt = prompt_template.format(question=query, db_schema=db_schema)
ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
generated_query = ollama.run(prompt)
generated_query

CPU times: user 6.42 ms, sys: 1.13 ms, total: 7.55 ms
Wall time: 1.55 s


' SELECT SUM(views) AS total_views FROM user WHERE account_balance > 1000;\n                        ```'

In [29]:
%%time
model_name = 'sqlcoder:7b'
prompt = prompt_template.format(question=query, db_schema=filtered_schema)
ollama = OLLAMA(OLLAMA_URL=utils.OLLAMA_URL, model_name=model_name)
generated_query = ollama.run(prompt)
generated_query

CPU times: user 2.57 ms, sys: 6.7 ms, total: 9.27 ms
Wall time: 1.15 s


' SELECT SUM(profile_views) AS total_views FROM user WHERE account_balance > 1000;\n                        ```'