In [None]:
!pip install -r requirements.txt

In [None]:
!pip install transformers==4.31.0

In [15]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Fri Sep 22 05:03:46 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   42C    P8    17W / 350W |  13329MiB / 24268MiB |      0%      Default |
|                               |   

In [1]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import numpy as np
from schema_item_filter import SchemaItemClassifierInference, filter_schema
import torch
from tqdm import tqdm
from transformers.trainer_utils import set_seed

import sys

from utils.db_utils import check_sql_executability, get_db_schema_sequence, get_matched_content_sequence, detect_special_char
from typing import List

# Main functions

In [7]:
max_tokens = 2048
max_new_tokens = 256
model_name = "seeklhy/codes-3b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.bfloat16,
    # load_in_8bit=True,
    # load_in_4bit=True,
    device_map="auto",
    use_cache=True,
)
model.eval()


Downloading (…)okenizer_config.json:   0%|          | 0.00/717 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/564 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/32.7k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/2.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

GPTBigCodeForCausalLM(
  (transformer): GPTBigCodeModel(
    (wte): Embedding(49152, 2816)
    (wpe): Embedding(8192, 2816)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPTBigCodeBlock(
        (ln_1): LayerNorm((2816,), eps=1e-05, elementwise_affine=True)
        (attn): GPTBigCodeAttention(
          (c_attn): Linear(in_features=2816, out_features=3072, bias=True)
          (c_proj): Linear(in_features=2816, out_features=2816, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((2816,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTBigCodeMLP(
          (c_fc): Linear(in_features=2816, out_features=11264, bias=True)
          (c_proj): Linear(in_features=11264, out_features=2816, bias=True)
          (act): PytorchGELUTanh()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((2816,), eps=

### Prepare SQL query prompt format

In [8]:
SQL_QUERY_PROMPT_TEMPLATE = """
### Instructions:
Your task is to convert a question into a SQL query, given a SQLlite database schema
Adhere to these rules:
- ****
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float.
### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{db_schema}

### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
"""

In [9]:
def generate_sql_query_generation_prompt(question, db_schema):
    return SQL_QUERY_PROMPT_TEMPLATE.format(question=question, db_schema=db_schema)

### SQL Generation Function

In [16]:
def generate_sql_query(question: str, db_schema: str, tables_hints: List[str], num_beams=5) -> str:
    # Implement logic to generate an SQL query based on the question and table hints.
    # Replace the "pass" with a calling function to LLM
    
    # Handle the case when table hints are empty or invalid.
    if not tables_hints:
        # Default behavior: Query all tables
        pass
    
    # Handle the general case
    # Example: "SELECT COUNT(*) FROM customers"
    prompt = generate_sql_query_generation_prompt(question, db_schema)
    eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # or to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=eos_token_id,
        pad_token_id=eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=num_beams
    )
    
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    sql_query = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    return sql_query


### Test the function

In [17]:
sales_db_schema = """
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
"""


In [18]:
question = "What product has the biggest fall in sales in 2022 compared to 2021? \
            Give me the product name, the sales amount in both years, and the difference."

In [19]:
%%time
# num_beam = 5
print(generate_sql_query(question=question, db_schema=sales_db_schema, tables_hints=None, num_beams=5))

SELECT
  products.name AS product_name,
  sales_2021.sales_amount AS sales_amount_2021,
  sales_2022.sales_amount AS sales_amount_2022,
  sales_2022.sales_amount - sales_2021.sales_amount AS difference
FROM products
LEFT JOIN (
  SELECT
    product_id,
    SUM(quantity * price) AS sales_amount
  FROM sales
  WHERE sale_date BETWEEN '2021-01-01' AND '2021-12-31'
  GROUP BY product_id
) AS sales_2021 ON products.product_id = sales_2021.product_id
LEFT JOIN (
  SELECT
    product_id,
    SUM(quantity * price) AS sales_amount
  FROM sales
  WHERE sale_date BETWEEN '2022-01-01' AND '2022-12-31'
  GROUP BY product_id
) AS sales_2022 ON products.product_id = sales_2022.product_id
ORDER BY difference DESC
LIMIT 1;
CPU times: user 7.9 s, sys: 0 ns, total: 7.9 s
Wall time: 7.89 s


# Helper function

In [2]:
def post_process(sql, schema_items):
    sql = sql.replace("\n", " ")
    for table in schema_items:
        for column_name in table["column_names"]:
            special_char_in_column_name = detect_special_char(column_name)
            if special_char_in_column_name and column_name in sql and "`"+column_name+"`" not in sql:
                sql = sql.replace(column_name, "`"+column_name+"`")
    sql = sql.replace(" order ", " `order` ")
    return sql


In [3]:
# extract the skeleton of the input text
def extract_skeleton(text):
    tokens_and_tags = nltk.pos_tag(nltk.word_tokenize(text))

    output_tokens = []
    for token, tag in tokens_and_tags:
        if tag in ['NN', 'NNP', 'NNS', 'NNPS', 'CD', 'SYM', 'FW', 'IN']:
            output_tokens.append("_")
        elif token in ['$', "''", '(', ')', ',', '--', '.', ':']:
            pass
        else:
            output_tokens.append(token)
    
    text_skeleton = " ".join(output_tokens)
    text_skeleton = text_skeleton.replace("_ 's", "_")
    text_skeleton = text_skeleton.replace(" 's", "'s")

    while("_ _" in text_skeleton):
        text_skeleton = text_skeleton.replace("_ _", "_")
    while("_ , _" in text_skeleton):
        text_skeleton = text_skeleton.replace("_ , _", "_")
    
    if text_skeleton.startswith("_ "):
        text_skeleton = text_skeleton[2:]
    
    return text_skeleton

In [4]:
def prepare_input_ids_and_attention_mask(tokenizer, input_seq, max_input_length, device):
    input_ids = tokenizer(input_seq , truncation = False)["input_ids"]

    if len(input_ids) <= max_input_length:
        input_ids = input_ids
        attention_mask = [1] * len(input_ids)
    else:
        if tokenizer.name_or_path == "THUDM/codegeex2-6b":
            input_ids = [64790, 64792] + input_ids[-(max_input_length-2):]
        else:
            input_ids = [tokenizer.bos_token_id] + input_ids[-(max_input_length-1):]

        attention_mask = [1] * max_input_length
    
    print("len(input_ids):", len(input_ids))
 
    return {
        "input_ids": torch.tensor([input_ids]).to(device), # torch.int64
        "attention_mask": torch.tensor([attention_mask]).to(device) # torch.int64
    }

In [5]:
def prepare_cross_domain_input_seq(opt, eval_data, demonstration_set, similarity):
    top_k_indices = sorted(range(len(similarity)), key = lambda x: similarity[x], reverse = True)[:opt.num_of_demonstrations]
    # top_k_indices = list(reversed(top_k_indices))
    # top_k_indices = random.sample(range(len(similarity)), opt.num_of_demonstrations)
    print(top_k_indices)
    print(similarity[top_k_indices])

    input_seq = ""
    for idx in top_k_indices:
        demonstration_sql = demonstration_set[idx]["sql"]
        if demonstration_sql.endswith(";"):
            demonstration_sql = demonstration_sql[:-1].strip() + " ;"
        else:
            demonstration_sql = demonstration_sql.strip() + " ;"

        input_seq += demonstration_set[idx]["schema_sequence"] + "\n" + demonstration_set[idx]["content_sequence"] + "\n" + \
            demonstration_set[idx]["text"] + "\n" + demonstration_sql + "\n\n"

    input_seq += eval_data["schema_sequence"] + "\n" + eval_data["content_sequence"] + "\n" + eval_data["text"] + "\n"
    # print(input_seq)
    # print("-"*30)

    return input_seq


In [6]:
def text2sql_func(model, text2sql_input_seq, tokenizer, max_tokens, max_new_tokens):
    inputs = prepare_input_ids_and_attention_mask(
        tokenizer, 
        text2sql_input_seq, 
        max_tokens - max_new_tokens,
        model.device
    )

    input_length = inputs["input_ids"].shape[1]

    with torch.no_grad():
        generate_ids = model.generate(
            **inputs,
            max_new_tokens = max_new_tokens,
            num_beams = 4,
            num_return_sequences = 4,
            use_cache = True
        )

    generated_sqls = tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)

    return generated_sqls
