In [1]:
# !pip install torch transformers accelerate bitsandbytes scipy 

In [1]:
!nvidia-smi

Tue Sep 19 08:22:14 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:4F:00.0 Off |                  Off |
| 30%   25C    P8    22W / 300W |      1MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from typing import List
import sqlite3
from tabulate import tabulate
import transformers
print(transformers.__version__)
print(torch.cuda.is_available())

4.33.2
True


# Load the model

Set parameter load_in_4bit=True for a model size of 12GB, otherwise the model will be too big. 
If you work on A100 machine, parameter torch_dtype=torch.bfloat16 can be set to True instead.

Can read more here: https://huggingface.co/blog/4bit-transformers-bitsandbytes

In [3]:
sql_model_name = "defog/sqlcoder"
sql_tokenizer = AutoTokenizer.from_pretrained(sql_model_name, cache_dir='/home/mpham/workspace/huawei-arena-2023/.cache')
sql_model = AutoModelForCausalLM.from_pretrained(
    sql_model_name,
    trust_remote_code=True,
    load_in_4bit=True,
    device_map="auto",
    use_cache=True,
    cache_dir='/home/mpham/workspace/huawei-arena-2023/.cache'
)

answer_model = sql_model
answer_tokenizer = sql_tokenizer

# answer_model_name = "meta-llama/Llama-2-7b-chat-hf"
# token = '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'

# answer_tokenizer = AutoTokenizer.from_pretrained(answer_model_name, token=token, cache_dir='/home/mpham/workspace/huawei-arena-2023/.cache')
# answer_model = transformers.pipeline(
#     "text-generation",
#     model='openai-gpt',
#     # torch_dtype=torch.bfloat16,
#     # device_map="auto",
#     # token=token
# )

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
def generate_sql_query(model, tokenizer, question: str, db_schema: str, tables_hints: List[str]) -> str:
    # Implement logic to generate an SQL query based on the question and table hints.
    # Replace the "pass" with a calling function to LLM
    
    # Handle the case when table hints are empty or invalid.
    if not tables_hints:
        # Default behavior: Query all tables
        pass
    
    # Handle the general case
    # Example: "SELECT COUNT(*) FROM customers"
    prompt = generate_sql_query_generation_prompt(question, db_schema)
    eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # or to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=eos_token_id,
        pad_token_id=eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1
    )
    
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    sql_query = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    return sql_query

In [5]:
def generate_answer_with_context(model, tokenizer, question: str, schema: List[str]) -> str:
    answer_generation_prompt = generate_answer_generation_prompt(question, schema)
    print('answer prompt:', answer_generation_prompt)
    
    ### Use the same procedure
    
    eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
    inputs = tokenizer(answer_generation_prompt, return_tensors="pt").to("cuda") # or to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=eos_token_id,
        pad_token_id=eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1
    )

    # answer = model(
    #     answer_generation_prompt,
    #     do_sample=False,
    #     # top_k=10,
    #     num_return_sequences=1,
    #     eos_token_id=tokenizer.eos_token_id,
    #     max_length=200,
    # )

    
    # for seq in sequences:
    #     print(f"Result: {seq['generated_text']}")
    
    
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    answer = outputs[0].split("```ans")[-1].split("```")[0].split(";")[0].strip() + ";"
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    return answer

# Prompt Engineering

In [6]:
SQL_QUERY_PROMPT_TEMPLATE = """### Instructions:
Your task is to convert a question into a SQL query, given a SQLlite database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{db_schema}

These are some hints to help you in this task. 
Some of these hints can be wrong, so only use relevant ones:
{tables_hints}

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
"""

ANSWER_GENERATION_PROMPT_TEMPLATE = """
Generate a suitable answer to a prompt based on the extracted tabular information.
The information extracted from the database is as follows:
{returned_schema}
Each row represents a data point, and the columns are separated by "|".
Your answer should be short, concise and straight to the point.
The prompt is as followed: `{question}`

### Response:
Based on your instructions, here is the answer I have generated to give an appropriate response:
```ans
"""

In [7]:
def generate_sql_query_generation_prompt(question, db_schema):
    return SQL_QUERY_PROMPT_TEMPLATE.format(question=question, db_schema=db_schema)

def generate_answer_generation_prompt(question, data):
    return ANSWER_GENERATION_PROMPT_TEMPLATE.format(question=question, returned_schema=str(data))

