In [1]:
import pandas as pd
# question variants
question = pd.read_csv('test_data.csv')

# original questions
demonstration = question.iloc[list(range(4,500,5))][['question','true_query']].reset_index(drop=True)

In [7]:
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer("all-MiniLM-L6-v2")

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
# Compute embedding for both lists
embeddings1 = model.encode(question['question'].tolist(), convert_to_tensor=True)
embeddings2 = model.encode(demonstration['question'].tolist(), convert_to_tensor=True)

# Compute cosine-similarities
cosine_scores = util.cos_sim(embeddings1, embeddings2)

In [14]:
import numpy as np
few_n = 3
top_n_idx = np.argsort(cosine_scores.tolist(), axis=1)[:,::-1][:,1:1+few_n]   # skip the most relevant one

# Prompt

In [16]:
prompt = """table catastici , columns = [ catastici.Owner_First_Name ( text ) , catastici.Owner_Family_Name ( text ) , catastici.Property_Type ( text ) , catastici.Rent_Income ( integer ) , catastici.Property_Location ( text )]
Owner_First_Name -- First name of the owner of the property ; Owner_Family_Name -- Family name of the owner of the property ; Property_Type -- Specific type of the property given in Italian. For example, "casa", "bottega da barbier", "bottega da fruttariol". ; Rent_Income -- Rent price of the property that the owner receives as income, given in Venice ancient gold coin ducato. ; Property_Location -- Ancient spproximate toponym of the property given in Italian.
{few_shot}
{question}
"""

In [17]:
in_prompt = []
for idx, val in enumerate(top_n_idx):
    few_shot = '\n'.join([demonstration.iloc[i]['question']+'\n'+demonstration.iloc[i]['true_query'].replace('\n','') for i in val])
    q = question.iloc[idx]['question']
    in_prompt.append(prompt.format(few_shot=few_shot, question=q))

# Inference

In [21]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = 'seeklhy/codes-7b'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map = "auto", 
    torch_dtype = torch.float16
)

# update eos token id of the tokenizer and the model to support early stop SQL generation
token_ids_of_example_sql = tokenizer("SELECT * FROM tables ;")["input_ids"]
print(token_ids_of_example_sql)
if token_ids_of_example_sql[-1] == tokenizer.eos_token_id:
    new_eos_token_id = token_ids_of_example_sql[-2]
else:
    new_eos_token_id = token_ids_of_example_sql[-1]
model.config.eos_token_id = new_eos_token_id
tokenizer.eos_token_id = new_eos_token_id
print("new_eos_token_id:", new_eos_token_id)
print("tokenizer.decode(new_eos_token_id): '{}'".format(tokenizer.decode(new_eos_token_id)))

Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [03:05<00:00, 61.84s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:10<00:00,  3.49s/it]


[4620, 319, 3753, 10343, 2082]
new_eos_token_id: 2082
tokenizer.decode(new_eos_token_id): ' ;'


In [43]:
import sqlparse

def prepare_input_ids_and_attention_mask(tokenizer, input_seq, max_input_length, device):
    input_ids = tokenizer(input_seq , truncation = False)["input_ids"]

    if len(input_ids) <= max_input_length:
        input_ids = input_ids
        attention_mask = [1] * len(input_ids)
    else:
        if tokenizer.name_or_path == "THUDM/codegeex2-6b":
            input_ids = [64790, 64792] + input_ids[-(max_input_length-2):]
        else:
            input_ids = [tokenizer.bos_token_id] + input_ids[-(max_input_length-1):]

        attention_mask = [1] * max_input_length
    
    # print("len(input_ids):", len(input_ids))
 
    return {
        "input_ids": torch.tensor([input_ids]).to(device), # torch.int64
        "attention_mask": torch.tensor([attention_mask]).to(device) # torch.int64
    }

def text2sql_func(model, text2sql_input_seq, tokenizer, eos_token_id, max_tokens=8192, max_new_tokens=256):
    inputs = prepare_input_ids_and_attention_mask(
        tokenizer, 
        text2sql_input_seq, 
        max_tokens - max_new_tokens,
        model.device
    )

    input_length = inputs["input_ids"].shape[1]

    # check_tokenizer(tokenizer, inputs["input_ids"])

    with torch.no_grad():
        generate_ids = model.generate(
            **inputs,
            max_new_tokens = max_new_tokens,
            num_beams = 4,
            num_return_sequences = 4,
            use_cache = True,
            eos_token_id = eos_token_id
        )

    generated_sqls = tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
    generated_sqls = [generated_sqls[i].split('\n')[0] for i in range(4)]

    return generated_sqls

In [36]:
print(in_prompt[0])

table catastici , columns = [ catastici.Owner_First_Name ( text ) , catastici.Owner_Family_Name ( text ) , catastici.Property_Type ( text ) , catastici.Rent_Income ( integer ) , catastici.Property_Location ( text )]
Owner_First_Name -- First name of the owner of the property ; Owner_Family_Name -- Family name of the owner of the property ; Property_Type -- Specific type of the property given in Italian. For example, "casa", "bottega da barbier", "bottega da fruttariol". ; Rent_Income -- Rent price of the property that the owner receives as income, given in Venice ancient gold coin ducato. ; Property_Location -- Ancient spproximate toponym of the property given in Italian.
What percentage of properties are located in "fondamenta de carmini"?
SELECT COUNT(*) AS total_properties,        (COUNT(*) * 100.0 / (SELECT COUNT(*) FROM catastici)) AS percentageFROM catasticiWHERE Property_Location = 'fondamenta de carmini'
Which property types are present in "calle di santa cattarina principia al

In [50]:
out = text2sql_func(model, in_prompt[0], tokenizer, new_eos_token_id)

Setting `pad_token_id` to `eos_token_id`:2082 for open-end generation.


In [47]:
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///catastici.db")

# test DB
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM catastici LIMIT 1;")

sqlite
['catastici']


"[('liberal', 'campi', 'casa e bottega da barbier', 70, 'campo vicino alla chiesa')]"

In [57]:
def check_sql_executability(query, db):
    try:
        return db.run(query)
    except:
        return "ERROR"

In [None]:
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

final_query = []
answers = []

for input_seq in tqdm(in_prompt):
    output = text2sql_func(model, input_seq, tokenizer, new_eos_token_id)
    final_out = None
    for out in output:
        answer = check_sql_executability(out, db)
        if answer != "ERROR":
            final_out = out
            break
    if final_out == None:
        final_out = '\n'.join(output)
        answer = "ERROR"
    final_query.append(final_out)
    answers.append(answer)

  0%|                                                                                                              | 0/500 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:2082 for open-end generation.
  0%|▏                                                                                                   | 1/500 [00:10<1:27:17, 10.50s/it]Setting `pad_token_id` to `eos_token_id`:2082 for open-end generation.
  0%|▍                                                                                                   | 2/500 [00:20<1:26:48, 10.46s/it]Setting `pad_token_id` to `eos_token_id`:2082 for open-end generation.
  1%|▌                                                                                                   | 3/500 [00:31<1:26:35, 10.45s/it]Setting `pad_token_id` to `eos_token_id`:2082 for open-end generation.
  1%|▊                                                                                                   | 4/500 [00:41<1:25:40, 10.36s/it]Setting `pad_token_id

In [None]:
question['generated_query'] = final_query
question['generated_answer'] = answers

In [None]:
question.to_csv('test_data_generated.csv',index=False)