In [1]:
pip install transformers accelerate bitsandbytes sentencepiece pandas datasets huggingface_hub tqdm

Note: you may need to restart the kernel to use updated packages.


In [3]:
# --- Standard Library Imports ---
# --- Third-party Library Imports ---
# --- Third-party Library Imports ---
import torch

from tqdm.auto import tqdm
import time
from huggingface_hub import login
import transformers # <--- ADD THIS LINE
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# --- Third-party Library Imports ---
import torch
from tqdm.auto import tqdm # For progress bars
from huggingface_hub import login # For Hugging Face Hub authentication
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

print("--- Cell 1: Imports and Initial Configuration Complete ---")
print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")

--- Cell 1: Imports and Initial Configuration Complete ---
PyTorch Version: 2.2.0
Transformers Version: 4.52.4


In [4]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version PyTorch compiled with: {torch.version.cuda}")
    print(f"Number of GPUs available to PyTorch: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("ERROR: PyTorch cannot see the GPUs! Check installation and CUDA compatibility.")

PyTorch version: 2.2.0
CUDA available: True
CUDA version PyTorch compiled with: 11.8
Number of GPUs available to PyTorch: 8
  GPU 0: NVIDIA A100-SXM4-80GB
  GPU 1: NVIDIA A100-SXM4-80GB
  GPU 2: NVIDIA A100-SXM4-80GB
  GPU 3: NVIDIA A100-SXM4-80GB
  GPU 4: NVIDIA A100-SXM4-80GB
  GPU 5: NVIDIA A100-SXM4-80GB
  GPU 6: NVIDIA A100-SXM4-80GB
  GPU 7: NVIDIA A100-SXM4-80GB


In [5]:
# --- Standard Library Imports ---
# --- Third-party Library Imports ---
# --- Third-party Library Imports ---
import torch
from tqdm.auto import tqdm
import time
from huggingface_hub import login
import transformers # <--- ADD THIS LINE
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# --- Third-party Library Imports ---
import torch
from tqdm.auto import tqdm # For progress bars
from huggingface_hub import login # For Hugging Face Hub authentication
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

print("--- Cell 1: Imports and Initial Configuration Complete ---")
print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")

--- Cell 1: Imports and Initial Configuration Complete ---
PyTorch Version: 2.2.0
Transformers Version: 4.52.4


In [6]:
# --- Model and Tokenizer Configuration ---
import os

# 3.1. Specify the Llama 2 70B Chat Model
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
print(f"Target Model: {MODEL_NAME}")

# 3.2. Configure 4-bit Quantization (essential for 70B, even on A100s for single/few GPU use)
# A100s support bfloat16, which is excellent for mixed-precision.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",        # nf4 is a good default
    bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for computation on A100s
    bnb_4bit_use_double_quant=True,   # Can save a bit more memory
)
print(f"BitsAndBytesConfig: load_in_4bit={bnb_config.load_in_4bit}, compute_dtype={bnb_config.bnb_4bit_compute_dtype}")

# 3.4. Define Cache Directory for Hugging Face downloads (optional, but good for managing large models)
# Create it within your project directory on the A100 server.
HF_MODEL_CACHE_DIR = os.path.join(os.getcwd(), ".hf_model_cache_70b") # Assumes current dir is project root
os.makedirs(HF_MODEL_CACHE_DIR, exist_ok=True)
print(f"Hugging Face model cache directory set to: {HF_MODEL_CACHE_DIR}")

print("\n--- Cell 3: Model and Prompt Configuration Complete ---")

Target Model: meta-llama/Llama-3.1-8B-Instruct
BitsAndBytesConfig: load_in_4bit=True, compute_dtype=torch.bfloat16
Hugging Face model cache directory set to: /raid/infolab/gaurav/Llama_Spider_A100_Project/experiments_70b_llama/.hf_model_cache_70b

--- Cell 3: Model and Prompt Configuration Complete ---


In [7]:
# --- Load the Tokenizer ---
# The tokenizer converts text into numerical IDs that the model understands, and vice-versa.
# It's crucial that the tokenizer matches the model it was trained with.
print(f"Loading tokenizer for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    # token=HF_TOKEN # For recent versions of transformers, login() handles global auth.
                     # You might need this for older versions or specific configurations.
    trust_remote_code=True # Some models require this if they have custom code. Llama 2 generally doesn't, but good to be aware of.
)

# Llama models often don't have a pad token defined by default.
# We set it to the EOS (End Of Sentence) token if it's not present.
# This is important for batching inputs of different lengths, though for our P(Yes)
# extraction (one prompt at a time), it's less critical but good practice.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Tokenizer pad_token was None, set to eos_token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")

print("Tokenizer loaded successfully.")
print(f"Tokenizer pad token ID: {tokenizer.pad_token_id}")
print(f"Tokenizer EOS token ID: {tokenizer.eos_token_id}")
print(f"Tokenizer BOS token ID: {tokenizer.bos_token_id}")

Loading tokenizer for meta-llama/Llama-3.1-8B-Instruct...
Tokenizer pad_token was None, set to eos_token: <|eot_id|> (ID: 128009)
Tokenizer loaded successfully.
Tokenizer pad token ID: 128009
Tokenizer EOS token ID: 128009
Tokenizer BOS token ID: 128000


In [8]:
import gc
import time
from transformers import AutoModelForCausalLM

print(f"Loading model: {MODEL_NAME} with 4-bit quantization on GPU 7... This will take significant time and memory...")
model_load_start_time = time.time()

try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,       # Apply 4-bit quantization
        torch_dtype=torch.bfloat16,           # Use bfloat16 on A100s
        device_map={"": 7},                   # 🔧 Manually assign everything to GPU 1
        trust_remote_code=True,               # Required for some models
        cache_dir=HF_MODEL_CACHE_DIR
    )
    model_load_end_time = time.time()
    print("\nModel loaded successfully on GPU 1!")
    print(f"Time taken: {model_load_end_time - model_load_start_time:.2f} seconds.")
    print(f"Model device map: {model.hf_device_map}")  # Should show everything on device 1

    # Optional: Clean up memory
    torch.cuda.empty_cache()
    gc.collect()
    print("Performed memory cleanup (torch.cuda.empty_cache(), gc.collect())")