In [8]:
def format_sql_execution(sql_response, sql_schema):
    results = tabulate(sql_response, tablefmt="jira", headers=sql_schema)
    return results

# Database

In [9]:
def get_schemas(cursor, table_hints=None):
    '''
    get the schema information from this database
    '''
    tableQuery="SELECT name FROM sqlite_master WHERE type='table'"
    tableList=cursor.execute(tableQuery).fetchall()
    schemas = []
    tables = {}
    for table in tableList:
        tableName=table[0]
        columnQuery="PRAGMA table_info('%s')" % tableName
        schema=cursor.execute(columnQuery).fetchall()
        tables[tableName] = schema
    return tables

# Main functions

In [10]:
def connect_fun(database_name: str) -> sqlite3.Connection:
    """
    Connect to an SQLite database and return a connection object.

    Parameters:
        database_name (str): The name (or path) of the SQLite database file to connect to.

    Returns:
        sqlite3.Connection or None: A connection object if the connection is successful,
        or None if there is an error.

    Example usage:
        db_name = 'your_database_name.db'
        connection = connect_fun(db_name)
        
        if connection:
            print(f"Connected to {db_name}")
            # You can now use 'connection' to interact with the database.
        else:
            print("Connection failed.")
    """
    try:
        connection = sqlite3.connect(database_name)
        return connection
    except sqlite3.Error as e:
        print(f"Error connecting to the database: {e}")
        return None


def query_fun(question: str, conn: sqlite3.Connection, tables_hints: List[str]=None, debug:bool=False) -> str:
    """
    Generate an answer to a question based on an SQLite database and question context.

    Parameters:
        question (str): The user's question.
        tables_hints (List[str]): List of table names to consider in the query.
        conn (sqlite3.Connection): A connection to the SQLite database.

    Returns:
        str: The answer to the question.

    Example usage:
        question = "How many customers are there in the database?"
        table_hints = ["customers"]
        connection = sqlite3.connect("your_database.db")
        answer = query_fun(question, table_hints, connection)
        print(answer)
    """
    try:
        cursor = conn.cursor()

        # Step 0: Get related tables based on all schemas and table hints
        related_schemas = get_schemas(cursor, tables_hints) 

        if debug:
            print("Related schemas: ", related_schemas) 
        
        # Step 1: Generate an SQL query based on the question and table hints.
        sql_query = generate_sql_query(sql_model, sql_tokenizer, question, related_schemas, tables_hints)

        if debug:
            print("SQL query: ", sql_query)

        # Step 2: Execute the SQL query and fetch the results.
        response = cursor.execute(sql_query)

        # Step 3: Obtain records from response and schema information (column names) from the cursor description.
        records = response.fetchall()
        reponse_schema = [desc[0] for desc in cursor.description]
        sql_response = format_sql_execution(records, reponse_schema)

        if debug:
            print("SQL execution response: ", sql_response)

        # Step 4: Process the query result and generate an answer with context using LLM.
        answer = generate_answer_with_context(answer_model, answer_tokenizer, question, sql_response)

        return sql_query, answer

    except sqlite3.Error as e:
        print(f"SQLite Error: {e}")
        return "An error occurred while processing the query."
    except Exception as e:
        print(f"Error: {e}")
        return "An error occurred."

# Experiments

![alt](/mnt/4TBSSD/pmkhoi/huawei-sql/huawei-arena-2023/imgs/chinook-er-diagram.png)

In [11]:
connection = connect_fun('/home/mpham/workspace/huawei-arena-2023/data/chinook/Chinook_Sqlite.sqlite')

In [12]:
cursor = connection.cursor()
related_schemas = get_schemas(cursor, table_hints=None)

In [13]:
related_schemas

