# Text2sql  Evaluate Finetune V2 with Execution Accurrcy

In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl (69.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.2


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import json
import sqlite3
import sqlparse
import os

In [3]:
torch.cuda.is_available()

True

In [4]:
available_memory = torch.cuda.get_device_properties(0).total_memory
available_memory

15828320256

# Load our model from Huggingface and it's Tokenizer 

In [5]:
model_name = "khalifa1/Sql_LLama2_V3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 4 bits – this is slower and less accurate
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_4bit=True,
        device_map="auto",
        use_cache=True,
    )

tokenizer_config.json:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/579 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/720 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

## Load Spider dataset and it's database 

In [19]:
spider_db_path = "/kaggle/input/yale-universitys-spider-10-nlp-dataset/spider/database"
spider_eval_path = "/kaggle/input/yale-universitys-spider-10-nlp-dataset/spider/train_others.json"

with open(spider_eval_path, "r") as f:
    data = json.load(f)
data[0] 

{'db_id': 'geo',
 'query': 'SELECT city_name FROM city WHERE population  =  ( SELECT MAX ( population ) FROM city WHERE state_name  =  "wyoming" ) AND state_name  =  "wyoming";',
 'query_toks': ['SELECT',
  'city_name',
  'FROM',
  'city',
  'WHERE',
  'population',
  '=',
  '(',
  'SELECT',
  'MAX',
  '(',
  'population',
  ')',
  'FROM',
  'city',
  'WHERE',
  'state_name',
  '=',
  '``',
  'wyoming',
  "''",
  ')',
  'AND',
  'state_name',
  '=',
  '``',
  'wyoming',
  "''",
  ';'],
 'query_toks_no_value': ['select',
  'city_name',
  'from',
  'city',
  'where',
  'population',
  '=',
  '(',
  'select',
  'max',
  '(',
  'population',
  ')',
  'from',
  'city',
  'where',
  'state_name',
  '=',
  'value',
  ')',
  'and',
  'state_name',
  '=',
  'value'],
 'question': 'what is the biggest city in wyoming',
 'question_toks': ['what', 'is', 'the', 'biggest', 'city', 'in', 'wyoming'],
 'sql': {'except': None,
  'from': {'conds': [], 'table_units': [['table_unit', 1]]},
  'groupBy': [],

##  Database connection setup

In [20]:
def get_conn(db_id: str):
    db_file = f"{db_id}.sqlite"
    db_path = os.path.join(spider_db_path, db_id, db_file)
    conn = sqlite3.connect(db_path)
    return conn

def get_db_schema(conn) -> str:
    res = conn.execute("SELECT * FROM sqlite_master").fetchall()
    schema = ""
    for d in res:
        if d[-1] is None:
            continue    
        schema += f"{d[-1]}\n\n"
    return schema

## inference function generate the Query from ques and Schema 

In [29]:
def inference(question: str, schema: str) -> str:
    prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
given a sqlite database schema.
### Instructions:
Adhere to these rules:
- **Deliberately analyze the question and schema**
- **Use Table Aliases** to prevent ambiguity
- **Cast numerators as float for ratios**
- **Use only one column per query**

### Database Schema
The query will run on a database with the following schema:
{schema}

### Response:The following SQL query best answers the question `{question}`:
```sql
""".format(question=question, schema=schema)

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    
    # Post-processing
    postgres_query = outputs[0].split("```sql")[-1].rstrip("```")
    return postgres_to_sqlite(postgres_query)

def postgres_to_sqlite(query: str) -> str:
    substitutions = [
        (r'ilike', 'LIKE'),
        (r'serial\s*$', 'INTEGER PRIMARY KEY AUTOINCREMENT'),
        (r'start\s+with\s+(\d+)', 'CHECK (id >= \\1)'),
    ]
    for pattern, replacement in substitutions:
        query = re.sub(pattern, replacement, query, flags=re.IGNORECASE)
    return query

## Evaluation script

In [30]:
def compare_sql(gold: str, gen: str, conn, is_ordered=False):
    try:
        gold_res = pd.read_sql(gold, conn)
    except Exception as _:
        return 1

    try:
        gen_res = pd.read_sql(gen, conn)
    except Exception as _:
        print("[Gen Fail]", gen)
        return 0

    accuracy = 0
    if (len(gold_res)) == 0:
        return 1 if len(gen_res) == 0 else 0
    
    gold_len = len(gold_res)
    gen_len = len(gen_res)
    for i in range(min(gold_len, gen_len)):
        gold_record = gold_res.values[i]
        
        if not is_ordered:
            try:
                is_match = gold_record in gen_res.values
            except:
                is_match = False
        else:
            is_match = gold_record == gen_res.values[i]
            if (type(is_match) != bool):
                is_match = is_match.all()
                
        if is_match:
            accuracy += 1

    return accuracy / len(gold_res)
 


##  Main evaluation loop

In [33]:
NUM_EVAL =300
total_accuracy = 0

for i, h in enumerate(data[:NUM_EVAL]):
    db_id = h['db_id']
    conn = get_conn(db_id)
    schema = get_db_schema(conn)
    question = h['question']
    gen = inference(question, schema)
    gold = h['query']
    
    sub_accuracy = compare_sql(gold, gen, conn)
    total_accuracy += sub_accuracy

    
    print(f"[{i}]")
    print(f"Question: {question}")
    print(f"Gold SQL: {sqlparse.format(gold, reident=True)}")
    print(f"Generated SQL: {sqlparse.format(gen, reident=True)}\n")

final_accuracy = total_accuracy / NUM_EVAL
print(f"Final Execution Accuracy: {final_accuracy * 100:.2f}%")

[0]
Question: what is the adjacent state of kentucky
Gold SQL: SELECT border FROM border_info WHERE state_name  =  "kentucky";
Generated SQL: 
 SELECT s.state_name FROM border_info bi JOIN state s ON bi.border = s.state_name WHERE bi.state_name = 'Kentucky';

[1]
Question: name all the rivers in illinois
Gold SQL: SELECT river_name FROM river WHERE traverse  =  "illinois";
Generated SQL: 
 SELECT r.river_name FROM river r WHERE r.traverse = 'Illinois';

[2]
Question: rivers in illinois
Gold SQL: SELECT river_name FROM river WHERE traverse  =  "illinois";
Generated SQL: 
 SELECT r.river_name FROM river r WHERE r.traverse = 'Illinois';

[3]
Question: what are all the rivers in illinois
Gold SQL: SELECT river_name FROM river WHERE traverse  =  "illinois";
Generated SQL: 
 SELECT r.river_name FROM river r WHERE r.traverse = 'Illinois';

[4]
Question: what are the rivers in illinois
Gold SQL: SELECT river_name FROM river WHERE traverse  =  "illinois";
Generated SQL: 
 SELECT r.river_name FR