In [None]:
!pip install torch transformers accelerate bitsandbytes scipy 

In [30]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Mon Sep 18 21:42:27 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
| 48%   55C    P8    18W / 350W |  19115MiB / 24268MiB |      0%      Default |
|                               |   

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from typing import List
import sqlite3

In [4]:
import transformers
print(transformers.__version__)

4.31.0


# Load the model

Set parameter load_in_4bit=True for a model size of 12GB, otherwise the model will be too big. 
If you work on A100 machine, parameter torch_dtype=torch.bfloat16 can be set to True instead.

Can read more here: https://huggingface.co/blog/4bit-transformers-bitsandbytes

In [5]:
model_name = "defog/sqlcoder"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.bfloat16,
    load_in_8bit=True,
    # load_in_4bit=True,
    device_map="auto",
    use_cache=True,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Main functions required by the Huawei organisers

### Build a Prompt template to get an SQL query 

In [6]:
SQL_QUERY_PROMPT_TEMPLATE = """### Instructions:
Your task is to convert a question into a SQL query, given a SQLlite database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{db_schema}

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
"""

In [7]:
def generate_sql_query_generation_prompt(question, db_schema):
    return SQL_QUERY_PROMPT_TEMPLATE.format(question=question, db_schema=db_schema)

### Build a Prompt template to generate an answer based on the schema

In [53]:
ANSWER_GENERATION_PROMPT_TEMPLATE = """### Instructions:
Your task is to convert a returned information from a databse into an text answer to a question.

### Input:
Generate an answer to the `{question}` based on the extracted tabular information.
The information extracted from the database is as follows:
{returned_schema}
Each row represents a data point, and the columns are separated by "|".

Your answer should be short, concise and straight to the point.

### Response with the following format:
Based on your question and the returned schema, here is the answer I have generated to answer the question `{question}`:
```text generated answer
"""

In [51]:
def generate_answer_generation_prompt(question, returned_schema):
    return ANSWER_GENERATION_PROMPT_TEMPLATE.format(question=question, returned_schema=str(returned_schema))

### Function to connect to database and return a connection object

In [10]:
def connect_fun(database_name: str) -> sqlite3.Connection:
    """
    Connect to an SQLite database and return a connection object.

    Parameters:
        database_name (str): The name (or path) of the SQLite database file to connect to.

    Returns:
        sqlite3.Connection or None: A connection object if the connection is successful,
        or None if there is an error.

    Example usage:
        db_name = 'your_database_name.db'
        connection = connect_fun(db_name)
        
        if connection:
            print(f"Connected to {db_name}")
            # You can now use 'connection' to interact with the database.
        else:
            print("Connection failed.")
    """
    try:
        connection = sqlite3.connect(database_name)
        return connection
    except sqlite3.Error as e:
        print(f"Error connecting to the database: {e}")
        return None

### Function to query a database and return an answer based on the context return from the SQL query

Example: 

    Question: What is the months with the highest sales in the year? 
    Answer:   The highest sales happened in November, December and January 
    Prompt flow: 
        Question (str) -> **LLM SQL query model** -> SQL query -> returned schema -> **LLM to generate answer with context based on returned schema**

In [17]:
def generate_sql_query(question: str, db_schema: str, tables_hints: List[str], num_beams=5) -> str:
    # Implement logic to generate an SQL query based on the question and table hints.
    # Replace the "pass" with a calling function to LLM
    
    # Handle the case when table hints are empty or invalid.
    if not tables_hints:
        # Default behavior: Query all tables
        pass
    
    # Handle the general case
    # Example: "SELECT COUNT(*) FROM customers"
    prompt = generate_sql_query_generation_prompt(question, db_schema)
    eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # or to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=eos_token_id,
        pad_token_id=eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=num_beams
    )
    
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    sql_query = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    return sql_query


In [63]:
def generate_answer_with_context(question: str, returned_data: str) -> str:
    answer_generation_prompt = generate_answer_generation_prompt(question, returned_data)
    
    ### Use the same procedure
    eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
    inputs = tokenizer(answer_generation_prompt, return_tensors="pt").to("cuda") # or to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=eos_token_id,
        pad_token_id=eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=5
    )
    
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    answer = outputs[0].split("```text generated answer")[-1].split("```")[0]
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    return answer
    
    

