In [1]:
import torch
import re
import pandas as pd
from transformers import AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from peft import PeftModel
from torch import cuda
from sql_metadata import Parser
from tqdm import tqdm
from vllm import LLM, SamplingParams

In [2]:
model_name = "qwen/CodeQwen1.5-7B-Chat"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from transformers import StoppingCriteria
class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [6203]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids
    
def append_string_to_file(text, file_path):
  with open(file_path, 'a') as file:
      file.write(text + '\n')

def remove_spaces(text):
  return re.sub(r'\s+', ' ', text)



In [4]:
import pandas as pd
from tqdm import tqdm
import sqlite3
from sql_metadata import Parser
from transformers import AutoTokenizer
from difflib import get_close_matches
from utils.database_formatter import get_table_schema_with_samples, get_all_table_names
from utils.sql_regularizator import format_and_lowercase_sql_query
from utils.prompts import schema_linking_prompt_token_counter

BASE_DATABASES_DIR = "./spider/database"
def get_closest_table_name(cursor, table_name):
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    available_tables = [row[0] for row in cursor.fetchall()]
    closest_matches = get_close_matches(table_name, available_tables, n=1, cutoff=0.6)
    return closest_matches[0] if closest_matches else None

def create_sql_generation_correct_tables(dataset, question, query, db_uri, correct_tables):
    connection = sqlite3.connect(db_uri)
    cursor = connection.cursor()
    correct_columns = Parser(query).columns
    database_schema_filtered = ""
    all_tables = get_all_table_names(db_uri)
    for table in reversed(correct_tables.split(",")):
        closest_table_name = get_closest_table_name(cursor, table)
        if closest_table_name:
            database_schema_filtered += get_table_schema_with_samples(db_uri, closest_table_name) + "\n"
    database_schema = ""
    for table in all_tables:
        database_schema += get_table_schema_with_samples(db_uri, table) + "\n"
    if schema_linking_prompt_token_counter(question, database_schema, correct_tables, correct_columns, tokenizer) <= CONTEXT_WINDOW:
        dataset.append({
            "db_id": db_uri.split("/")[-1].split(".")[0],
            "question": question,
            "query": query,
            "filtered_database_schema": database_schema_filtered,
            "database_schema": database_schema,
            "correct_tables": ", ".join(correct_tables),
            "correct_columns": ", ".join(correct_columns)
        })
    connection.close()
    return dataset

In [5]:
import pandas as pd

def evaluate_model_performance(new_df):
    total_samples = len(new_df)
    total_accuracy = 0
    filtered_accuracy = 0
    total_precision = 0
    total_recall = 0

    for index, row in new_df.iterrows():
        if not row['predicted_tables'] or pd.isna(row['predicted_tables']):
            continue
        predicted_tables = row['predicted_tables'].split(", ")
        reference_tables = row['reference_tables'].split(", ")
        
        # Convert to lowercase and strip whitespace for comparison
        predicted_tables = [x.lower().replace("--","").replace("**","").strip() for x in predicted_tables]
        reference_tables = [x.lower().strip() for x in reference_tables]
        
        # Calculate accuracy
        if set(predicted_tables) == set(reference_tables):
            total_accuracy += 1
        
        # Calculate precision and recall
        true_positives = len(set(predicted_tables) & set(reference_tables))
        false_positives = len(set(predicted_tables) - set(reference_tables))
        false_negatives = len(set(reference_tables) - set(predicted_tables))

        precision = recall = 0  # Initialize to avoid undefined variable if no true_positives
        if true_positives:
            precision = true_positives / (true_positives + false_positives)
            recall = true_positives / (true_positives + false_negatives)
        
        total_precision += precision
        total_recall += recall

    # Calculate average precision and recall if there are valid samples
    avg_precision = total_precision / total_samples if total_samples > 0 else 0
    avg_recall = total_recall / total_samples if total_samples > 0 else 0

    # Calculate total and filtered accuracy
    accuracy = total_accuracy / total_samples if total_samples > 0 else 0
    filtered_accuracy = filtered_accuracy / total_samples if total_samples > 0 else 0

    print("Total Accuracy:", accuracy)
    print("Filtered Accuracy:", filtered_accuracy)
    print("Average Precision:", avg_precision)
    print("Average Recall:", avg_recall)