except Exception as e:
    import traceback
    traceback.print_exc()
    raise RuntimeError(f"Failed to load model {MODEL_NAME} on GPU 7: {e}. Check VRAM, CUDA setup, and Hugging Face authentication.")

print("\n--- Cell 5: Llama 3.1 8B Instruct Model Loading Complete ---")

print("Model max_position_embeddings:", model.config.max_position_embeddings)
print("Tokenizer model_max_length:", tokenizer.model_max_length)


Loading model: meta-llama/Llama-3.1-8B-Instruct with 4-bit quantization on GPU 7... This will take significant time and memory...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


Model loaded successfully on GPU 1!
Time taken: 13.74 seconds.
Model device map: {'': 7}
Performed memory cleanup (torch.cuda.empty_cache(), gc.collect())

--- Cell 5: Llama 3.1 8B Instruct Model Loading Complete ---
Model max_position_embeddings: 131072
Tokenizer model_max_length: 131072


In [9]:
import zipfile
import os

SERVER_ZIP_FILE_PATH = '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip'
EXTRACTION_DESTINATION_DIR_ON_SERVER = '/raid/infolab/gaurav/Llama_Spider_A100_Project/'

# Initialize all paths to None
DEV_JSON_PATH = None
TABLES_JSON_PATH = None
TRAIN_SPIDER_JSON_PATH = None
TRAIN_OTHERS_JSON_PATH = None

def unzip_data(zip_filepath, dest_dir):
    """Unzips a zip file to a specified destination directory."""
    print(f"Attempting to unzip {zip_filepath} to {dest_dir}...")
    try:
        with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
            zip_ref.extractall(dest_dir)
        print(f"Successfully unzipped files to {dest_dir}")
        return True
    except Exception as e:
        print(f"An unexpected error occurred during unzipping: {e}")
        return False

print(f"Script started. Looking for zip file at: {SERVER_ZIP_FILE_PATH}")

if os.path.exists(SERVER_ZIP_FILE_PATH):
    print(f"Zip file found at {SERVER_ZIP_FILE_PATH}.")
    if unzip_data(SERVER_ZIP_FILE_PATH, EXTRACTION_DESTINATION_DIR_ON_SERVER):
        
        EXPECTED_EXTRACTED_FOLDER_NAME = 'spider_subset_data'
        data_folder_path = os.path.join(EXTRACTION_DESTINATION_DIR_ON_SERVER, EXPECTED_EXTRACTED_FOLDER_NAME)

        # Define paths for all four critical files
        DEV_JSON_PATH = os.path.join(data_folder_path, 'dev.json')
        TABLES_JSON_PATH = os.path.join(data_folder_path, 'tables.json')
        TRAIN_SPIDER_JSON_PATH = os.path.join(data_folder_path, 'train_spider.json')
        TRAIN_OTHERS_JSON_PATH = os.path.join(data_folder_path, 'train_others.json')

        # Verify all paths
        print("\nVerifying extracted file paths...")
        all_paths_valid = True
        for name, path in [("dev.json", DEV_JSON_PATH), 
                           ("tables.json", TABLES_JSON_PATH),
                           ("train_spider.json", TRAIN_SPIDER_JSON_PATH),
                           ("train_others.json", TRAIN_OTHERS_JSON_PATH)]:
            if os.path.exists(path):
                print(f"SUCCESS: {name} path is valid: {path}")
            else:
                print(f"ERROR: {name} NOT FOUND at expected path: {path}")
                all_paths_valid = False

        if all_paths_valid:
            print("\n--- All required data paths are set up. Ready to load data. ---")
        else:
            print("\n--- ERROR: One or more data paths are invalid. Cannot proceed. ---")
    else:
        print("Unzipping failed on the server. Cannot define data paths.")
else:
    print(f"ERROR: Zip file NOT FOUND at {SERVER_ZIP_FILE_PATH} on the server.")

Script started. Looking for zip file at: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip
Zip file found at /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip.
Attempting to unzip /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip to /raid/infolab/gaurav/Llama_Spider_A100_Project/...
Successfully unzipped files to /raid/infolab/gaurav/Llama_Spider_A100_Project/

Verifying extracted file paths...
SUCCESS: dev.json path is valid: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/dev.json
SUCCESS: tables.json path is valid: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/tables.json
SUCCESS: train_spider.json path is valid: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/train_spider.json
SUCCESS: train_others.json path is valid: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/train_others.json

--- All required data paths are set up. Ready to load data. ---


In [10]:
import json

def load_json_data(file_path):
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            return json.load(f)
    else:
        print(f"ERROR: File not found at {file_path}")
        return None

dev_data = load_json_data(DEV_JSON_PATH)
tables_data = load_json_data(TABLES_JSON_PATH)
train_spider_data = load_json_data(TRAIN_SPIDER_JSON_PATH)
train_others_data = load_json_data(TRAIN_OTHERS_JSON_PATH)

if dev_data and tables_data and train_spider_data and train_others_data:
    print(f"Loaded {len(dev_data)} queries from dev.json")
    print(f"Loaded {len(tables_data)} database schemas from tables.json")
    print(f"Loaded {len(train_spider_data)} queries from train_spider_data.json")
    print(f"Loaded {len(train_others_data)} queries from train_others_data.json")
else:
    print("Failed to load Spider data. Please check paths and upload.")

Loaded 1034 queries from dev.json
Loaded 166 database schemas from tables.json
Loaded 7000 queries from train_spider_data.json
Loaded 1659 queries from train_others_data.json


In [32]:
import re
import os

TEXT_QUERIES_FILE = "/raid/infolab/gaurav/Llama_Spider_A100_Project/experiments_70b_llama/all_dev_nl_queries.txt"

if not os.path.exists(TEXT_QUERIES_FILE):
    raise FileNotFoundError(f"Cannot find '{TEXT_QUERIES_FILE}' – make sure it’s in your working directory or update the path.")

selected_nl_queries = []

# --- CORRECTED REGEX PATTERN ---
# Using a raw string (r"...") with single backslashes for special sequences like \s and \d.
# The parenthesis are also escaped with a single backslash.
pattern = re.compile(r"Test Query\s+(\d+):\s+'(.+)'\s+\(True DB:\s*([^)]+)\)")

with open(TEXT_QUERIES_FILE, "r") as f_in:
    for line in f_in:
        line = line.strip()
        # We don't need the startswith check, the regex will handle it.
        m = pattern.match(line)
        if not m:
            # This warning will now only trigger for genuinely malformed lines.
            print(f"Warning: could not parse line:\\n  {line}")
            continue

        # Groups are now: 1: number, 2: question, 3: db_id
        question_text = m.group(2)
        true_db_id    = m.group(3)

        selected_nl_queries.append({
            "question": question_text,
            "db_id":    true_db_id
        })

if len(selected_nl_queries) == 0:
    raise ValueError(f"No queries were parsed from '{TEXT_QUERIES_FILE}'. Check your file’s format and the regex pattern.")

print(f"Loaded {len(selected_nl_queries)} queries from '{TEXT_QUERIES_FILE}':")
for i, q in enumerate(selected_nl_queries[:5], 1): # Print first 5 as a sample
    print(f"  Query {i}: '{q['question']}' (True DB: {q['db_id']})")


# --- Create a map from DB ID to a list of its real questions ---
# This will be used for dynamic few-shot example selection.
db_id_to_all_real_questions_map = {}
for query_info in selected_nl_queries:
    db_id = query_info['db_id']
    question = query_info['question']
    if db_id not in db_id_to_all_real_questions_map:
        db_id_to_all_real_questions_map[db_id] = []
    db_id_to_all_real_questions_map[db_id].append(question)

print(f"\nCreated a mapping for {len(db_id_to_all_real_questions_map)} DB IDs to their corresponding real questions.")
print(f"Example: DB 'dog_kennels' now has {len(db_id_to_all_real_questions_map.get('academic', []))} associated real questions.")

Loaded 1034 queries from '/raid/infolab/gaurav/Llama_Spider_A100_Project/experiments_70b_llama/all_dev_nl_queries.txt':
  Query 1: 'How many 'United Airlines' flights go to Airport 'ASY'?' (True DB: flight_2)
  Query 2: 'What are the name of the countries where there is not a single car maker?' (True DB: car_1)
  Query 3: 'What are the date and the operating professional's first name of each treatment?' (True DB: dog_kennels)
  Query 4: 'List each owner's first name, last name, and the size of his for her dog.' (True DB: dog_kennels)
  Query 5: 'Find the first name and age of students who have a dog but do not have a cat as a pet.' (True DB: pets_1)

Created a mapping for 20 DB IDs to their corresponding real questions.
Example: DB 'dog_kennels' now has 0 associated real questions.


In [33]:
import json
import os

# --- Helper function to load JSON safely ---
def load_json_data(file_path):
    if not file_path or not os.path.exists(file_path):
        print(f"ERROR: File not found at {file_path}. Cannot load.")
        return None
    with open(file_path, 'r') as f:
        return json.load(f)

# --- Load all data sources ---
print("\n--- Loading all Spider data sources for example pool ---")
dev_data = load_json_data(DEV_JSON_PATH)
train_spider_data = load_json_data(TRAIN_SPIDER_JSON_PATH)
train_others_data = load_json_data(TRAIN_OTHERS_JSON_PATH)
tables_data = load_json_data(TABLES_JSON_PATH)

# Proceed only if all data was loaded successfully
if all([dev_data, train_spider_data, train_others_data, tables_data]):
    
    # --- Combine all question sources into one large pool ---
    # This is the key step: creating a rich source for our few-shot examples.
    all_spider_queries_data = dev_data + train_spider_data + train_others_data
    
    print(f"\nLoaded {len(dev_data)} queries from dev.json")
    print(f"Loaded {len(train_spider_data)} queries from train_spider.json")
    print(f"Loaded {len(train_others_data)} queries from train_others.json")
    print(f"-> Total queries in example pool: {len(all_spider_queries_data)}")
    print(f"Loaded {len(tables_data)} database schemas from tables.json")

    # --- Create the comprehensive map from DB ID to all its real questions ---
    db_id_to_all_real_questions_map = {}
    for query_info in all_spider_queries_data:
        db_id = query_info['db_id']
        question = query_info['question']
        if db_id not in db_id_to_all_real_questions_map:
            db_id_to_all_real_questions_map[db_id] = []
        db_id_to_all_real_questions_map[db_id].append(question)

    print(f"\nSuccessfully created a mapping for {len(db_id_to_all_real_questions_map)} DB IDs to their corresponding real questions.")
    print(f"This is a significant increase from using only the dev set.")
    print(f"Example: DB 'academic' now has {len(db_id_to_all_real_questions_map.get('academic', []))} associated real questions.")
    print(f"Example: DB 'sakila_1' (from train_others) now has {len(db_id_to_all_real_questions_map.get('sakila_1', []))} associated real questions.")

else:
    print("\n--- ERROR: Failed to load one or more Spider data files. Please check paths and file integrity. ---")


--- Loading all Spider data sources for example pool ---

Loaded 1034 queries from dev.json
Loaded 7000 queries from train_spider.json
Loaded 1659 queries from train_others.json
-> Total queries in example pool: 9693
Loaded 166 database schemas from tables.json

Successfully created a mapping for 166 DB IDs to their corresponding real questions.
This is a significant increase from using only the dev set.
Example: DB 'academic' now has 181 associated real questions.
Example: DB 'sakila_1' (from train_others) now has 82 associated real questions.


In [34]:
print("\n--- Listing All Questions Associated with Each Database ---")

sorted_db_ids = sorted(db_id_to_all_real_questions_map.keys())
print(len(sorted_db_ids))

for db_id in sorted_db_ids:
    questions = db_id_to_all_real_questions_map[db_id]
    print(f"\nDatabase: {db_id} ({len(questions)} questions)")
    for i, question in enumerate(questions, 1):
        print(f"  {i}. {question}")


--- Listing All Questions Associated with Each Database ---
166

Database: academic (181 questions)
  1. return me the homepage of PVLDB .
  2. return me the homepage of " H. V. Jagadish " .
  3. return me the abstract of " Making database systems usable " .
  4. return me the year of " Making database systems usable "
  5. return me the year of " Making database systems usable " .
  6. return me the papers after 2000 .
  7. return me the homepage of the VLDB conference .
  8. return me all the keywords .
  9. return me all the organizations .
  10. return me all the organizations in " North America " .
  11. return me the homepage of " University of Michigan " .
  12. return me the number of references of " Making database systems usable " .
  13. return me the references of " Making database systems usable " .
  14. return me the number of citations of " Making database systems usable " .
  15. return me the citations of " Making database systems usable " .
  16. return me the paper

In [35]:
import json
import os

# Define data directory and file paths
SPIDER_DATA_DIR = '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data'
spider_tables_json_path = os.path.join(SPIDER_DATA_DIR, 'tables.json')
# llm_examples_path = os.path.join(SPIDER_DATA_DIR, 'llm_generated_schema_examples.json')

# Load the LLM-generated examples map
# with open(llm_examples_path, 'r') as f:
#     db_id_to_questions_map = json.load(f)

# Load the base schema structures
with open(spider_tables_json_path, 'r') as f:
    raw_schemas = json.load(f)
# Map by database ID for easy lookup
all_db_schemas_data_loaded = {db_info['db_id']: db_info for db_info in raw_schemas}

# Initialize the container for enriched SQL strings
all_db_schemas_sql_strings = {}


In [36]:
# --- Integration of CODES Prompt Construction (Now for BASE Schemas) ---

from tqdm.auto import tqdm # Ensure tqdm is imported for the progress bar
import json
import os
import sqlite3 # <-- Import the sqlite3 library

print("--- Building Base Schema Prompts (Paper-Exact Column Format, NO examples) ---")

# --- Helper Functions for Prompt Construction (Unchanged) ---

def map_spider_type_to_sql_type(spider_type, is_pk_or_fk=False):
    spider_type = spider_type.lower()
    if spider_type == "text": return "text"
    if spider_type == "number": return "integer" if is_pk_or_fk else "real"
    if spider_type == "time": return "datetime"
    if spider_type == "boolean": return "boolean"
    return "text"

def get_representative_values(cursor, table_name, column_name):
    try:
        query = f'SELECT DISTINCT "{column_name}" FROM "{table_name}" WHERE "{column_name}" IS NOT NULL LIMIT 2'
        cursor.execute(query)
        rows = cursor.fetchall()
        values = [str(row[0]) for row in rows]
        return ", ".join(values) if values else "N/A"
    except sqlite3.OperationalError:
        return "N/A"

def schema_filter_placeholder(db_schema):
    return db_schema['table_names_original']

def value_retriever_placeholder(nl_query, db_id):
    return {}

# --- MODIFIED: This function now only builds the schema structure ---
def construct_base_schema_prompt(db_id, all_schemas_data, db_dir):
    """
    Constructs a database prompt string with tables, columns, and foreign keys.
    It does NOT include few-shot examples.
    """
    if db_id not in all_schemas_data:
        return f"-- Database ID '{db_id}' not found."

    db_path = os.path.join(db_dir, db_id, f"{db_id}.sqlite")
    if not os.path.exists(db_path):
        return f"-- Database file not found at: {db_path}"

    db_schema = all_schemas_data[db_id]
    prompt_parts = []
    conn = None
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        column_info_map = {
            i: {"name": c_name, "table_index": t_idx, "type": db_schema['column_types'][i]}
            for i, (t_idx, c_name) in enumerate(db_schema['column_names_original']) if c_name != "*"
        }
        
        relevant_tables = schema_filter_placeholder(db_schema)

        for table_idx, table_name in enumerate(db_schema['table_names_original']):
            if table_name not in relevant_tables:
                continue

            column_defs = []
            for col_idx, col_info in column_info_map.items():
                if col_info['table_index'] == table_idx:
                    prefixed_col_name = f"{table_name}.{col_info['name']}"
                    col_parts_inside_parentheses = []
                    is_pk_or_fk = col_idx in db_schema['primary_keys'] or any(fk[0] == col_idx for fk in db_schema['foreign_keys'])
                    col_parts_inside_parentheses.append(map_spider_type_to_sql_type(col_info['type'], is_pk_or_fk))
                    if col_idx in db_schema['primary_keys']:
                        col_parts_inside_parentheses.append("primary key")
                    rep_values = get_representative_values(cursor, table_name, col_info['name'])
                    col_parts_inside_parentheses.append(f"values: {rep_values}")
                    final_column_string = f"{prefixed_col_name} ( { ' | '.join(col_parts_inside_parentheses)} )"
                    column_defs.append(final_column_string)
            
            prompt_parts.append(f"table {table_name}, columns = [ {', '.join(column_defs)} ]")

        if db_schema['foreign_keys']:
            prompt_parts.append("foreign keys:")
            table_info_map = {i: name for i, name in enumerate(db_schema['table_names_original'])}
            for fk_col_idx, ref_col_idx in db_schema['foreign_keys']:
                fk_table_name = table_info_map[column_info_map[fk_col_idx]['table_index']]
                fk_col_name = column_info_map[fk_col_idx]['name']
                ref_table_name = table_info_map[column_info_map[ref_col_idx]['table_index']]
                ref_col_name = column_info_map[ref_col_idx]['name']
                prompt_parts.append(f"{fk_table_name}.{fk_col_name} = {ref_table_name}.{ref_col_name}")

    except Exception as e:
        print(f"ERROR processing db '{db_id}': {e}")
        return f"-- Error generating prompt for db {db_id}."
    finally:
        if conn:
            conn.close()

    return "\n".join(prompt_parts)

