In [1]:
# Command to auto reload imported libraries
%load_ext autoreload
%autoreload 2

In [1]:
# Path to the cache folder where models will be downloaded to
MAIN_PATH = "/home/ma-user/work/huawei-arena-2023"
CACHE_DIR = '/home/ma-user/work/huawei-arena-2023/huawei2023cache'

%cd $MAIN_PATH

/home/ma-user/work/huawei-arena-2023


In [2]:
import torch
torch.__version__

import sqlite3
from typing import *
import infection
from infection.display import print_answer
from infection.databases import (
    SQL3Database, 
    format_sql_execution,
    format_df_to_table
)
from infection.models import (
    get_model, 
    get_model_response
)
from infection.prompt import (
    generate_prompt,
    SQL_QUERY_PROMPT_TEMPLATE, ANSWER_GENERATION_PROMPT_TEMPLATE, 
    CHART_GENERATION_PROMPT_TEMPLATE, SQL_SAFETY_PROMPT_TEMPLATE,
    LLAMA2_ERROR_PROMPT_TEMPLATE, 
    LLAMA2_CHART_GENERATION_PROMPT_TEMPLATE, LLAMA2_ANSWER_GENERATION_PROMPT_TEMPLATE,
    NSQL350_QUERY_PROMPT_TEMPLATE
)
from infection.trustworthiness import (
    plot_sql_chart, 
    suggest_plot,
    fix_sql_hallucination, 
    check_sql_hallucination
)

from IPython.display import display, Markdown
from infection.safety import InjectionDetector
import pandas as pd
import numpy as np

answer_templates = "./response_template.csv"
answer_template_df = pd.read_csv(answer_templates)



In [7]:
from infection.modelarts.download import download
download("obs://infection/nsql350/models--NumbersStation--nsql-350M", CACHE_DIR+'/test')

100%|██████████| 21/21 [00:11<00:00,  1.77it/s]


In [8]:
sql_model = get_model('nsql350', cache_dir=CACHE_DIR, device='cpu')

### 2.2. Main functions for the challenge

In [5]:
def connect_fun(database_name: str) -> SQL3Database:
    """
    Connect to an SQLite database and return a connection object.

    Parameters:
        database_name (str): The name (or path) of the SQLite database file to connect to.

    Returns:
        sqlite3.Connection or None: A connection object if the connection is successful,
        or None if there is an error.

    Example usage:
        db_name = 'your_database_name.db'
        connection = connect_fun(db_name)
        
        if connection:
            print(f"Connected to {db_name}")
            # You can now use 'connection' to interact with the database.
        else:
            print("Connection failed.")
    """
    
    try:
        connection = SQL3Database(database_name)
        return connection
    except sqlite3.Error as e:
        print(f"Error connecting to the database: {e}")
        return None

In [6]:
def query_fun(question: str, tables_hints: List[str], conn) -> str:
    """
    Generate an answer to a question based on an SQLite database and question context.

    Parameters:
        question (str): The user's question.
        tables_hints (List[str]): List of table names to consider in the query.
        conn (sqlite3.Connection): A connection to the SQLite database.

    Returns:
        str: The answer to the question.

    Example usage:
        question = "How many customers are there in the database?"
        table_hints = ["customers"]
        connection = sqlite3.connect("your_database.db")
        answer = query_fun(question, table_hints, connection)
        print(answer)
    """
    
    try:
        # Step 0: Get related tables based on all schemas and table hints
        schemas = conn.get_schemas(tables_hints)
        formatted_schemas = conn.format_schemas(tables_hints, add_examples=2)
        
        #### Step 1: Generate an SQL query based on the question and table hints.
        sql_query = get_model_response(
            sql_model, NSQL350_QUERY_PROMPT_TEMPLATE, 
            question=question, 
            db_schema=formatted_schemas, 
            tables_hints=tables_hints,
            catchphrase='```sql'
        )
        
        # Step 1.75: Check for hallucination and attempt to fix the query
        mapping_dict, not_existing_query_names, sql_query = check_sql_hallucination(schemas, sql_query)
        if len(not_existing_query_names) > 0:
            raise Exception(f"These columns does not exists {not_existing_query_names}, rephrase your prompt")
        
        if len(mapping_dict) > 0:
            sql_query = fix_sql_hallucination(mapping_dict, sql_query)
        
        #### Step 2: Execute the SQL query and fetch the results.
        records, response_schema = conn.execute_sql(sql_query)
        
        # Step 3: Obtain records from response and schema information (column names) from the cursor description.
        # sql_response = format_sql_execution(records, reponse_schema, format='table')
        response_df = format_sql_execution(records, response_schema, format='dataframe')
        
        if response_df is None:
            raise Exception()
        
        response_df = response_df.drop_duplicates().head(10)
        sql_results = format_df_to_table(response_df, response_schema)
            
        # Step 4: Process the query result and generate an answer with context using LLM.
        # answer = get_model_response(
        #     answer_model, LLAMA2_ANSWER_GENERATION_PROMPT_TEMPLATE,
        #     question=question, 
        #     returned_schema=sql_response if len(records)>0 else 'No data',
        #     catchphrase='[/INST]'
        # )

        #Extra steps:
        random_template = np.random.choice(answer_template_df['template'].values.tolist(), 1).item()
        answer = random_template.replace('[data]', '\n'+sql_results)

        return str(answer)
    
    except Exception as e:
        return "I’m sorry but I cannot fulfill your request."

In [7]:
connection = connect_fun(MAIN_PATH+'/data/example-data/example-covid-vaccinations.sqlite3')

In [10]:
%%time

answer = query_fun(
    question="What was the biggest vaccination rate achieved?",
    tables_hints=[],
    conn=connection,
)
print(answer)

SELECT MAX(C03898V04649) FROM covid_vaccinations;
Here is the information retrieved for you: 
|    | MAX(c03898v04649)                    |
|----|--------------------------------------|
|  0 | 2ae19629-3f77-13a3-e055-000000000001 |.
CPU times: user 42.7 s, sys: 403 ms, total: 43.1 s
Wall time: 5.39 s


In [11]:
%%time
import testing2
testing2.run_all_tests(connect_fun=connect_fun, query_fun=query_fun)

Example query: select distinct STATISTIC_CODE from covid_vaccinations WHERE `Statistic_Label` = 'Fully Vaccinated'
Example result:
  STATISTIC_CODE
0       CDC45C01
SELECT STATISTIC_CODE FROM covid_vaccinations WHERE Statistic_Label = 'fully vaccinated';
Model result:
I’m sorry but I cannot fulfill your request.
------------------------------------------------------------
Example query: select count(distinct `Age Group`) from covid_vaccinations
Example result:
   count(distinct `Age Group`)
0                            2
SELECT COUNT(DISTINCT Age Group) FROM covid_vaccinations;
Model result:
Below is the data retrieved for you: 
|    |   COUNT(DISTINCT `age group`) |
|----|-------------------------------|
|  0 |                             2 |.
------------------------------------------------------------
Example query: select max(VALUE) from covid_vaccinations
Example result:
   max(VALUE)
0        99.4
SELECT MAX(C03898V04649) FROM covid_vaccinations;
Model result:
The data you need i