In [6]:

import pandas as pd
from tqdm import tqdm
import re
import sqlite3
from difflib import get_close_matches
from sql_metadata import Parser
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import time
# record time
start_time = time.time()




def batch_generate_with_vllm(prompts, llm, sampling_params):
    outputs = llm.generate(prompts, sampling_params=sampling_params)
    return [output.outputs[0].text.strip() for output in outputs]
# Database schema transformation and filtering functions
def get_closest_table_name(cursor, table_name):
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    available_tables = [row[0] for row in cursor.fetchall()]
    closest_matches = get_close_matches(table_name, available_tables, n=2, cutoff=0.6)
    return closest_matches[0] if closest_matches else None

def create_sql_generation_correct_tables(dataset, question, query, db_uri, correct_tables):
    connection = sqlite3.connect(db_uri)
    cursor = connection.cursor()
    
    correct_columns = Parser(query).columns
    database_schema_filtered = ""
    all_tables = get_all_table_names(db_uri)
    for table in reversed(correct_tables.split(",")):
        closest_table_name = get_closest_table_name(cursor, table)
        if closest_table_name:
            database_schema_filtered += get_table_schema_with_samples(db_uri, closest_table_name)
            database_schema_filtered += "\n"
        else:
            print(f"Warning: No close match found for table {table} in database.")
    database_schema = ""
    for table in all_tables:
        database_schema += get_table_schema_with_samples(db_uri, table)
        database_schema += "\n"
    
    dataset.append({
        "db_id": db_uri.split("/")[-1].split(".")[0],
        "question": question,
        "query": query,
        "filtered_database_schema": database_schema_filtered,
        "database_schema": database_schema,
        "correct_tables": ", ".join(correct_tables.split(",")),
        "correct_columns": ", ".join(correct_columns),
    })
    connection.close()
    return dataset


def main_stage1():
    df = pd.read_csv("./validation/spider_syn_dataset.csv")
    #Model path
    llm = LLM(model="./stage1", gpu_memory_utilization=0.5)
    sampling_params = SamplingParams(temperature=0.1, top_p=1.0, repetition_penalty=1.05, max_tokens=256)

    # Generate a list of all the inputs
    prompts_stage1 = [
        f""" Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.
{row['database_schema']}
###
Question: {row['question']}
"""
        for index, row in df.iterrows()
    ]

    # Performing batch inference
    responses = batch_generate_with_vllm(prompts_stage1, llm, sampling_params)

    
    BASE_DATABASES_DIR = "./spider/test_database"
    filtered_dataset = []
    for response, row in zip(responses, df.itertuples()):
        response = response.split("Tables:")[1].split(";")[0] if "Tables:" in response else ""
        response = re.sub(r'\s+', ' ', response).strip()
        db_uri = f"{BASE_DATABASES_DIR}/{row.db_id}/{row.db_id}.sqlite"
        filtered_dataset = create_sql_generation_correct_tables(
            filtered_dataset, row.question, row.query, db_uri, response
        )

    # Save the processed data set
    pd.DataFrame(filtered_dataset).to_csv('useful_val_dataset3.csv', index=False)
    

    print("Stage 1 completed and data is saved.")

if __name__ == "__main__":
    main_stage1()


INFO 04-25 13:21:08 llm_engine.py:74] Initializing an LLM engine (v0.4.0.post1) with config: model='/hpc2hdd/home/jzhao815/6000E_code/stage1', tokenizer='/hpc2hdd/home/jzhao815/6000E_code/stage1', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=65536, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 04-25 13:21:09 selector.py:16] Using FlashAttention backend.
INFO 04-25 13:21:20 model_runner.py:104] Loading model weights took 13.5516 GB
INFO 04-25 13:21:24 gpu_executor.py:94] # GPU blocks: 18725, # CPU blocks: 4096
INFO 04-25 13:21:27 model_runner.py:791] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 04-25 13:21:27 model_ru

Processed prompts: 100%|██████████| 1034/1034 [02:09<00:00,  8.00it/s]


Stage 1 completed and data is saved.