{'Album': [(0, 'AlbumId', 'INTEGER', 1, None, 1),
  (1, 'Title', 'NVARCHAR(160)', 1, None, 0),
  (2, 'ArtistId', 'INTEGER', 1, None, 0)],
 'Artist': [(0, 'ArtistId', 'INTEGER', 1, None, 1),
  (1, 'Name', 'NVARCHAR(120)', 0, None, 0)],
 'Customer': [(0, 'CustomerId', 'INTEGER', 1, None, 1),
  (1, 'FirstName', 'NVARCHAR(40)', 1, None, 0),
  (2, 'LastName', 'NVARCHAR(20)', 1, None, 0),
  (3, 'Company', 'NVARCHAR(80)', 0, None, 0),
  (4, 'Address', 'NVARCHAR(70)', 0, None, 0),
  (5, 'City', 'NVARCHAR(40)', 0, None, 0),
  (6, 'State', 'NVARCHAR(40)', 0, None, 0),
  (7, 'Country', 'NVARCHAR(40)', 0, None, 0),
  (8, 'PostalCode', 'NVARCHAR(10)', 0, None, 0),
  (9, 'Phone', 'NVARCHAR(24)', 0, None, 0),
  (10, 'Fax', 'NVARCHAR(24)', 0, None, 0),
  (11, 'Email', 'NVARCHAR(60)', 1, None, 0),
  (12, 'SupportRepId', 'INTEGER', 0, None, 0)],
 'Employee': [(0, 'EmployeeId', 'INTEGER', 1, None, 1),
  (1, 'LastName', 'NVARCHAR(20)', 1, None, 0),
  (2, 'FirstName', 'NVARCHAR(20)', 1, None, 0),
  (3, 'Ti

In [15]:
questions = [
    # "What is the highest sales of three salesman person? Give me the salesperson's name and his or her total sales",
    # "In 1981 which team picked overall 148?"

    "Find me 5 random song track names"
]

In [18]:
%%time
sql_command, text_response = query_fun(
    question=questions[0],
    conn=connection,
    debug=True
)

Related schemas:  {'Album': [(0, 'AlbumId', 'INTEGER', 1, None, 1), (1, 'Title', 'NVARCHAR(160)', 1, None, 0), (2, 'ArtistId', 'INTEGER', 1, None, 0)], 'Artist': [(0, 'ArtistId', 'INTEGER', 1, None, 1), (1, 'Name', 'NVARCHAR(120)', 0, None, 0)], 'Customer': [(0, 'CustomerId', 'INTEGER', 1, None, 1), (1, 'FirstName', 'NVARCHAR(40)', 1, None, 0), (2, 'LastName', 'NVARCHAR(20)', 1, None, 0), (3, 'Company', 'NVARCHAR(80)', 0, None, 0), (4, 'Address', 'NVARCHAR(70)', 0, None, 0), (5, 'City', 'NVARCHAR(40)', 0, None, 0), (6, 'State', 'NVARCHAR(40)', 0, None, 0), (7, 'Country', 'NVARCHAR(40)', 0, None, 0), (8, 'PostalCode', 'NVARCHAR(10)', 0, None, 0), (9, 'Phone', 'NVARCHAR(24)', 0, None, 0), (10, 'Fax', 'NVARCHAR(24)', 0, None, 0), (11, 'Email', 'NVARCHAR(60)', 1, None, 0), (12, 'SupportRepId', 'INTEGER', 0, None, 0)], 'Employee': [(0, 'EmployeeId', 'INTEGER', 1, None, 1), (1, 'LastName', 'NVARCHAR(20)', 1, None, 0), (2, 'FirstName', 'NVARCHAR(20)', 1, None, 0), (3, 'Title', 'NVARCHAR(30)',

In [20]:
sql_command

'SELECT track.name\nFROM   track\nWHERE  track.trackid in (SELECT trackid\n                          FROM   track\n                          ORDER BY random()\n                          LIMIT 5);'

In [19]:
text_response

'SELECT songs.name\nFROM   songs\nORDER BY random()\nLIMIT  5;'

# Testcases

In [None]:
sales_db_schema = """
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""

question = "What product has the biggest fall in sales in 2022 compared to 2021? \
            Give me the product name, the sales amount in both years, and the difference."

In [9]:
%%time
print(generate_sql_query(sql_model, sql_tokenizer, question=question, db_schema=sales_db_schema, tables_hints=None))

WITH year_2021 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2021
  FROM   sales
  WHERE  sales.sale_date >= '2021-01-01'
     AND sales.sale_date < '2022-01-01'
  GROUP BY sales.product_id
  ORDER BY sales_2021 desc
), year_2022 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2022
  FROM   sales
  WHERE  sales.sale_date >= '2022-01-01'
  GROUP BY sales.product_id
  ORDER BY sales_2022 desc
)
SELECT products.name,
       year_2021.sales_2021,
       year_2022.sales_2022,
       year_2022.sales_2022 - year_2021.sales_2021 AS difference
FROM   products join year_2021 on products.product_id = year_2021.product_id
    join year_2022 on products.product_id = year_2022.product_id
WHERE  year_2022.sales_2022 - year_2021.sales_2021 > 0
LIMIT 1;
CPU times: user 53.6 s, sys: 966 ms, total: 54.6 s
Wall time: 55.2 s


In [15]:
sns_schema = """
Create Category table
CREATE TABLE Category (
    id INTEGER PRIMARY KEY,
    title TEXT,
    parent_id INTEGER,
    FOREIGN KEY (parent_id) REFERENCES Category(id)
);

-- Create Post table
CREATE TABLE Post (
    id INTEGER PRIMARY KEY,
    title TEXT,
    content TEXT,
    created_date DATETIME,
    last_modified_date DATETIME,
    created_id INTEGER,
    last_modified_id INTEGER,
    category_id INTEGER,
    FOREIGN KEY (created_id) REFERENCES User(id),
    FOREIGN KEY (last_modified_id) REFERENCES User(id),
    FOREIGN KEY (category_id) REFERENCES Category(id)
);

-- Create Comment table
CREATE TABLE Comment (
    id INTEGER PRIMARY KEY,
    post_id INTEGER,
    content TEXT,
    created_date DATETIME,
    last_modified_date DATETIME,
    created_id INTEGER,
    parent_id INTEGER,
    FOREIGN KEY (post_id) REFERENCES Post(id),
    FOREIGN KEY (created_id) REFERENCES User(id),
    FOREIGN KEY (parent_id) REFERENCES Comment(id)
);

-- Create User table
CREATE TABLE User (
    id INTEGER PRIMARY KEY,
    username TEXT,
    email TEXT,
    first_name TEXT,
    last_name TEXT
);

-- Create Reaction table
CREATE TABLE Reaction (
    id INTEGER PRIMARY KEY,
    user_id INTEGER,
    content_id INTEGER,
    content_type TEXT,
    reaction_type TEXT,
    FOREIGN KEY (user_id) REFERENCES User(id),
    FOREIGN KEY (content_id) REFERENCES Post(id) ON DELETE CASCADE,
    FOREIGN KEY (content_id) REFERENCES Comment(id) ON DELETE CASCADE
);
"""

sns_questions = [
    'write SQL query to list all posts that belong to the category "Vacancies" \
    and its subordinates as long as the posts need to have at least 5 comments and 3 reactions.',

    'write SQL query to list all posts that belong to the category "Vacancies" \
    and its subordinates as long as the posts need to have at least 5 comments \
    and at least 10 reactions across either the posts or their comments.',

    'what are the full name of the top 5 users whose posts or comments having \
    the highest number of reactions of "Like" or "Love".',

    "Who is the most active users in commenting on across the posts that contain \
    the keyword 'education' and 'policy'  as well as the posts' comments from June 2022 to September 2023"
    
]

In [16]:
%%time
print(generate_sql_query(sql_model, sql_tokenizer, question=sns_questions[0], db_schema=sns_schema, tables_hints=None))



WITH RECURSIVE subordinates(id, depth) AS (
    SELECT c.id, 0
    FROM   category c
    WHERE  c.parent_id IS NULL
    UNION ALL
    SELECT c.id, s.depth + 1
    FROM   category c JOIN subordinates s ON c.parent_id = s.id
), posts_with_comments_and_reactions(post_id) AS (
    SELECT p.id
    FROM   post p JOIN comment c ON p.id = c.post_id
    WHERE  p.category_id in (SELECT id
                            FROM   category
                            WHERE  title = 'Vacancies')
    GROUP BY p.id
    HAVING count(*) >= 5
    UNION ALL
    SELECT r.content_id
    FROM   reaction r
    WHERE  r.content_type = 'post'
    GROUP BY r.content_id
    HAVING count(*) >= 3
)
SELECT p.id,
       p.title,
       p.content,
       p.created_date,
       p.last_modified_date,
       p.created_id,
       p.last_modified_id,
       p.category_id
FROM   post p
WHERE  p.id in (SELECT post_id
                 FROM   posts_with_comments_and_reactions) and p.category_id in (SELECT id
                       