In [1]:
#!pip install -U accelerate bitsandbytes peft transformers datasets trl git-lfs wandb flash-attn sql-metadata scipy

In [1]:
import torch
import re
import pandas as pd
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from peft import PeftModel
from torch import cuda
from sql_metadata import Parser
from tqdm import tqdm
from modelscope import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm
2024-04-26 08:27:45,389 - modelscope - INFO - PyTorch version 2.2.2 Found.
2024-04-26 08:27:45,391 - modelscope - INFO - Loading ast index from /hpc2hdd/home/jzhao815/.cache/modelscope/ast_indexer
2024-04-26 08:27:46,341 - modelscope - INFO - Loading done! Current index file version is 1.13.3, with md5 2bbcf3979e8ce95c88f5a68d5ba83352 and a total number of 972 components indexed


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#use local model or from huggingface/modelscope
model_name = "qwen/CodeQwen1.5-7B-Chat"
# model_name = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_4bit_compute_dtype = torch.float16,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2", # use with amper architecture
    torch_dtype=torch.bfloat16,
    #quantization_config=bnb_config, # use when low on memory
    device_map = "auto"
)

: 

In [4]:
model = PeftModel.from_pretrained(model, "./final_checkpoint_part1",torch_dtype = torch.bfloat16)
model = model.merge_and_unload()
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

In [5]:
tokenizer.encode(' ;')

[5106]

In [6]:
from transformers import StoppingCriteria
class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [6203]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids
    
def append_string_to_file(text, file_path):
  with open(file_path, 'a') as file:
      file.write(text + '\n')

def remove_spaces(text):
  return re.sub(r'\s+', ' ', text)

def call_mistral(inputs):
  output_tokens = model.generate(inputs, max_new_tokens=250, do_sample=False, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, stopping_criteria = [EosListStoppingCriteria()])
  return tokenizer.decode(output_tokens[0][len(inputs[0]):], skip_special_tokens=True)

In [None]:
df = pd.read_csv("./validation/spider_syn_dataset.csv")
results = []
for index, row in tqdm(df.iterrows(), total=len(df)):
  question = row['question']
  query = row['query']
  database_schema = row['database_schema']
  db_id = row['db_id']
  user_message = f"""Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.
{database_schema}
###
Question: {question}
"""
  messages = [
      {"role": "user", "content": user_message.strip()}
  ]
  inputs = tokenizer.apply_chat_template(messages, return_tensors="pt",add_generation_prompt=True,tokenize = True).to(model.device)
  response = call_mistral(inputs)
  if ";" in response:
    response = response.split(";")[0]
    if "Tables:" in response:
      response = response.split("Tables:")[1]
  response = re.sub(r'\s+', ' ', response).strip()
  try:
    ref_rables = ", ".join(Parser(query).tables)
  except Exception:
    continue
  print("\n")
  print(response)
  print(ref_rables)
  print("============================")
  results.append([response, ref_rables, query,row['question'],row['db_id']])
  new_df = pd.DataFrame(results, columns = ['predicted_tables','reference_tables','query','question','db_id'])

In [None]:
total_samples = len(new_df)
total_accuracy = 0
filtered_accuracy = 0
total_precision = 0
total_recall = 0

for index, row in new_df.iterrows():
    
    if not row['predicted_tables'] or pd.isna(row['predicted_tables']):
        continue
    predicted_tables = row['predicted_tables'].split(", ")
    reference_tables = row['reference_tables'].split(", ")
    
    # Convert to lowercase and strip whitespace for comparison
    predicted_tables = [x.lower().replace("--","").replace("**","").strip() for x in predicted_tables]
    reference_tables = [x.lower().strip() for x in reference_tables]
    
    # Calculate accuracy
    if set(predicted_tables) == set(reference_tables):
        total_accuracy += 1
    
    # Calculate precision and recall
    true_positives = len(set(predicted_tables) & set(reference_tables))
    false_positives = len(set(predicted_tables) - set(reference_tables))
    false_negatives = len(set(reference_tables) - set(predicted_tables))

    if true_positives == len(reference_tables):
        filtered_accuracy += 1
    
    if len(predicted_tables) > 0:
        precision = true_positives / (true_positives + false_positives)
        recall = true_positives / (true_positives + false_negatives)
    
    total_precision += precision
    total_recall += recall