In [2]:
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import time



def batch_generate_with_vllm(prompts, llm, sampling_params):
    outputs = llm.generate(prompts, sampling_params=sampling_params)
    return [output.outputs[0].text.strip() for output in outputs]

def main_stage2():
    df = pd.read_csv("./useful_val_dataset4.csv")
    #modelpath
    llm = LLM(model="./stage2", gpu_memory_utilization=0.4)
    sampling_params = SamplingParams(temperature=0.1, top_p=1.0, repetition_penalty=1.05, max_tokens=256)

    prompts_stage2 = [
        f"""Given the following SQL tables, your job is to generate the Sqlite SQL query given the user's question. Please pay special attention to the choice of table names and column names.
Put your answer inside the ```sql and ``` tags.
{row['filtered_database_schema']}
###
Question: {row['question']}
"""
        for index, row in df.iterrows()
    ]

    # Performing batch inference
    responses = batch_generate_with_vllm(prompts_stage2, llm, sampling_params)

    results = []
    for response, row in zip(responses, df.itertuples()):
        if ";" in response:
            response = response.split(";")[0]
        if "```sql" in response:
            response = response.split("```sql")[1]
        response = re.sub(r'\s+', ' ', response).strip()

        # print("\n")
        # print(response)
        # print(row.query)
        # print("============================")
        results.append([response, row.query, row.question, row.db_id])

    new_df = pd.DataFrame(results, columns=['generated_query', 'reference_query', 'question', 'db_id'])

    # The generated query is written to a SQL file for evaluation
    with open("Predicted.sql", "w") as f:
        for query in new_df["generated_query"]:
            f.write(query + ";\n")

    print("The generated query has been written to the Predicted.sql file.")

if __name__ == "__main__":
    main_stage2()
    
#     end_time = time.time()

#     total_time = end_time - start_time
#     print(f"run time：{total_time}秒")


INFO 04-26 08:28:53 llm_engine.py:74] Initializing an LLM engine (v0.4.0.post1) with config: model='/hpc2hdd/home/jzhao815/6000E_code/stage2', tokenizer='/hpc2hdd/home/jzhao815/6000E_code/stage2', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=65536, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 04-26 08:28:54 selector.py:16] Using FlashAttention backend.
INFO 04-26 08:29:03 model_runner.py:104] Loading model weights took 13.5516 GB
INFO 04-26 08:29:08 gpu_executor.py:94] # GPU blocks: 10617, # CPU blocks: 4096
INFO 04-26 08:29:11 model_runner.py:791] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 04-26 08:29:11 model_ru

Processed prompts: 100%|██████████| 2147/2147 [03:32<00:00, 10.10it/s]

The generated query has been written to the Predicted.sql file.





Evaluation 

In [3]:
import subprocess

script_path = 'eval/evaluation.py'
sql_file_path = 'Predicted.sql'


command = ['python', script_path, '--input', sql_file_path]

# run eval
result = subprocess.run(command, capture_output=True, text=True)

print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)

STDOUT: params as fllows 
 Namespace(input='Predicted.sql', gold='eval/eval_data/gold.txt', db='spider/database', table='eval/eval_data/tables.json', etype='all', plug_value=False, keep_distinct=False, progress_bar_for_each_datapoint=False, natsql=False, pred='Predicted.sql')
gseq_one length  1034
pseq_one length 2147
compare pred idx 0
easy pred: SELECT COUNT(*) FROM club;
easy gold: SELECT count(*) FROM singer

compare pred idx 1
easy pred: SELECT COUNT(*) FROM club;
easy gold: SELECT count(*) FROM singer

compare pred idx 2
medium pred: SELECT name FROM club ORDER BY name ASC;
medium gold: SELECT name ,  country ,  age FROM singer ORDER BY age DESC

compare pred idx 3
medium pred: SELECT name FROM club ORDER BY name ASC;
medium gold: SELECT name ,  country ,  age FROM singer ORDER BY age DESC

compare pred idx 4
medium pred: SELECT manager , captain FROM club;
medium gold: SELECT avg(age) ,  min(age) ,  max(age) FROM singer WHERE country  =  'France'

compare pred idx 5
medium pred: