In [1]:
!nvidia-smi

Tue Sep 19 14:27:39 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:4F:00.0 Off |                  Off |
| 30%   25C    P8    18W / 300W |      1MiB / 48685MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import sqlite3
from typing import *
from infection.database import get_schemas
from infection.models import get_model
from infection.postprocess import format_sql_execution
from infection.prompt import (SQL_QUERY_PROMPT_TEMPLATE, ANSWER_GENERATION_PROMPT_TEMPLATE, generate_prompt)

def get_model_response(model, prompt_template:str, **kwargs):
    prompt = generate_prompt(prompt_template, **kwargs)
    outputs = model.generate(prompt, num_beams=1)
    result = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"
    return result

In [3]:
sql_model = get_model('sqlcoder', cache_dir='/home/mpham/workspace/huawei-arena-2023/.cache', load_in_4bit=True)
answer_model = sql_model

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
def connect_fun(database_name: str) -> sqlite3.Connection:
    """
    Connect to an SQLite database and return a connection object.

    Parameters:
        database_name (str): The name (or path) of the SQLite database file to connect to.

    Returns:
        sqlite3.Connection or None: A connection object if the connection is successful,
        or None if there is an error.

    Example usage:
        db_name = 'your_database_name.db'
        connection = connect_fun(db_name)
        
        if connection:
            print(f"Connected to {db_name}")
            # You can now use 'connection' to interact with the database.
        else:
            print("Connection failed.")
    """
    try:
        connection = sqlite3.connect(database_name)
        return connection
    except sqlite3.Error as e:
        print(f"Error connecting to the database: {e}")
        return None


def query_fun(question: str, conn: sqlite3.Connection, tables_hints: List[str]=None, debug:bool=False) -> str:
    """
    Generate an answer to a question based on an SQLite database and question context.

    Parameters:
        question (str): The user's question.
        tables_hints (List[str]): List of table names to consider in the query.
        conn (sqlite3.Connection): A connection to the SQLite database.

    Returns:
        str: The answer to the question.

    Example usage:
        question = "How many customers are there in the database?"
        table_hints = ["customers"]
        connection = sqlite3.connect("your_database.db")
        answer = query_fun(question, table_hints, connection)
        print(answer)
    """
    try:
        cursor = conn.cursor()

        # Step 0: Get related tables based on all schemas and table hints
        related_schemas = get_schemas(cursor, tables_hints) 

        if debug:
            print("Related schemas: \n", related_schemas)
            print('-'*30)
        
        # Step 1: Generate an SQL query based on the question and table hints.
        sql_query = get_model_response(
            sql_model, SQL_QUERY_PROMPT_TEMPLATE, 
            question=question, 
            db_schema=related_schemas, 
            tables_hints=tables_hints
        )

        if debug:
            print("SQL query: \n", sql_query)
            print('-'*30)

        # Step 2: Execute the SQL query and fetch the results.
        response = cursor.execute(sql_query)

        # Step 3: Obtain records from response and schema information (column names) from the cursor description.
        records = response.fetchall()
        reponse_schema = [desc[0] for desc in cursor.description]
        sql_response = format_sql_execution(records, reponse_schema)

        if debug:
            print("SQL execution response: \n", sql_response)
            print('-'*30)

        # Step 4: Process the query result and generate an answer with context using LLM.
        answer = get_model_response(
            answer_model, ANSWER_GENERATION_PROMPT_TEMPLATE,
            question=question, 
            returned_schema=sql_response
        )

        return sql_query, answer

    except sqlite3.Error as e:
        print(f"SQLite Error: {e}")
        return "An error occurred while processing the query."
    except Exception as e:
        print(f"Error: {e}")
        return "An error occurred."

In [5]:
connection = connect_fun('/home/mpham/workspace/huawei-arena-2023/data/chinook/Chinook_Sqlite.sqlite')

In [6]:
cursor = connection.cursor()
related_schemas = get_schemas(cursor, table_hints=None)
related_schemas

{'Album': [(0, 'AlbumId', 'INTEGER', 1, None, 1),
  (1, 'Title', 'NVARCHAR(160)', 1, None, 0),
  (2, 'ArtistId', 'INTEGER', 1, None, 0)],
 'Artist': [(0, 'ArtistId', 'INTEGER', 1, None, 1),
  (1, 'Name', 'NVARCHAR(120)', 0, None, 0)],
 'Customer': [(0, 'CustomerId', 'INTEGER', 1, None, 1),
  (1, 'FirstName', 'NVARCHAR(40)', 1, None, 0),
  (2, 'LastName', 'NVARCHAR(20)', 1, None, 0),
  (3, 'Company', 'NVARCHAR(80)', 0, None, 0),
  (4, 'Address', 'NVARCHAR(70)', 0, None, 0),
  (5, 'City', 'NVARCHAR(40)', 0, None, 0),
  (6, 'State', 'NVARCHAR(40)', 0, None, 0),
  (7, 'Country', 'NVARCHAR(40)', 0, None, 0),
  (8, 'PostalCode', 'NVARCHAR(10)', 0, None, 0),
  (9, 'Phone', 'NVARCHAR(24)', 0, None, 0),
  (10, 'Fax', 'NVARCHAR(24)', 0, None, 0),
  (11, 'Email', 'NVARCHAR(60)', 1, None, 0),
  (12, 'SupportRepId', 'INTEGER', 0, None, 0)],
 'Employee': [(0, 'EmployeeId', 'INTEGER', 1, None, 1),
  (1, 'LastName', 'NVARCHAR(20)', 1, None, 0),
  (2, 'FirstName', 'NVARCHAR(20)', 1, None, 0),
  (3, 'Ti

In [11]:
questions = [
    # "What is the highest sales of three salesman person? Give me the salesperson's name and his or her total sales",
    # "In 1981 which team picked overall 148?"

    "Find me 5 random song track names",
    "What is the mean price of all the song tracks ?",
    "How many employees are there ?",
    "What is the nationality that has the most number of our customers ?",
]

In [15]:
%%time
sql_command, text_response = query_fun(
    question=questions[3],
    conn=connection,
    debug=False
)

CPU times: user 4.38 s, sys: 1.43 s, total: 5.81 s
Wall time: 5.84 s


In [16]:
print(sql_command)

SELECT country
  FROM customer
  GROUP BY country
  ORDER BY count(*) desc
  limit 1;


In [17]:
print(text_response)

The most common nationality among our customers is USA. This is because our customers are mostly from the United States.;