# Calculate average precision and recall
avg_precision = total_precision / total_samples
avg_recall = total_recall / total_samples

# Calculate total accuracy
accuracy = total_accuracy / total_samples
filtered_accuracy = filtered_accuracy / total_samples

print("Total Accuracy:", accuracy)
print("Filtered Accuracy:", filtered_accuracy)
print("Average Precision:", avg_precision)
print("Average Recall:", avg_recall)

new_df.to_csv("generated_test_schema_links.csv", index=False)

In [2]:
import pandas as pd
from tqdm import tqdm
from utils.database_formatter import get_table_schema_with_samples, get_all_table_names
from utils.sql_regularizator import format_and_lowercase_sql_query
from utils.prompts import (
    sql_generation_prompt_token_counter,
    schema_linking_prompt_token_counter,
)
from transformers import AutoTokenizer
from sql_metadata import Parser
import sqlite3
BASE_DATABASES_DIR = "./spider/test_database"
CONTEXT_WINDOW = 3000
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


from difflib import get_close_matches

def get_closest_table_name(cursor, table_name):
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    available_tables = [row[0] for row in cursor.fetchall()]
    closest_matches = get_close_matches(table_name, available_tables, n=1, cutoff=0.6)  
    return closest_matches[0] if closest_matches else None

def create_sql_generation_correct_tables(dataset, question, query, db_uri, correct_tables):
    connection = sqlite3.connect(db_uri)
    cursor = connection.cursor()
    
    correct_columns = Parser(query).columns
    database_schema_filtered = ""
    all_tables = get_all_table_names(db_uri)
    for table in reversed(correct_tables.split(",")):
        closest_table_name = get_closest_table_name(cursor, table)
        if closest_table_name:
            database_schema_filtered += get_table_schema_with_samples(db_uri, closest_table_name)
            database_schema_filtered += "\n"
        else:
            print(f"Warning: No close match found for table {table} in database.")
    database_schema = ""
    for table in all_tables:
        database_schema += get_table_schema_with_samples(db_uri, table)
        database_schema += "\n"
    if (
        schema_linking_prompt_token_counter(question, database_schema, correct_tables, correct_columns, tokenizer)
        <= CONTEXT_WINDOW
    ):
        dataset.append(
            {
                "db_id": db_uri.split("/")[-1].split(".")[0],
                "question": question,
                "query": query,
                "filtered_database_schema": database_schema_filtered,
                "database_schema": database_schema,
                "correct_tables": ", ".join(correct_tables),
                "correct_columns": ", ".join(correct_columns),
            }
        )
    connection.close()
    return dataset
def load_spider_dev_set():
    df = pd.read_csv("generated_test_schema_links_test2000.csv")
    df.iloc[:, :1] = df.iloc[:, :1].apply(lambda x: x.str.replace(' ', ''), axis=1)
    return df

if __name__ == "__main__":
    # Load Spider dev set
    df = load_spider_dev_set()
    filtered_finetuning_dataset = []
    for index, row in tqdm(df.iterrows(), total=len(df)):
        db_id = row["db_id"]
        question = row["question"]
        query = row["query"]
        correct_tabs=row["predicted_tables"]
        #print(correct_tabs)
        formatted_query = format_and_lowercase_sql_query(query)
        db_uri = f"{BASE_DATABASES_DIR}/{db_id}/{db_id}.sqlite"
        filtered_validation_dataset = create_sql_generation_correct_tables(
            filtered_finetuning_dataset, question, formatted_query, db_uri, correct_tabs
        )
    filtered_validation_dataset = pd.DataFrame(filtered_validation_dataset)
    filtered_validation_dataset.to_csv('useful_val_dataset4.csv')


  0%|          | 0/2147 [00:00<?, ?it/s]
No chat template is defined for this tokenizer - using the default template for the GPT2TokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.

 22%|██▏       | 481/2147 [00:04<00:21, 77.59it/s] 



 31%|███▏      | 676/2147 [00:06<00:11, 129.80it/s]



100%|██████████| 2147/2147 [00:16<00:00, 131.34it/s]
