In [4]:
!pip install torch transformers accelerate bitsandbytes scipy 

Collecting scipy
  Obtaining dependency information for scipy from https://files.pythonhosted.org/packages/a3/d3/f88285098505c8e5d141678a24bb9620d902c683f11edc1eb9532b02624e/scipy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading scipy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (59 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Downloading scipy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.5 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.5/36.5 MB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m0m eta [36m0:00:01[0m[36m0:00:01[0m
[?25hInstalling collected packages: scipy
Successfully installed scipy-1.11.2


In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load the model

Set parameter load_in_4bit=True for a model size of 12GB, otherwise the model will be too big. 
If you work on A100 machine, parameter torch_dtype=torch.bfloat16 can be set to True instead

In [5]:
model_name = "defog/sqlcoder"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.bfloat16,
    # load_in_8bit=True,
    load_in_4bit=True,
    device_map="auto",
    use_cache=True,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

# Prompt Engineering Here!

In [6]:
question =  """
            What product has the biggest fall in sales in 2022 compared to 2021? 
            Give me the product name, the sales amount in both years, and the difference.
            """

In [7]:
prompt = """### Instructions:
Your task is to convert a question into a SQL query, given a Postgres database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float

### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
""".format(question=question)

In [8]:
eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]

# Generate an SQL

In [11]:
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
generated_ids = model.generate(
    **inputs,
    num_return_sequences=1,
    eos_token_id=eos_token_id,
    pad_token_id=eos_token_id,
    max_new_tokens=400,
    do_sample=False,
    num_beams=5
)



In [12]:
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [13]:
torch.cuda.empty_cache()
torch.cuda.synchronize()
# empty cache so that you do generate more results w/o memory crashing
# particularly important on Colab – memory management is much more straightforward
# when running on an inference service

In [14]:
print(outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";")

WITH sales_2021 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2021
  FROM   sales
  WHERE  sales.sale_date >= '2021-01-01'
     AND sales.sale_date <= '2021-12-31'
  GROUP BY sales.product_id
), sales_2022 AS (
  SELECT sales.product_id,
         sum(sales.quantity) AS sales_2022
  FROM   sales
  WHERE  sales.sale_date >= '2022-01-01'
     AND sales.sale_date <= '2022-12-31'
  GROUP BY sales.product_id
)
SELECT products.name,
       sales_2021.sales_2021,
       sales_2022.sales_2022,
       sales_2022.sales_2022 - sales_2021.sales_2021 AS difference
FROM   products
    LEFT JOIN sales_2021 ON products.product_id = sales_2021.product_id
    LEFT JOIN sales_2022 ON products.product_id = sales_2022.product_id
ORDER BY difference DESC NULLS LAST;


# Main functions required by the Huawei organisers

### Function to connect to database and return a connection object

In [None]:
import sqlite3

def connect_fun(database_name: str) -> sqlite3.Connection:
    """
    Connect to an SQLite database and return a connection object.

    Parameters:
        database_name (str): The name of the SQLite database file to connect to.

    Returns:
        sqlite3.Connection or None: A connection object if the connection is successful,
        or None if there is an error.

    Example usage:
        db_name = 'your_database_name.db'
        connection = connect_fun(db_name)
        
        if connection:
            print(f"Connected to {db_name}")
            # You can now use 'connection' to interact with the database.
        else:
            print("Connection failed.")
    """
    try:
        connection = sqlite3.connect(database_name)
        return connection
    except sqlite3.Error as e:
        print(f"Error connecting to the database: {e}")
        return None

### Function to query a database and return an answer based on the context return from the SQL query

Example: 

    Question: What is the months with the highest sales in the year? 
    Answer:   The highest sales happened in November, December and January 
    Prompt flow: 
        Question (str) -> **LLM SQL query model** -> SQL query -> returned schema -> **LLM to generate answer with context based on returned schema**
        
    

In [16]:
from typing import List
import sqlite3

def query_fun(question: str, tables_hints: List[str], conn: sqlite3.Connection) -> str:
    """
    Generate an answer to a question based on an SQLite database and question context.

    Parameters:
        question (str): The user's question.
        tables_hints (List[str]): List of table names to consider in the query.
        conn (sqlite3.Connection): A connection to the SQLite database.

    Returns:
        str: The answer to the question.

    Example usage:
        question = "How many customers are there in the database?"
        table_hints = ["customers"]
        connection = sqlite3.connect("your_database.db")
        answer = query_fun(question, table_hints, connection)
        print(answer)
    """
    try:
        # Step 1: Generate an SQL query based on the question and table hints.
        sql_query = generate_sql_query(question, tables_hints)

        # Step 2: Execute the SQL query and fetch the results.
        cursor = conn.cursor()
        cursor.execute(sql_query)

        # Step 3: Obtain the schema information (column names) from the cursor description.
        schema = [desc[0] for desc in cursor.description]

        # Step 4: Process the query result and generate an answer with context using LLM.
        answer = generate_answer_with_context(question, schema)

        return answer

    except sqlite3.Error as e:
        print(f"SQLite Error: {e}")
        return "An error occurred while processing the query."
    except Exception as e:
        print(f"Error: {e}")
        return "An error occurred."

def generate_sql_query(question: str, tables_hints: List[str]) -> str:
    # Implement logic to generate an SQL query based on the question and table hints.
    # Replace the "pass" with a calling function to LLM
    
    # Handle the case when table hints are empty or invalid.
    if not tables_hints:
        # Default behavior: Query all tables
        return "SELECT * FROM sqlite_master WHERE type='table'"
    
    # Handle the case when table hints are provided.
    # Example: "SELECT COUNT(*) FROM customers"
    pass

def generate_answer_with_context(question: str, schema: List[str]) -> str:
    # Implement logic to generate an answer with context based on the question and schema.
    # You can use a language model (LLM) to generate the answer, incorporating schema information.
    # Example: "There are 100 customers in the database."
    pass