# --- Generate the new BASE prompts for all databases ---
SPIDER_DATA_DIR = '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data'
DATABASE_DIR = os.path.join(SPIDER_DATA_DIR, 'database')

all_db_schemas_base_prompts = {}
if 'all_db_schemas_data_loaded' in globals() and all_db_schemas_data_loaded:
    print(f"Found prerequisites. Generating base prompts using databases from: {DATABASE_DIR}")
    for db_id in tqdm(all_db_schemas_data_loaded.keys(), desc="Generating Base Schema Prompts"):
        all_db_schemas_base_prompts[db_id] = construct_base_schema_prompt(
            db_id, all_db_schemas_data_loaded, DATABASE_DIR
        )
    print(f"\nSuccessfully generated {len(all_db_schemas_base_prompts)} base schema prompts.")
    
    # # --- Verification Step ---
    # print("\n--- Verification of a Base Schema Prompt ---")
    # db_to_verify = 'perpetrator'
    # if db_to_verify in all_db_schemas_base_prompts:
    #     print(f"Generated base prompt for '{db_to_verify}':")
    #     print(all_db_schemas_base_prompts[db_to_verify])
    # else:
    #     print(f"Could not find schema for '{db_to_verify}' to verify.")

else:
    print("ERROR: Prerequisite data ('all_db_schemas_data_loaded') not found. Please run the previous cells.")

--- Building Base Schema Prompts (Paper-Exact Column Format, NO examples) ---
Found prerequisites. Generating base prompts using databases from: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/database


Generating Base Schema Prompts:   0%|          | 0/166 [00:00<?, ?it/s]


Successfully generated 166 base schema prompts.


In [37]:
import os
import json 
LOCAL_EXPERIMENT_BASE_DIR = "/raid/infolab/gaurav/Llama_Spider_A100_Project/"

EXPERIMENT_RUN_NAME = "randomQ_allDBs_run1" 
EXPERIMENT_PROJECT_DIR = os.path.join(LOCAL_EXPERIMENT_BASE_DIR, EXPERIMENT_RUN_NAME)

try:
    os.makedirs(EXPERIMENT_PROJECT_DIR, exist_ok=True)
    print(f"Ensured experiment project directory exists: '{EXPERIMENT_PROJECT_DIR}'")
except OSError as e:
    print(f"Error creating directory {EXPERIMENT_PROJECT_DIR}: {e}")
    EXPERIMENT_PROJECT_DIR = "." 


RESULTS_FILENAME = "spider_queries_llama3.1_8B-instruct-prompt_codeS_real_examples_all_db.json"
EXPERIMENT_RESULTS_FILE = os.path.join(EXPERIMENT_PROJECT_DIR, RESULTS_FILENAME)

print(f"Experiment results will be saved to: {EXPERIMENT_RESULTS_FILE}")

Ensured experiment project directory exists: '/raid/infolab/gaurav/Llama_Spider_A100_Project/randomQ_allDBs_run1'
Experiment results will be saved to: /raid/infolab/gaurav/Llama_Spider_A100_Project/randomQ_allDBs_run1/spider_queries_llama3.1_8B-instruct-prompt_codeS_real_examples_all_db.json


In [38]:
# Cell defining get_one_zero_token_ids

def get_one_zero_token_ids(tokenizer_arg):
    """
    Determines the single token IDs for the characters '1' and '0'.
    It's crucial these are single tokens for the logit logic to work.
    """
    # Llama 3 tokenizer handles single digits with leading spaces well.
    one_token_id = tokenizer_arg.encode(" 1", add_special_tokens=False)
    zero_token_id = tokenizer_arg.encode(" 0", add_special_tokens=False)

    if len(one_token_id) == 1 and len(zero_token_id) == 1:
        print("Using ' 1' and ' 0' (with leading space) for token IDs.")
        return one_token_id[0], zero_token_id[0]
    
    # Fallback in case the model prefers no space (less common for instruction models)
    one_token_id_no_space = tokenizer_arg.encode("1", add_special_tokens=False)
    zero_token_id_no_space = tokenizer_arg.encode("0", add_special_tokens=False)

    if len(one_token_id_no_space) == 1 and len(zero_token_id_no_space) == 1:
        print("Using '1' and '0' (no leading space) for token IDs.")
        return one_token_id_no_space[0], zero_token_id_no_space[0]
        
    else:
        # If neither works, there's a problem with the tokenizer for this task.
        print(f"ERROR: Could not determine reliable single token IDs for '1' or '0'.")
        print(f"Tokenization of ' 1': {one_token_id}")
        print(f"Tokenization of ' 0': {zero_token_id}")
        print(f"Tokenization of '1': {one_token_id_no_space}")
        print(f"Tokenization of '0': {zero_token_id_no_space}")
        raise ValueError("Unstable tokenization for '1'/'0'. Cannot proceed.")

print("Helper function 'get_one_zero_token_ids' defined.")

Helper function 'get_one_zero_token_ids' defined.