### Combine all functions together

In [13]:
def query_fun(question: str, tables_hints: List[str], conn: sqlite3.Connection) -> str:
    """
    Generate an answer to a question based on an SQLite database and question context.

    Parameters:
        question (str): The user's question.
        tables_hints (List[str]): List of table names to consider in the query.
        conn (sqlite3.Connection): A connection to the SQLite database.

    Returns:
        str: The answer to the question.

    Example usage:
        question = "How many customers are there in the database?"
        table_hints = ["customers"]
        connection = sqlite3.connect("your_database.db")
        answer = query_fun(question, table_hints, connection)
        print(answer)
    """
    try:
        # Step 1: Generate an SQL query based on the question and table hints.
        sql_query = generate_sql_query(question, tables_hints)

        # Step 2: Execute the SQL query and fetch the results.
        cursor = conn.cursor()
        cursor.execute(sql_query)

        # Step 3: Obtain the schema information (column names) from the cursor description.
        schema = [desc[0] for desc in cursor.description]

        # Step 4: Process the query result and generate an answer with context using LLM.
        answer = generate_answer_with_context(question, schema)

        return answer

    except sqlite3.Error as e:
        print(f"SQLite Error: {e}")
        return "An error occurred while processing the query."
    except Exception as e:
        print(f"Error: {e}")
        return "An error occurred."

# Now test the solution

### Test the function with sales db and question

In [25]:
sales_db_schema = """
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""

In [24]:
question = "What product has the biggest fall in sales in 2022 compared to 2021? \
            Give me the product name, the sales amount in both years, and the difference."

In [29]:
%%time
# num_beam = 5
print(generate_sql_query(question=question, db_schema=sales_db_schema, tables_hints=None, num_beams=5))

WITH sales_2021 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2021
  FROM   sales
  WHERE  sales.sale_date >= '2021-01-01'
     AND sales.sale_date <= '2021-12-31'
  GROUP BY sales.product_id), sales_2022 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2022
  FROM   sales
  WHERE  sales.sale_date >= '2022-01-01'
     AND sales.sale_date <= '2022-12-31'
  GROUP BY sales.product_id)
SELECT products.name,
       sales_2021.sales_2021,
       sales_2022.sales_2022,
       sales_2022.sales_2022 - sales_2021.sales_2021 AS difference
FROM   products
    LEFT JOIN sales_2021 ON products.product_id = sales_2021.product_id
    LEFT JOIN sales_2022 ON products.product_id = sales_2022.product_id
ORDER BY difference DESC
LIMIT 1;
CPU times: user 19.7 s, sys: 49.2 ms, total: 19.7 s
Wall time: 19.7 s


In [27]:
%%time
# num_beam = 3
print(generate_sql_query(question=question, db_schema=sales_db_schema, tables_hints=None, num_beams=3))

WITH sales_2021 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2021
  FROM   sales
  WHERE  sales.sale_date >= '2021-01-01'
     AND sales.sale_date <= '2021-12-31'
  GROUP BY sales.product_id
), sales_2022 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2022
  FROM   sales
  WHERE  sales.sale_date >= '2022-01-01'
     AND sales.sale_date <= '2022-12-31'
  GROUP BY sales.product_id
)
SELECT products.name,
       sales_2021.sales_2021,
       sales_2022.sales_2022,
       sales_2022.sales_2022 - sales_2021.sales_2021 AS difference
FROM   products
    LEFT JOIN sales_2021 ON products.product_id = sales_2021.product_id
    LEFT JOIN sales_2022 ON products.product_id = sales_2022.product_id
ORDER BY difference DESC
LIMIT 1;
CPU times: user 28 s, sys: 36.9 ms, total: 28 s
Wall time: 28 s


In [31]:
%%time
# num_beam = 1
print(generate_sql_query(question=question, db_schema=sales_db_schema, tables_hints=None, num_beams=1))

WITH year_2021 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2021
  FROM   sales
  WHERE  sales.sale_date >= '2021-01-01'
     AND sales.sale_date < '2022-01-01'
  GROUP BY sales.product_id
  ORDER BY sales_2021 desc
), year_2022 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2022
  FROM   sales
  WHERE  sales.sale_date >= '2022-01-01'
     AND sales.sale_date < '2023-01-01'
  GROUP BY sales.product_id
  ORDER BY sales_2022 desc
)
SELECT products.name,
       year_2021.sales_2021,
       year_2022.sales_2022,
       year_2022.sales_2022 - year_2021.sales_2021 AS difference
FROM   products join year_2021 on products.product_id = year_2021.product_id
join year_2022 on products.product_id = year_2022.product_id
LIMIT 1;
CPU times: user 16.9 s, sys: 30.7 ms, total: 16.9 s
Wall time: 16.9 s


In [None]:
question = "What is the highest sales of three salesman person? Give me the salesperson's name and his or her total sales"

In [None]:
print(generate_sql_query(question=question, db_schema=sales_db_schema, tables_hints=None))

In [32]:
%%time
question = "In 1981 which team picked overall 148?"
test_db_schema = """
CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)
"""
print(generate_sql_query(question=question, db_schema=test_db_schema, tables_hints=None, num_beams=1))


SELECT table_name_8.team
FROM   table_name_8
WHERE  table_name_8.year = '1981'
   and table_name_8.overall_pick = '148';
CPU times: user 3 s, sys: 24.2 ms, total: 3.03 s
Wall time: 3.02 s


### Get a dummy database

![DB Image](./imgs/chinook-er-diagram.png)

In [None]:
import sqlite3
import pandas as pd
# Define the SQLite database file name.
db_file = "./sample_db/Chinook_Sqlite.sqlite"

# Connect to the SQLite database.
## creating a connection
conn = sqlite3.connect(db_file)

## importing tables 
tables = pd.read_sql("""SELECT *
                        FROM sqlite_master
                        WHERE type='table';""", conn)

tables

In [None]:
chinook_db_schema ="""
CREATE TABLE [Album]
(
    [AlbumId] INTEGER  NOT NULL,
    [Title] NVARCHAR(160)  NOT NULL,
    [ArtistId] INTEGER  NOT NULL,
    CONSTRAINT [PK_Album] PRIMARY KEY  ([AlbumId]),
    FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
CREATE TABLE [Artist]
(
    [ArtistId] INTEGER  NOT NULL,
    [Name] NVARCHAR(120),
    CONSTRAINT [PK_Artist] PRIMARY KEY  ([ArtistId])
)
CREATE TABLE [Customer]
(
    [CustomerId] INTEGER  NOT NULL,
    [FirstName] NVARCHAR(40)  NOT NULL,
    [LastName] NVARCHAR(20)  NOT NULL,
    [Company] NVARCHAR(80),
    [Address] NVARCHAR(70),
    [City] NVARCHAR(40),
    [State] NVARCHAR(40),
    [Country] NVARCHAR(40),
    [PostalCode] NVARCHAR(10),
    [Phone] NVARCHAR(24),
    [Fax] NVARCHAR(24),
    [Email] NVARCHAR(60)  NOT NULL,
    [SupportRepId] INTEGER,
    CONSTRAINT [PK_Customer] PRIMARY KEY  ([CustomerId]),
    FOREIGN KEY ([SupportRepId]) REFERENCES [Employee] ([EmployeeId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
CREATE TABLE [Employee]
(
    [EmployeeId] INTEGER  NOT NULL,
    [LastName] NVARCHAR(20)  NOT NULL,
    [FirstName] NVARCHAR(20)  NOT NULL,
    [Title] NVARCHAR(30),
    [ReportsTo] INTEGER,
    [BirthDate] DATETIME,
    [HireDate] DATETIME,
    [Address] NVARCHAR(70),
    [City] NVARCHAR(40),
    [State] NVARCHAR(40),
    [Country] NVARCHAR(40),
    [PostalCode] NVARCHAR(10),
    [Phone] NVARCHAR(24),
    [Fax] NVARCHAR(24),
    [Email] NVARCHAR(60),
    CONSTRAINT [PK_Employee] PRIMARY KEY  ([EmployeeId]),
    FOREIGN KEY ([ReportsTo]) REFERENCES [Employee] ([EmployeeId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
CREATE TABLE [Genre]
(
    [GenreId] INTEGER  NOT NULL,
    [Name] NVARCHAR(120),
    CONSTRAINT [PK_Genre] PRIMARY KEY  ([GenreId])
)
CREATE TABLE [Invoice]
(
    [InvoiceId] INTEGER  NOT NULL,
    [CustomerId] INTEGER  NOT NULL,
    [InvoiceDate] DATETIME  NOT NULL,
    [BillingAddress] NVARCHAR(70),
    [BillingCity] NVARCHAR(40),
    [BillingState] NVARCHAR(40),
    [BillingCountry] NVARCHAR(40),
    [BillingPostalCode] NVARCHAR(10),
    [Total] NUMERIC(10,2)  NOT NULL,
    CONSTRAINT [PK_Invoice] PRIMARY KEY  ([InvoiceId]),
    FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
CREATE TABLE [InvoiceLine]
(
    [InvoiceLineId] INTEGER  NOT NULL,
    [InvoiceId] INTEGER  NOT NULL,
    [TrackId] INTEGER  NOT NULL,
    [UnitPrice] NUMERIC(10,2)  NOT NULL,
    [Quantity] INTEGER  NOT NULL,
    CONSTRAINT [PK_InvoiceLine] PRIMARY KEY  ([InvoiceLineId]),
    FOREIGN KEY ([InvoiceId]) REFERENCES [Invoice] ([InvoiceId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION,
    FOREIGN KEY ([TrackId]) REFERENCES [Track] ([TrackId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
CREATE TABLE [MediaType]
(
    [MediaTypeId] INTEGER  NOT NULL,
    [Name] NVARCHAR(120),
    CONSTRAINT [PK_MediaType] PRIMARY KEY  ([MediaTypeId])
)
CREATE TABLE [Playlist]
(
    [PlaylistId] INTEGER  NOT NULL,
    [Name] NVARCHAR(120),
    CONSTRAINT [PK_Playlist] PRIMARY KEY  ([PlaylistId])
)
CREATE TABLE [PlaylistTrack]
(
    [PlaylistId] INTEGER  NOT NULL,
    [TrackId] INTEGER  NOT NULL,
    CONSTRAINT [PK_PlaylistTrack] PRIMARY KEY  ([PlaylistId], [TrackId]),
    FOREIGN KEY ([PlaylistId]) REFERENCES [Playlist] ([PlaylistId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION,
    FOREIGN KEY ([TrackId]) REFERENCES [Track] ([TrackId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
CREATE TABLE [Track]
(
    [TrackId] INTEGER  NOT NULL,
    [Name] NVARCHAR(200)  NOT NULL,
    [AlbumId] INTEGER,
    [MediaTypeId] INTEGER  NOT NULL,
    [GenreId] INTEGER,
    [Composer] NVARCHAR(220),
    [Milliseconds] INTEGER  NOT NULL,
    [Bytes] INTEGER,
    [UnitPrice] NUMERIC(10,2)  NOT NULL,
    CONSTRAINT [PK_Track] PRIMARY KEY  ([TrackId]),
    FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION,
    FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION,
    FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
""" 

# Test the answer generation function

In [65]:
%%time
returned_data = """
TrackId | Name                                    | AlbumId | MediaTypeId | GenreId | Composer                                       | Milliseconds | Bytes     | UnitPrice
1       | For Those About To Rock (We Salute You) | 1       | 1           | 1       | Angus Young, Malcolm Young, Brian Johnson   | 343719       | 11170334  | 0.99
2       | Balls to the Wall                       | 2       | 2           | 1       | None                                          | 342562       | 5510424   | 0.99
3       | Fast As a Shark                         | 3       | 2           | 1       | F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman | 230619 | 3990994   | 1.26

"""

question = "What are best sellers tracks? Give me the name and its unit prices."

print(generate_answer_with_context(question=question, returned_data=returned_data))


The best sellers tracks are For Those About To Rock (We Salute You) with a unit price of $0.99, Balls to the Wall with a unit price of $0.99, and Fast As a Shark with a unit price of $1.26.

CPU times: user 3.99 s, sys: 30.6 ms, total: 4.02 s
Wall time: 4.02 s