In [39]:
# This cell defines the global token IDs for '1' and '0'
if 'tokenizer' in globals() and tokenizer is not None:
    try:
        ONE_TOKEN_ID, ZERO_TOKEN_ID = get_one_zero_token_ids(tokenizer)
        print(f"ONE_TOKEN_ID: {ONE_TOKEN_ID} ('{tokenizer.decode([ONE_TOKEN_ID])}')")
        print(f"ZERO_TOKEN_ID: {ZERO_TOKEN_ID} ('{tokenizer.decode([ZERO_TOKEN_ID])}')")
    except ValueError as e:
        print(f"Error defining 1/0 token IDs: {e}")
else:
    print("ERROR: 'tokenizer' is not defined. Cannot define token IDs.")

Using '1' and '0' (no leading space) for token IDs.
ONE_TOKEN_ID: 16 ('1')
ZERO_TOKEN_ID: 15 ('0')


In [28]:
import torch # Ensure torch is imported
import os    # <-- Add this import for file operations

# --- Core function to get P(1) and the binary decision ---
# <-- Add query_id and db_id as optional arguments for file naming
def get_schema_match_prediction(model_arg, tokenizer_arg, system_prompt_arg, user_prompt_content_arg, one_token_id_arg, zero_token_id_arg, query_id=None, db_id=None, max_length=model.config.max_position_embeddings):
    """
    Gets the model's prediction (1 or 0) and the probability score for that decision.
    Saves the full prompt to a file if query_id and db_id are provided.
    Returns a tuple: (binary_decision, probability_of_one)
    """
    messages = [
        {"role": "system", "content": system_prompt_arg},
        {"role": "user", "content": user_prompt_content_arg}
    ]

    prompt_for_model = tokenizer_arg.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # --- ADDED: Logic to save the prompt to a file ---
    # This block will only run if the necessary IDs are passed from the loop
    if query_id and db_id:
        PROMPT_LOG_DIR = "prompt_logs_all_db_examples"  # A dedicated folder for all prompts
        os.makedirs(PROMPT_LOG_DIR, exist_ok=True) # Create the folder if it doesn't exist

        # Sanitize db_id in case it contains characters invalid for filenames
        safe_db_id = db_id.replace('/', '_')
        
        filename = f"prompt_{query_id}_vs_{safe_db_id}.txt"
        filepath = os.path.join(PROMPT_LOG_DIR, filename)
        
        try:
            with open(filepath, 'w', encoding='utf-8') as f:
                f.write(prompt_for_model)
        except Exception as e:
            # This prevents a file error from crashing the whole experiment
            print(f"  WARNING: Could not write prompt to file {filepath}. Error: {e}")
    # --- END OF ADDED LOGIC ---

    # print(prompt_for_model) # The original print statement is kept

    inputs = tokenizer_arg(
        prompt_for_model,
        return_tensors="pt",
        truncation=True,
        max_length=max_length - 10
    )
    inputs = {k: v.to(model_arg.device) for k, v in inputs.items()}

    if inputs['input_ids'].shape[1] >= max_length - 10:
         print(f"Warning: Prompt for query was truncated. Length: {inputs['input_ids'].shape[1]}")

    with torch.no_grad():
        outputs = model_arg(**inputs)
        logits = outputs.logits
        # Get the logits for the very next token to be generated
        next_token_logits = logits[:, -1, :]
        
        # Get the specific logits for the '1' and '0' tokens
        logit_one = next_token_logits[:, one_token_id_arg].item()
        logit_zero = next_token_logits[:, zero_token_id_arg].item()

    # --- Make the decision based on which logit is higher ---
    binary_decision = 1 if logit_one > logit_zero else 0

    # --- Calculate the probability using the softmax function on the two logits ---
    # This correctly converts the logits into a probability score for '1'
    max_logit = max(logit_one, logit_zero)
    exp_one = torch.exp(torch.tensor(logit_one - max_logit))
    exp_zero = torch.exp(torch.tensor(logit_zero - max_logit))
    
    prob_one = exp_one / (exp_one + exp_zero)
    
    return (binary_decision, prob_one.item())

print("Core function 'get_schema_match_prediction' defined.")

Core function 'get_schema_match_prediction' defined.


In [29]:
# --- Prompt Configuration for Binary (1/0) Output (Aesthetic & Well-Structured) ---

SYSTEM_PROMPT = """
You are an expert system that determines if a natural language question can be answered using ONLY the provided database schema.
The schema includes sample values for some columns; treat these as HINTS to understand the column's content, not as a complete list.
Your task is to respond with a single character: '1' if the question is answerable, or '0' if it is not.
Do not provide any explanations or other text. Just '1' or '0'.
"""

# --- MODIFIED AESTHETIC TEMPLATE ---
# We now have two distinct placeholders: {schema_string} and {examples_section}.
# This ensures the examples are clearly separated from the schema definition for a clean, aesthetic look.
USER_PROMPT_TEMPLATE = """
# You are provided with a database schema.
[Schema:
{schema_string}
]
{examples_section}
# Task: Can the following question be answered using ONLY the schema above? Respond with 1 (Yes) or 0 (No).
Q: {nl_query}
A:
"""

print("SYSTEM_PROMPT and USER_PROMPT_TEMPLATE have been updated for a more aesthetic and structured format.")

SYSTEM_PROMPT and USER_PROMPT_TEMPLATE have been updated for a more aesthetic and structured format.


In [30]:
import json
import os
from tqdm.auto import tqdm

# This dictionary will hold all results.
experiment_all_query_results = []

# --- 1. Resume from Previous Run (if applicable) ---
if os.path.exists(EXPERIMENT_RESULTS_FILE):
    print(f"INFO: Found existing results file. Loading progress from '{EXPERIMENT_RESULTS_FILE}'")
    try:
        with open(EXPERIMENT_RESULTS_FILE, 'r') as f:
            experiment_all_query_results = json.load(f)
        print(f"Loaded results for {len(experiment_all_query_results)} queries. Resuming...")
    except json.JSONDecodeError:
        print(f"WARNING: Results file '{EXPERIMENT_RESULTS_FILE}' is corrupted. Starting from scratch.")
        experiment_all_query_results = []

completed_query_ids = {res['experiment_query_id'] for res in experiment_all_query_results}

# --- 2. Define the Schemas and Questions to be Used in the Experiment ---
# Use the base prompts; we will add examples dynamically.
if 'all_db_schemas_base_prompts' not in globals() or not all_db_schemas_base_prompts:
    raise NameError("The 'all_db_schemas_base_prompts' variable is not defined. Please run the base prompt generation cell first.")
if 'db_id_to_all_real_questions_map' not in globals():
    raise NameError("The 'db_id_to_all_real_questions_map' is not defined. Please run the query loading cell.")
    
candidate_schemas_for_evaluation = all_db_schemas_base_prompts

# --- 3. Start the Main Experiment Loop ---
print(f"\n--- Starting Experiment: {len(selected_nl_queries)} Queries vs. {len(candidate_schemas_for_evaluation)} Schemas (with Dynamic Examples)---")

# Outer loop: Iterate through each NL query
for query_idx, nl_query_info in enumerate(tqdm(selected_nl_queries, desc="Processing NL Queries")):
    current_nl_query_text = nl_query_info['question']
    true_db_id_for_query = nl_query_info['db_id']
    experiment_query_id = f"spider_dev_q{query_idx}_idx{query_idx}"

    if experiment_query_id in completed_query_ids:
        continue

    print(f"\nProcessing Query {query_idx + 1}/{len(selected_nl_queries)} (ID: {experiment_query_id}): '{current_nl_query_text}' (True DB: {true_db_id_for_query})")

    predictions_for_current_query = []

    # Inner loop: Iterate through each candidate database schema
    for candidate_db_id, base_schema_prompt in tqdm(candidate_schemas_for_evaluation.items(), desc=f"  DBs for Q:{experiment_query_id[:20]}", leave=False):
        
        # --- DYNAMIC FEW-SHOT EXAMPLE SELECTION ---
        # Find relevant, realistic examples for the current candidate schema
        relevant_examples = db_id_to_all_real_questions_map.get(candidate_db_id, [])
        
        # Mask the current test query so the model doesn't see the answer in the examples
        examples_to_use = [q for q in relevant_examples if q != current_nl_query_text]
        
        # Select up to 5 examples
        final_examples = examples_to_use[:5]
        
        # --- CONSTRUCT THE FINAL PROMPT WITH DYNAMIC EXAMPLES ---
        examples_section_string = ""
        final_schema_prompt_with_examples = base_schema_prompt
        if final_examples:
            examples_string = "\n".join([f"--{q}" for q in final_examples])
            # The heading and formatting are now part of this self-contained, aesthetic block
            examples_section_string = (
                f"\n#Here are some example questions that CAN be answered by the schema above:\n"
                f"{examples_string}\n"
            )
        
        # --- AESTHETIC CHANGE: Format the new template with separate placeholders ---
        user_prompt_content = USER_PROMPT_TEMPLATE.format(
            schema_string=base_schema_prompt,
            examples_section=examples_section_string,
            nl_query=current_nl_query_text
        )
        
        binary_decision = -1 
        p_one_score = -1.0

        try:
             binary_decision, p_one_score = get_schema_match_prediction(
                model,
                tokenizer,
                SYSTEM_PROMPT,
                user_prompt_content,
                ONE_TOKEN_ID,
                ZERO_TOKEN_ID,
                query_id=experiment_query_id,   # Pass the query ID for logging
                db_id=candidate_db_id         # Pass the DB ID for logging
            )
        except Exception as e:
            import traceback
            print(f"    ERROR: Exception during model inference for Query ID '{experiment_query_id}' with DB '{candidate_db_id}'.")
            print(f"    Exception type: {type(e).__name__}, Message: {e}")

        predictions_for_current_query.append({
            'candidate_db_id': candidate_db_id,
            'decision': binary_decision,
            'p_one_score': p_one_score
        })

    ranked_databases_for_query = sorted(predictions_for_current_query, key=lambda x: x['p_one_score'], reverse=True)

    experiment_all_query_results.append({
        'experiment_query_id': experiment_query_id,
        'nl_query_text': current_nl_query_text,
        'true_db_id': true_db_id_for_query,
        'ranked_databases_with_predictions': ranked_databases_for_query 
    })

    # --- 4. Periodic Saving of Results ---
    try:
        with open(EXPERIMENT_RESULTS_FILE, 'w') as f_out:
            json.dump(experiment_all_query_results, f_out, indent=2)
    except Exception as e:
        print(f"  ERROR: Could not save intermediate results: {e}")

# --- 5. Final Save After Loop Completion ---
print("\n--- Experiment Loop Finished ---\n")
if experiment_all_query_results:
    print(f"Processed a total of {len(experiment_all_query_results)} unique queries.")
    try:
        with open(EXPERIMENT_RESULTS_FILE, 'w') as f_out:
            json.dump(experiment_all_query_results, f_out, indent=2)
        print(f"Final results successfully saved to {EXPERIMENT_RESULTS_FILE}")
    except Exception as e:
        print(f"ERROR: Could not save the final results: {e}")
else:
    print("No results were generated. Check logs for errors.")

INFO: Found existing results file. Loading progress from '/raid/infolab/gaurav/Llama_Spider_A100_Project/randomQ_allDBs_run1/spider_queries_llama3.1_8B-instruct-prompt_codeS_real_examples_all_db.json'
Loaded results for 122 queries. Resuming...


NameError: name 'selected_nl_queries' is not defined

In [20]:
import os
import json

# Path where the evaluation summary (Recall@K results) will be saved
EVAL_RESULTS_SAVE_PATH = "recall_k_results_context_lamma-3.1-8B-instruct-prompt-codeS_real_examples_all_db.json"

# --- 4.1. Define Recall@K Calculation Function ---
def calculate_recall_at_k_metric(all_query_results_list, k_values_list):
    """
    Calculates Recall@K for a list of K values.
    Each item in all_query_results_list should be a dictionary with:
        'true_db_id': The ground truth database ID for the query.
        'ranked_databases_with_scores': A list of {'candidate_db_id': id, 'p_yes_score': score},
                                         sorted by score in descending order.
    """
    recall_counts = {k: 0 for k in k_values_list}  # Stores how many times true_db was in top K
    total_valid_queries = 0  # Queries for which we have a true_db_id

    if not all_query_results_list:
        return {k: 0.0 for k in k_values_list}, 0

    for query_result in all_query_results_list:
        true_db = query_result.get('true_db_id')
        ranked_dbs_info = query_result.get('ranked_databases_with_predictions')

        if true_db is None or ranked_dbs_info is None:
            print(f"Warning: Skipping query result due to missing 'true_db_id' or 'ranked_databases_with_scores': "
                  f"{query_result.get('experiment_query_id', 'Unknown Query')}")
            continue  # Skip if essential information is missing

        total_valid_queries += 1
        # Extract just the DB IDs from the ranked list
        ranked_db_ids_only = [item['candidate_db_id'] for item in ranked_dbs_info]

        for k in k_values_list:
            # Get the top K predicted database IDs
            top_k_predicted_dbs = ranked_db_ids_only[:k]
            if true_db in top_k_predicted_dbs:
                recall_counts[k] += 1

    # Calculate final recall percentages
    recall_percentages = {}
    if total_valid_queries > 0:
        for k in k_values_list:
            recall_percentages[k] = (recall_counts[k] / total_valid_queries) * 100.0  # As percentage
    else:
        recall_percentages = {k: 0.0 for k in k_values_list}

    return recall_percentages, total_valid_queries


# --- 4.2. Perform Evaluation ---
# Load results if this cell is run in a new session and experiment_all_query_results isn't in memory
# (assuming results were saved to EXPERIMENT_RESULTS_FILE)
loaded_results_for_eval = None
if 'experiment_all_query_results' in globals() and experiment_all_query_results:
    print("Using in-memory experiment_all_query_results for evaluation.")
    loaded_results_for_eval = experiment_all_query_results
elif os.path.exists(EXPERIMENT_RESULTS_FILE):
    print(f"Loading results from {EXPERIMENT_RESULTS_FILE} for evaluation...")
    try:
        with open(EXPERIMENT_RESULTS_FILE, 'r') as f_in:
            loaded_results_for_eval = json.load(f_in)
        print(f"Successfully loaded {len(loaded_results_for_eval)} results from file.")
    except Exception as e:
        print(f"Error loading results from file for evaluation: {e}")
else:
    print("No results available in memory or in the specified results file for evaluation.")

if loaded_results_for_eval:
    K_VALUES_TO_EVALUATE = [1, 3, 5, 10]  # Define the K values you care about
    recall_scores_map, num_queries_evaluated = calculate_recall_at_k_metric(
        loaded_results_for_eval, K_VALUES_TO_EVALUATE
    )

    print("\n--- Evaluation: Recall@K ---")
    print(f"Evaluated on {num_queries_evaluated} queries.")
    for k_val, recall_val in recall_scores_map.items():
        print(f"Recall@{k_val}: {recall_val:.2f}%")

    # --- 4.2.1. Save evaluation results to a JSON file ---
    try:
        eval_summary = {
            "num_queries_evaluated": num_queries_evaluated,
            "recall_scores": recall_scores_map
        }
        with open(EVAL_RESULTS_SAVE_PATH, 'w') as fout:
            json.dump(eval_summary, fout, indent=2)
        print(f"Saved evaluation results to '{EVAL_RESULTS_SAVE_PATH}'")
    except Exception as save_err:
        print(f"Error saving evaluation results: {save_err}")

    # --- 4.3. Optional: Print Detailed Results for a Few Queries ---
    print("\n--- Sample Detailed Query Results (Top 5 Queries) ---")
    for i, res in enumerate(loaded_results_for_eval[:5]):  # Show for first 5 queries
        print(f"\nQuery {i+1}: '{res.get('nl_query_text', '<no text>')}' (True DB: {res.get('true_db_id')})")
        print("  Top Ranked Databases (with P(Yes) scores):")
        for rank, db_info in enumerate(res.get('ranked_databases_with_scores', [])[:5]):  # Show top 5 ranked DBs
            is_true_db_char = "*" if db_info['candidate_db_id'] == res['true_db_id'] else " "
            print(f"    {rank+1}. {db_info['candidate_db_id']}{is_true_db_char} "
                  f"(Score: {db_info['p_yes_score']:.4f})")
else:
    print("Cannot perform evaluation as no results were loaded or generated.")


Loading results from /raid/infolab/gaurav/Llama_Spider_A100_Project/randomQ_allDBs_run1/spider_queries_llama3.1_8B-instruct-prompt_codeS_real_examples_all_db.json for evaluation...
Successfully loaded 120 results from file.

--- Evaluation: Recall@K ---
Evaluated on 120 queries.
Recall@1: 75.83%
Recall@3: 90.83%
Recall@5: 95.00%
Recall@10: 96.67%
Saved evaluation results to 'recall_k_results_context_lamma-3.1-8B-instruct-prompt-codeS_real_examples_all_db.json'

--- Sample Detailed Query Results (Top 5 Queries) ---

Query 1: 'How many 'United Airlines' flights go to Airport 'ASY'?' (True DB: flight_2)
  Top Ranked Databases (with P(Yes) scores):

Query 2: 'What are the name of the countries where there is not a single car maker?' (True DB: car_1)
  Top Ranked Databases (with P(Yes) scores):

Query 3: 'What are the date and the operating professional's first name of each treatment?' (True DB: dog_kennels)
  Top Ranked Databases (with P(Yes) scores):

Query 4: 'List each owner's first nam