In [1]:
pip install transformers accelerate bitsandbytes sentencepiece pandas datasets huggingface_hub tqdm

Note: you may need to restart the kernel to use updated packages.


In [2]:
  import ipywidgets
  print(f"ipywidgets version: {ipywidgets.__version__}")
  print(f"ipywidgets location: {ipywidgets.__file__}")

  import tqdm
  print(f"tqdm version: {tqdm.__version__}")
  print(f"tqdm location: {tqdm.__file__}")

ipywidgets version: 8.1.5
ipywidgets location: /raid/infolab/gaurav/Llama_Spider_A100_Project/miniconda3/envs/llama_spider_env/lib/python3.10/site-packages/ipywidgets/__init__.py
tqdm version: 4.67.1
tqdm location: /raid/infolab/gaurav/Llama_Spider_A100_Project/miniconda3/envs/llama_spider_env/lib/python3.10/site-packages/tqdm/__init__.py


In [3]:
from tqdm.auto import tqdm
import time

print("tqdm imported successfully from .auto")
my_list = list(range(3))
for i in tqdm(my_list, desc="Minimal Auto Test"):
    time.sleep(0.2)
print("Simple tqdm .auto loop completed")

tqdm imported successfully from .auto


Minimal Auto Test:   0%|          | 0/3 [00:00<?, ?it/s]

Simple tqdm .auto loop completed


In [5]:
# --- Standard Library Imports ---
# --- Third-party Library Imports ---
# --- Third-party Library Imports ---
import torch
from tqdm.auto import tqdm
import time
from huggingface_hub import login
import transformers # <--- ADD THIS LINE
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# --- Third-party Library Imports ---
import torch
from tqdm.auto import tqdm # For progress bars
from huggingface_hub import login # For Hugging Face Hub authentication
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

print("--- Cell 1: Imports and Initial Configuration Complete ---")
print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")

--- Cell 1: Imports and Initial Configuration Complete ---
PyTorch Version: 2.2.0
Transformers Version: 4.52.4


In [6]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version PyTorch compiled with: {torch.version.cuda}")
    print(f"Number of GPUs available to PyTorch: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("ERROR: PyTorch cannot see the GPUs! Check installation and CUDA compatibility.")

PyTorch version: 2.2.0
CUDA available: True
CUDA version PyTorch compiled with: 11.8
Number of GPUs available to PyTorch: 8
  GPU 0: NVIDIA A100-SXM4-80GB
  GPU 1: NVIDIA A100-SXM4-80GB
  GPU 2: NVIDIA A100-SXM4-80GB
  GPU 3: NVIDIA A100-SXM4-80GB
  GPU 4: NVIDIA A100-SXM4-80GB
  GPU 5: NVIDIA A100-SXM4-80GB
  GPU 6: NVIDIA A100-SXM4-80GB
  GPU 7: NVIDIA A100-SXM4-80GB


In [7]:
# --- Standard Library Imports ---
# --- Third-party Library Imports ---
# --- Third-party Library Imports ---
import torch
from tqdm.auto import tqdm
import time
from huggingface_hub import login
import transformers # <--- ADD THIS LINE
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# --- Third-party Library Imports ---
import torch
from tqdm.auto import tqdm # For progress bars
from huggingface_hub import login # For Hugging Face Hub authentication
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

print("--- Cell 1: Imports and Initial Configuration Complete ---")
print(f"PyTorch Version: {torch.__version__}")
print(f"Transformers Version: {transformers.__version__}")

--- Cell 1: Imports and Initial Configuration Complete ---
PyTorch Version: 2.2.0
Transformers Version: 4.52.4


In [9]:
# --- Hugging Face Hub Authentication ---
# You MUST have requested access to Llama 2 models via Meta's form on Hugging Face
# AND have your request approved.

# Option 1: If you've stored your token as an environment variable on the server
# HF_TOKEN = os.environ.get("HF_TOKEN")
# if HF_TOKEN:
#     print("Logging into Hugging Face Hub using token from environment variable...")
#     login(token=HF_TOKEN)
# else:
#     print("HF_TOKEN environment variable not set. Attempting widget login if in interactive environment, or manual CLI login might be needed.")
#     login() # Will prompt if in an environment that supports it

# Option 2: Paste token directly (less secure, use with caution)
# HF_TOKEN = "YOUR_HF_READ_TOKEN_HERE"
# login(token=HF_TOKEN)

# Option 3: Use huggingface-cli login in a server terminal beforehand (Recommended)
# If already logged in via CLI, this cell might not be strictly necessary,
# but running login() can confirm status or refresh credentials.
try:
    login() # Will use cached token or prompt if needed
    print("Hugging Face login successful or already authenticated.")
except Exception as e:
    print(f"Hugging Face login failed: {e}. Ensure you are authenticated to download Llama 2.")

print("\n--- Cell 2: Hugging Face Login Attempt Complete ---")

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Hugging Face login successful or already authenticated.

--- Cell 2: Hugging Face Login Attempt Complete ---


In [8]:
# --- Model and Tokenizer Configuration ---
import os

# 3.1. Specify the Llama 2 70B Chat Model
MODEL_NAME = "meta-llama/Llama-2-70b-chat-hf"
print(f"Target Model: {MODEL_NAME}")

# 3.2. Configure 4-bit Quantization (essential for 70B, even on A100s for single/few GPU use)
# A100s support bfloat16, which is excellent for mixed-precision.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",        # nf4 is a good default
    bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for computation on A100s
    bnb_4bit_use_double_quant=True,   # Can save a bit more memory
)
print(f"BitsAndBytesConfig: load_in_4bit={bnb_config.load_in_4bit}, compute_dtype={bnb_config.bnb_4bit_compute_dtype}")

# 3.3. Define Prompt Templates
SYSTEM_PROMPT = (
    "You are an expert data analyst. Your task is to determine if a given natural language query "
    "can be answered *solely* based on the provided database schema. "
    "Do not attempt to answer the query itself. Your entire response must be only the word 'Yes' or the word 'No'."
)

USER_PROMPT_TEMPLATE = """Database Schema:
---
{schema_string}
---
Natural Language Query: "{nl_query}"
---
Can the query be answered using *only* the provided schema and its potential contents? Answer with either "Yes" or "No".
"""
print("System and User prompt templates defined.")

# 3.4. Define Cache Directory for Hugging Face downloads (optional, but good for managing large models)
# Create it within your project directory on the A100 server.
HF_MODEL_CACHE_DIR = os.path.join(os.getcwd(), ".hf_model_cache_70b") # Assumes current dir is project root
os.makedirs(HF_MODEL_CACHE_DIR, exist_ok=True)
print(f"Hugging Face model cache directory set to: {HF_MODEL_CACHE_DIR}")

print("\n--- Cell 3: Model and Prompt Configuration Complete ---")

Target Model: meta-llama/Llama-2-70b-chat-hf
BitsAndBytesConfig: load_in_4bit=True, compute_dtype=torch.bfloat16
System and User prompt templates defined.
Hugging Face model cache directory set to: /raid/infolab/gaurav/Llama_Spider_A100_Project/experiments_70b_llama/.hf_model_cache_70b

--- Cell 3: Model and Prompt Configuration Complete ---


In [9]:
# --- Load Tokenizer and Define Yes/No Token Logic ---

# 4.1. Load Tokenizer
print(f"Loading tokenizer for {MODEL_NAME}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=HF_MODEL_CACHE_DIR)
    # Set pad token if not already set (Llama tokenizers often don't have one)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Set tokenizer.pad_token to tokenizer.eos_token ('{tokenizer.eos_token}')")
    print("Tokenizer loaded successfully.")
except Exception as e:
    raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}: {e}")


# 4.2. Define Helper Function to get Yes/No Token IDs
def get_yes_no_token_ids(tokenizer_arg):
    """Determines token IDs for 'Yes'/'No', preferring those with a leading space."""
    # Try with leading space first for chat models
    yes_variants = [" Yes", "Yes"]
    no_variants = [" No", "No"]
    
    final_yes_id = None
    final_no_id = None

    for variant in yes_variants:
        token_ids = tokenizer_arg.encode(variant, add_special_tokens=False)
        if len(token_ids) == 1:
            final_yes_id = token_ids[0]
            print(f"Found single token for '{variant}': ID {final_yes_id}")
            break
            
    for variant in no_variants:
        token_ids = tokenizer_arg.encode(variant, add_special_tokens=False)
        if len(token_ids) == 1:
            final_no_id = token_ids[0]
            print(f"Found single token for '{variant}': ID {final_no_id}")
            break

    if final_yes_id is None or final_no_id is None:
        print(f"ERROR: Could not determine reliable single token IDs for 'Yes'/'No' or variants.")
        # You might want to print detailed tokenization attempts here if this error occurs
        raise ValueError("Unstable tokenization for 'Yes'/'No'. Cannot proceed.")
    
    return final_yes_id, final_no_id

# 4.3. Define Global YES_TOKEN_ID and NO_TOKEN_ID
try:
    YES_TOKEN_ID, NO_TOKEN_ID = get_yes_no_token_ids(tokenizer)
    print(f"GLOBAL YES_TOKEN_ID: {YES_TOKEN_ID} ('{tokenizer.decode([YES_TOKEN_ID]).strip()}')")
    print(f"GLOBAL NO_TOKEN_ID: {NO_TOKEN_ID} ('{tokenizer.decode([NO_TOKEN_ID]).strip()}')")
except ValueError as e:
    raise RuntimeError(f"Failed to set YES/NO token IDs: {e}")

print("\n--- Cell 4: Tokenizer Loading and Yes/No Token ID Setup Complete ---")

Loading tokenizer for meta-llama/Llama-2-70b-chat-hf...
Set tokenizer.pad_token to tokenizer.eos_token ('</s>')
Tokenizer loaded successfully.
Found single token for 'Yes': ID 3869
Found single token for 'No': ID 1939
GLOBAL YES_TOKEN_ID: 3869 ('Yes')
GLOBAL NO_TOKEN_ID: 1939 ('No')

--- Cell 4: Tokenizer Loading and Yes/No Token ID Setup Complete ---


In [10]:
# --- Load the Llama 2 70B Model ---
# This is a memory-intensive step. `device_map="auto"` will attempt to distribute
# the model across available GPUs if one is insufficient.
# Ensure CUDA_VISIBLE_DEVICES is set in your shell if you want to restrict which GPUs are used.
import gc
print(f"Loading model: {MODEL_NAME} with 4-bit quantization. This will take significant time and memory...")
model_load_start_time = time.time()
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,    # Apply 4-bit quantization
        torch_dtype=torch.bfloat16,        # Use bfloat16 on A100s
        device_map="auto",                 # Distribute model across available GPUs automatically
        trust_remote_code=True,            # Often needed for newer models
        cache_dir=HF_MODEL_CACHE_DIR
    )
    model_load_end_time = time.time()
    print("\nModel loaded successfully!")
    print(f"Time taken to load model: {model_load_end_time - model_load_start_time:.2f} seconds.")
    print(f"Model device map: {model.hf_device_map}") # Shows how layers are distributed
    # For a 70B model, this should show parts on different GPUs if more than one is used.
    
    # Perform a quick memory cleanup after loading large model
    torch.cuda.empty_cache()
    gc.collect()
    print("Performed memory cleanup (torch.cuda.empty_cache(), gc.collect())")

except Exception as e:
    import traceback
    traceback.print_exc()
    raise RuntimeError(f"Failed to load model {MODEL_NAME}: {e}. Check VRAM, CUDA setup, and Hugging Face authentication.")

print("\n--- Cell 5: Llama 2 70B Model Loading Complete ---")

print("Model max_position_embeddings:", model.config.max_position_embeddings)
print("Tokenizer model_max_length:", tokenizer.model_max_length)

Loading model: meta-llama/Llama-2-70b-chat-hf with 4-bit quantization. This will take significant time and memory...


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]


Model loaded successfully!
Time taken to load model: 47.00 seconds.
Model device map: {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 1, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 2, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 3, 'model.layers.37': 3, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40'

In [12]:
import zipfile
import os

SERVER_ZIP_FILE_PATH = '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip'
EXTRACTION_DESTINATION_DIR_ON_SERVER = '/raid/infolab/gaurav/Llama_Spider_A100_Project/'

DEV_JSON_PATH = None
TABLES_JSON_PATH = None

def unzip_data(zip_filepath, dest_dir):
    """
    Unzips a zip file to a specified destination directory.
    """
    print(f"Attempting to unzip {zip_filepath} to {dest_dir}...")
    try:
        
        with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
            zip_ref.extractall(dest_dir)
        print(f"Successfully unzipped files to {dest_dir}")

        print(f"Contents of {dest_dir}:")
        for item in os.listdir(dest_dir):
            print(f"  - {item}")
        return True
    except zipfile.BadZipFile:
        print(f"Error: {zip_filepath} is not a valid zip file or is corrupted.")
        return False
    except FileNotFoundError:
        print(f"Error: Zip file not found at {zip_filepath}. Please ensure the path is correct.")
        return False
    except PermissionError:
        print(f"Error: Permission denied to write to {dest_dir} or read {zip_filepath}.")
        return False
    except Exception as e:
        print(f"An unexpected error occurred during unzipping: {e}")
        return False

print(f"Script started. Looking for zip file at: {SERVER_ZIP_FILE_PATH}")

if os.path.exists(SERVER_ZIP_FILE_PATH):
    print(f"Zip file found at {SERVER_ZIP_FILE_PATH}.")
    if unzip_data(SERVER_ZIP_FILE_PATH, EXTRACTION_DESTINATION_DIR_ON_SERVER):
        
        EXPECTED_EXTRACTED_FOLDER_NAME = 'spider_subset_data' # This is the folder INSIDE the zip

        DEV_JSON_PATH = os.path.join(EXTRACTION_DESTINATION_DIR_ON_SERVER, EXPECTED_EXTRACTED_FOLDER_NAME, 'dev.json')
        TABLES_JSON_PATH = os.path.join(EXTRACTION_DESTINATION_DIR_ON_SERVER, EXPECTED_EXTRACTED_FOLDER_NAME, 'tables.json')

        print("\nVerifying extracted file paths...")
        if os.path.exists(DEV_JSON_PATH):
            print(f"SUCCESS: dev.json path is valid: {DEV_JSON_PATH}")
        else:
            print(f"ERROR: dev.json NOT FOUND at expected path: {DEV_JSON_PATH}")
            print(f"Please check the contents of {os.path.join(EXTRACTION_DESTINATION_DIR_ON_SERVER, EXPECTED_EXTRACTED_FOLDER_NAME)}")


        if os.path.exists(TABLES_JSON_PATH):
            print(f"SUCCESS: tables.json path is valid: {TABLES_JSON_PATH}")
        else:
            print(f"ERROR: tables.json NOT FOUND at expected path: {TABLES_JSON_PATH}")
            print(f"Please check the contents of {os.path.join(EXTRACTION_DESTINATION_DIR_ON_SERVER, EXPECTED_EXTRACTED_FOLDER_NAME)}")

    else:
        print("Unzipping failed on the server. Cannot define data paths.")
else:
    print(f"ERROR: Zip file NOT FOUND at {SERVER_ZIP_FILE_PATH} on the server.")
    print("Please ensure the 'scp' command was successful and the path is correct.")


if DEV_JSON_PATH and TABLES_JSON_PATH and os.path.exists(DEV_JSON_PATH) and os.path.exists(TABLES_JSON_PATH):
    print("\n--- Ready to load data ---")
    print(f"Path to dev.json: {DEV_JSON_PATH}")
    print(f"Path to tables.json: {TABLES_JSON_PATH}")
    
else:
    print("\n--- Data paths are not correctly set up. Cannot proceed with data loading. ---")

Script started. Looking for zip file at: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip
Zip file found at /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip.
Attempting to unzip /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data.zip to /raid/infolab/gaurav/Llama_Spider_A100_Project/...
Successfully unzipped files to /raid/infolab/gaurav/Llama_Spider_A100_Project/
Contents of /raid/infolab/gaurav/Llama_Spider_A100_Project/:
  - experiments_70b_llama
  - .gitignore
  - backup_to_github.sh
  - Miniconda3-latest-Linux-x86_64.sh
  - spider_subset_data.zip
  - randomQ_allDBs_run1
  - .ipynb_checkpoints
  - .git
  - miniconda3
  - 100_queries.txt
  - spider_subset_data
  - __MACOSX

Verifying extracted file paths...
SUCCESS: dev.json path is valid: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/dev.json
SUCCESS: tables.json path is valid: /raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/tables.json

-

In [13]:
import json

def load_json_data(file_path):
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            return json.load(f)
    else:
        print(f"ERROR: File not found at {file_path}")
        return None

dev_data = load_json_data(DEV_JSON_PATH)
tables_data = load_json_data(TABLES_JSON_PATH)

if dev_data and tables_data:
    print(f"Loaded {len(dev_data)} queries from dev.json")
    print(f"Loaded {len(tables_data)} database schemas from tables.json")
else:
    print("Failed to load Spider data. Please check paths and upload.")

Loaded 1034 queries from dev.json
Loaded 166 database schemas from tables.json


In [92]:
# earlier
import json
# import os # Not strictly needed for this dictionary creation unless used in paths
# import traceback # Only needed if you keep the full traceback print in except

# --- Helper Functions (These are the same as you provided) ---
def load_schemas(tables_json_path):
    """Loads schemas from tables.json into a dictionary keyed by db_id."""
    with open(tables_json_path, 'r') as f:
        schemas_list = json.load(f)
    schemas_dict = {db_info['db_id']: db_info for db_info in schemas_list}
    return schemas_dict

def map_spider_type_to_sql_type(spider_type, is_pk_or_fk=False):
    """Maps Spider's generic types to SQLite data types."""
    spider_type = spider_type.lower()
    if spider_type == "text":
        return "TEXT"
    elif spider_type == "number":
        return "INTEGER" if is_pk_or_fk else "REAL"
    elif spider_type == "time":
        return "DATETIME"
    elif spider_type == "boolean":
        return "BOOLEAN"
    elif spider_type == "others":
        return "BLOB"
    else:
        return "TEXT"

def escape_sql_identifier(name):
    """Escapes SQL identifiers (table/column names) if they contain spaces or are keywords."""
    if " " in name or name.lower() in {"select", "from", "where", "table", "primary", "key", "foreign", "index", "order", "group"}:
        return f'"{name}"'
    return name

def generate_create_table_sql_for_db(db_id, all_schemas_data): # Parameter name changed for consistency
    """
    Generates SQL CREATE TABLE statements for a given db_id from the Spider schema.
    'all_schemas_data' is the dictionary produced by load_schemas.
    """
    if db_id not in all_schemas_data:
        return f"-- Database ID '{db_id}' not found in schemas."

    db_schema = all_schemas_data[db_id] # Get the specific schema info for this db_id
    sql_statements = []
    column_info_by_index = {}
    for i, (table_idx, col_name_original) in enumerate(db_schema['column_names_original']):
        if col_name_original == "*":
            continue
        column_info_by_index[i] = {
            "original_name": col_name_original,
            "table_index": table_idx,
            "original_table_name": db_schema['table_names_original'][table_idx],
            "type": db_schema['column_types'][i]
        }
    for table_idx, table_name_original in enumerate(db_schema['table_names_original']):
        escaped_table_name = escape_sql_identifier(table_name_original)
        column_definitions = []
        table_constraints = []
        current_table_columns = []
        for col_global_idx, (tbl_idx_for_col, col_name_orig) in enumerate(db_schema['column_names_original']):
            if col_name_orig == "*":
                continue
            if tbl_idx_for_col == table_idx:
                current_table_columns.append({
                    "global_idx": col_global_idx,
                    "name": col_name_orig,
                    "type": db_schema['column_types'][col_global_idx]
                })
        pk_column_indices_for_table = [
            pk_idx for pk_idx in db_schema['primary_keys']
            if column_info_by_index.get(pk_idx) and column_info_by_index[pk_idx]['table_index'] == table_idx
        ]
        pk_column_names_for_table = [column_info_by_index[idx]['original_name'] for idx in pk_column_indices_for_table]
        for col_data in current_table_columns:
            col_name_original = col_data['name']
            spider_type = col_data['type']
            col_global_idx = col_data['global_idx']
            is_pk_col = col_global_idx in pk_column_indices_for_table
            is_fk_col = any(fk_pair[0] == col_global_idx for fk_pair in db_schema['foreign_keys'])
            sql_type = map_spider_type_to_sql_type(spider_type, is_pk_or_fk=(is_pk_col or is_fk_col))
            escaped_col_name = escape_sql_identifier(col_name_original)
            col_def_str = f"{escaped_col_name} {sql_type}"
            if is_pk_col and len(pk_column_names_for_table) == 1:
                col_def_str += " PRIMARY KEY"
            column_definitions.append(col_def_str)
        if len(pk_column_names_for_table) > 1:
            escaped_pk_cols = [escape_sql_identifier(name) for name in pk_column_names_for_table]
            table_constraints.append(f"PRIMARY KEY ({', '.join(escaped_pk_cols)})")
        for fk_col_idx, referenced_col_idx in db_schema['foreign_keys']:
            if column_info_by_index.get(fk_col_idx) and \
               column_info_by_index.get(referenced_col_idx) and \
               column_info_by_index[fk_col_idx]['table_index'] == table_idx:
                fk_column_name = column_info_by_index[fk_col_idx]['original_name']
                referenced_table_name = column_info_by_index[referenced_col_idx]['original_table_name']
                referenced_column_name = column_info_by_index[referenced_col_idx]['original_name']
                escaped_fk_col = escape_sql_identifier(fk_column_name)
                escaped_ref_table = escape_sql_identifier(referenced_table_name)
                escaped_ref_col = escape_sql_identifier(referenced_column_name)
                table_constraints.append(
                    f"FOREIGN KEY ({escaped_fk_col}) REFERENCES {escaped_ref_table} ({escaped_ref_col})"
                )
        all_parts = column_definitions + table_constraints
        create_table_statement = f"CREATE TABLE {escaped_table_name} (\n  "
        create_table_statement += ",\n  ".join(all_parts)
        create_table_statement += "\n);"
        sql_statements.append(create_table_statement)
    return "\n\n".join(sql_statements)
# --- End of Helper Functions ---


# --- MODIFIED "Main Execution" for "Cell 1" to produce the dictionary ---
# This code will be run when you execute the Jupyter cell.
# The output variable needed by your experiment is `all_db_schemas_sql_strings`.

all_db_schemas_sql_strings = {} # This is the dictionary your experiment needs

# Define the path to your tables.json
spider_tables_json_path = '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/tables.json'

print("--- Cell 1: Preparing Database Schema SQL Strings (Dictionary Output) ---")
try:
    # 1. Load all schema structures from tables.json
    # `all_db_schemas_data_loaded` will be a dictionary: {db_id: schema_info_dict, ...}
    all_db_schemas_data_loaded = load_schemas(spider_tables_json_path) # Renamed to avoid confusion with function parameter
    print(f"Loaded schema data for {len(all_db_schemas_data_loaded)} databases from '{spider_tables_json_path}'.")

    # 2. Iterate through each loaded schema and generate its SQL string, storing it in the dictionary
    if all_db_schemas_data_loaded:
        for db_id in all_db_schemas_data_loaded: # Iterate through keys (db_ids)
            # Call generate_create_table_sql_for_db, passing the full loaded data
            # and the current db_id.
            sql_string_for_db = generate_create_table_sql_for_db(db_id, all_db_schemas_data_loaded)

            # Store the raw SQL string in the dictionary.
            # We only store it if it's a successful generation (doesn't start with the error message)
            if sql_string_for_db and not sql_string_for_db.startswith("-- Database ID"):
                all_db_schemas_sql_strings[db_id] = sql_string_for_db
            elif sql_string_for_db.startswith("-- Database ID"):
                print(f"Warning: Schema for {db_id} reported as not found by generate_create_table_sql_for_db.")
            else:
                print(f"Warning: SQL generation returned empty or unexpected for {db_id} (Result: '{sql_string_for_db[:50]}...')")

        print(f"Successfully populated `all_db_schemas_sql_strings` dictionary with {len(all_db_schemas_sql_strings)} entries.")
    else:
        print("No schema data loaded from tables.json, so `all_db_schemas_sql_strings` will be empty.")

except FileNotFoundError:
    print(f"FATAL ERROR: The file '{spider_tables_json_path}' was not found.")
    all_db_schemas_sql_strings = {} # Ensure it's defined as empty on error
except json.JSONDecodeError:
    print(f"FATAL ERROR: Could not decode JSON from '{spider_tables_json_path}'. Check if it's a valid JSON file.")
    all_db_schemas_sql_strings = {}
except Exception as e:
    print(f"FATAL ERROR during schema preparation: {e}")
    # import traceback # Uncomment if you need the full traceback here
    # traceback.print_exc()
    all_db_schemas_sql_strings = {}

# --- Verification (you can add this to your cell to check after it runs) ---
print(f"\n--- Verification of all_db_schemas_sql_strings ---")
print(f"Type: {type(all_db_schemas_sql_strings)}")
print(f"Number of schemas processed: {len(all_db_schemas_sql_strings)}")
if all_db_schemas_sql_strings:
    # Print a sample to verify content
    sample_db_id = list(all_db_schemas_sql_strings.keys())[1]
    print(f"Sample - DB ID: {sample_db_id}")
    print(f"Sample - SQL String :\n{all_db_schemas_sql_strings[sample_db_id]}")
else:
    print("`all_db_schemas_sql_strings` is empty. Review errors above.")
# --- End of Cell 1 Logic ---

--- Cell 1: Preparing Database Schema SQL Strings (Dictionary Output) ---
Loaded schema data for 166 databases from '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/tables.json'.
Successfully populated `all_db_schemas_sql_strings` dictionary with 166 entries.

--- Verification of all_db_schemas_sql_strings ---
Type: <class 'dict'>
Number of schemas processed: 166
Sample - DB ID: college_2
Sample - SQL String :
CREATE TABLE classroom (
  building TEXT PRIMARY KEY,
  room_number TEXT,
  capacity REAL
);

CREATE TABLE department (
  dept_name TEXT PRIMARY KEY,
  building TEXT,
  budget REAL
);

CREATE TABLE course (
  course_id TEXT PRIMARY KEY,
  title TEXT,
  dept_name TEXT,
  credits REAL,
  FOREIGN KEY (dept_name) REFERENCES department (dept_name)
);

CREATE TABLE instructor (
  ID TEXT PRIMARY KEY,
  name TEXT,
  dept_name TEXT,
  salary REAL,
  FOREIGN KEY (dept_name) REFERENCES department (dept_name)
);

CREATE TABLE section (
  course_id TEXT PRIMARY KEY,
  sec_i

In [None]:
# Random Sampling

# This cell defines parameters for running the experiment.
# It will now randomly select queries and always use ALL database schemas as candidates.

import random # Ensure random is imported at the top of your notebook or this cell
# import os # Ensure os is imported (likely already done for path joining)
# import json # Ensure json is imported (likely already done for loading)

# --- 2.1. Experiment Parameters ---
# Number of NL queries to RANDOMLY select from dev.json to process.
# For initial testing in Colab, use a small subset. For a more thorough run, increase this.
NUM_RANDOM_QUERIES_TO_TEST = 100 # For example, test 5 random queries

# This will now effectively always be True based on your requirement.
# The logic will be set up to use all schemas from all_db_schemas_sql_strings.
# We can keep the variable for clarity or remove it if it's always all DBs.
# For this implementation, let's explicitly aim for all DBs.
print("INFO: This experiment configuration will test each randomly selected query against ALL available Spider database schemas.")


# --- 2.2. Randomly Select NL Queries for the Experiment ---
# We will randomly sample NUM_RANDOM_QUERIES_TO_TEST queries from the loaded dev_data.
if not dev_data: # dev_data should have been loaded in Cell 1
    raise ValueError("dev_data is not loaded (from dev.json). Cannot select queries. Please run Cell 1 first.")

if len(dev_data) == 0:
    raise ValueError("dev_data is empty. No queries to select.")

actual_num_queries_to_select = min(NUM_RANDOM_QUERIES_TO_TEST, len(dev_data))
# Using min ensures we don't try to sample more queries than available.

if actual_num_queries_to_select < NUM_RANDOM_QUERIES_TO_TEST:
    print(f"Warning: Requested {NUM_RANDOM_QUERIES_TO_TEST} random queries, but only {len(dev_data)} are available. Using all {len(dev_data)} queries.")

# Randomly sample without replacement
selected_nl_queries = random.sample(dev_data, actual_num_queries_to_select)

print(f"\nRandomly selected {len(selected_nl_queries)} NL queries for the experiment:")
for i, q_info in enumerate(selected_nl_queries):
    print(f"  Test Query {i+1}: '{q_info['question']}' (True DB: {q_info['db_id']})")


# --- 2.3. Determine Candidate Database Schemas for Each Query ---
# For this experiment design, we ALWAYS use ALL available database schemas.
# all_db_schemas_sql_strings should have been populated in Cell 1.
if not all_sql_output: # Populated in Cell 1
    raise ValueError("all_sql_output is empty. Schemas were not converted in Cell 1. Cannot proceed.")

candidate_schemas_for_evaluation = all_db_schemas_sql_strings # Use all converted schemas
print(f"\nEach of the {len(selected_nl_queries)} selected queries will be evaluated against all {len(candidate_schemas_for_evaluation)} available Spider database schemas.")

if not candidate_schemas_for_evaluation: # Should not happen if all_db_schemas_sql_strings was populated
    raise ValueError("No candidate schemas available for evaluation. This indicates an issue with schema loading or conversion in Cell 1.")

In [32]:
import json

# --- Helper Functions (modified to output schema with PK and FK relationships) ---

def load_schemas(tables_json_path):
    """Loads schemas from tables.json into a dictionary keyed by db_id."""
    with open(tables_json_path, 'r') as f:
        schemas_list = json.load(f)
    schemas_dict = {db_info['db_id']: db_info for db_info in schemas_list}
    return schemas_dict

def map_spider_type_to_sql_type(spider_type):
    """Maps Spider's generic types to a concise SQL‐style type string."""
    spider_type = spider_type.lower()
    if spider_type == "text":
        return "TEXT"
    elif spider_type == "number":
        return "REAL"
    elif spider_type == "time":
        return "DATETIME"
    elif spider_type == "boolean":
        return "BOOLEAN"
    elif spider_type == "others":
        return "BLOB"
    else:
        return "TEXT"

def escape_sql_identifier(name):
    """Escapes identifiers if they contain spaces or are SQL keywords."""
    if " " in name or name.lower() in {
        "select", "from", "where", "table", "primary", "key",
        "foreign", "index", "order", "group"
    }:
        return f'"{name}"'
    return name

def generate_create_table_sql_for_db(db_id, all_schemas_data):
    """
    Instead of producing CREATE TABLE statements, this returns a schema description string:
    - For each table: "table_name: col1(type)[PK], col2(type)[FK->refTable.refCol], col3(type), ..."
    It uses 'primary_keys' and 'foreign_keys' arrays from the Spider JSON.
    """
    if db_id not in all_schemas_data:
        return f"-- Database ID '{db_id}' not found in schemas."
    
    db_schema = all_schemas_data[db_id]
    table_names = db_schema['table_names_original']
    column_names = db_schema['column_names_original']
    column_types = db_schema['column_types']
    pk_indices = set(db_schema.get('primary_keys', []))
    fk_pairs = db_schema.get('foreign_keys', [])

    # Build a lookup: column_index -> (table_idx, column_name, column_type)
    col_info = {}
    for idx, (tbl_idx, col_name) in enumerate(column_names):
        if col_name == "*":
            continue
        col_info[idx] = {
            "table_index": tbl_idx,
            "column_name": col_name,
            "column_type": column_types[idx]
        }

    # Initialize a structure: table_idx -> list of column descriptors
    tables_columns = {i: [] for i in range(len(table_names))}

    # For each column, determine if it's PK or FK or normal
    for col_idx, info in col_info.items():
        tbl_idx = info['table_index']
        col_name = escape_sql_identifier(info['column_name'])
        col_type = map_spider_type_to_sql_type(info['column_type'])

        is_pk = (col_idx in pk_indices)
        # Check if this column is a foreign key, and if so, note the referenced table/column
        fk_description = ""
        for (fk_col_idx, ref_col_idx) in fk_pairs:
            if fk_col_idx == col_idx:
                ref_info = col_info[ref_col_idx]
                ref_table = escape_sql_identifier(ref_info['original_table_name'] if 'original_table_name' in ref_info else table_names[ref_info['table_index']])
                ref_col_name = escape_sql_identifier(ref_info['column_name'])
                fk_description = f"[FK→{ref_table}.{ref_col_name}]"
                break

        pk_tag = "[PK]" if is_pk else ""
        tables_columns[tbl_idx].append(f"{col_name}({col_type}){pk_tag}{fk_description}")

    # Build the final schema string, one segment per table
    segments = []
    for tbl_idx, tbl_name in enumerate(table_names):
        escaped_tbl = escape_sql_identifier(tbl_name)
        col_list = tables_columns.get(tbl_idx, [])
        cols_joined = ", ".join(col_list)
        segments.append(f"{escaped_tbl}: {cols_joined}")

    # Join table segments with semicolons
    return "; ".join(segments)

# --- Main Execution (Cell 1) producing the dictionary ---
all_db_schemas_sql_strings = {}  # Dictionary to hold schema descriptions

# Path to tables.json
spider_tables_json_path = '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/tables.json'

print("--- Cell 1: Preparing Database Schemas with PK/FK Annotations ---")
try:
    # 1. Load all schema structures
    all_db_schemas_data_loaded = load_schemas(spider_tables_json_path)
    print(f"Loaded schema data for {len(all_db_schemas_data_loaded)} databases from '{spider_tables_json_path}'.")

    # 2. Iterate and generate schema descriptions
    for db_id in all_db_schemas_data_loaded:
        schema_desc = generate_create_table_sql_for_db(db_id, all_db_schemas_data_loaded)
        if schema_desc and not schema_desc.startswith("-- Database ID"):
            all_db_schemas_sql_strings[db_id] = schema_desc
        else:
            print(f"Warning: Could not generate schema for {db_id}")

    print(f"Populated `all_db_schemas_sql_strings` with {len(all_db_schemas_sql_strings)} entries.")

except FileNotFoundError:
    print(f"FATAL ERROR: File '{spider_tables_json_path}' not found.")
    all_db_schemas_sql_strings = {}
except json.JSONDecodeError:
    print(f"FATAL ERROR: Could not decode JSON from '{spider_tables_json_path}'.")
    all_db_schemas_sql_strings = {}
except Exception as e:
    print(f"FATAL ERROR during schema preparation: {e}")
    all_db_schemas_sql_strings = {}

# --- Verification ---
print(f"\n--- Verification of all_db_schemas_sql_strings ---")
print(f"Type: {type(all_db_schemas_sql_strings)}")
print(f"Number of schemas processed: {len(all_db_schemas_sql_strings)}")
if all_db_schemas_sql_strings:
    sample_db_id = list(all_db_schemas_sql_strings.keys())[125]
    print(f"Sample - DB ID: {sample_db_id}")
    print(f"Sample - Schema Description:\n{all_db_schemas_sql_strings[sample_db_id]}")
else:
    print("`all_db_schemas_sql_strings` is empty. Check for errors above.")


--- Cell 1: Preparing Database Schemas with PK/FK Annotations ---
Loaded schema data for 166 databases from '/raid/infolab/gaurav/Llama_Spider_A100_Project/spider_subset_data/tables.json'.
Populated `all_db_schemas_sql_strings` with 166 entries.

--- Verification of all_db_schemas_sql_strings ---
Type: <class 'dict'>
Number of schemas processed: 166
Sample - DB ID: yelp
Sample - Schema Description:
business: bid(REAL)[PK], business_id(TEXT), name(TEXT), full_address(TEXT), city(TEXT), latitude(TEXT), longitude(TEXT), review_count(REAL), is_open(REAL), rating(REAL), state(TEXT); category: id(REAL)[PK], business_id(TEXT)[FK→business.business_id], category_name(TEXT); user: uid(REAL)[PK], user_id(TEXT), name(TEXT); checkin: cid(REAL)[PK], business_id(TEXT)[FK→business.business_id], count(REAL), day(TEXT); neighbourhood: id(REAL)[PK], business_id(TEXT)[FK→business.business_id], neighbourhood_name(TEXT); review: rid(REAL)[PK], business_id(TEXT)[FK→business.business_id], user_id(TEXT)[FK→use

In [93]:
# This cell defines parameters for running the experiment.
# It will now randomly select queries and always use ALL database schemas as candidates.

import random # Ensure random is imported at the top of your notebook or this cell
# import os # Ensure os is imported (likely already done for path joining)
# import json # Ensure json is imported (likely already done for loading)

# --- 2.1. Experiment Parameters ---
# Number of NL queries to RANDOMLY select from dev.json to process.
# For initial testing in Colab, use a small subset. For a more thorough run, increase this.
NUM_RANDOM_QUERIES_TO_TEST = 100 # For example, test 5 random queries

# This will now effectively always be True based on your requirement.
# The logic will be set up to use all schemas from all_db_schemas_sql_strings.
# We can keep the variable for clarity or remove it if it's always all DBs.
# For this implementation, let's explicitly aim for all DBs.
print("INFO: This experiment configuration will test each randomly selected query against ALL available Spider database schemas.")


# --- 2.2. Randomly Select NL Queries for the Experiment ---
# We will randomly sample NUM_RANDOM_QUERIES_TO_TEST queries from the loaded dev_data.
if not dev_data: # dev_data should have been loaded in Cell 1
    raise ValueError("dev_data is not loaded (from dev.json). Cannot select queries. Please run Cell 1 first.")

if len(dev_data) == 0:
    raise ValueError("dev_data is empty. No queries to select.")

actual_num_queries_to_select = min(NUM_RANDOM_QUERIES_TO_TEST, len(dev_data))
# Using min ensures we don't try to sample more queries than available.

if actual_num_queries_to_select < NUM_RANDOM_QUERIES_TO_TEST:
    print(f"Warning: Requested {NUM_RANDOM_QUERIES_TO_TEST} random queries, but only {len(dev_data)} are available. Using all {len(dev_data)} queries.")

# Randomly sample without replacementselected_nl_queries = random.sample(dev_data, actual_num_queries_to_select)

import re
import os

# Path to your text file containing lines like:
#   Test Query 1: 'What are the names and release years for all the songs of the youngest singer?' (True DB: concert_singer)
TEXT_QUERIES_FILE = "/raid/infolab/gaurav/Llama_Spider_A100_Project/100_queries.txt"

if not os.path.exists(TEXT_QUERIES_FILE):
    raise FileNotFoundError(f"Cannot find '{TEXT_QUERIES_FILE}' – make sure it’s in your working directory or update the path.")

selected_nl_queries = []
pattern = re.compile(r"Test Query\s+\d+:\s+'(.+)'\s+\(True DB:\s*([^)]+)\)")

with open(TEXT_QUERIES_FILE, "r") as f_in:
    for line in f_in:
        line = line.strip()
        # Skip any header or non‐“Test Query” lines
        if not line.startswith("Test Query"):
            continue

        m = pattern.match(line)
        if not m:
            print(f"Warning: could not parse line:\n  {line}")
            continue

        question_text = m.group(1)
        true_db_id    = m.group(2)

        # Build the same dict‐structure downstream code expects
        selected_nl_queries.append({
            "question": question_text,
            "db_id":    true_db_id
        })

if len(selected_nl_queries) == 0:
    raise ValueError(f"No queries were parsed from '{TEXT_QUERIES_FILE}'. Check your file’s format.")

print(f"Loaded {len(selected_nl_queries)} queries from '{TEXT_QUERIES_FILE}':")
for i, q in enumerate(selected_nl_queries, 1):
    print(f"  Query {i}: '{q['question']}' (True DB: {q['db_id']})")

# --- 2.3. Determine Candidate Database Schemas for Each Query ---
# For this experiment design, we ALWAYS use ALL available database schemas.
# all_db_schemas_sql_strings should have been populated in Cell 1.
# if not all_sql_output: # Populated in Cell 1
#     raise ValueError("all_sql_output is empty. Schemas were not converted in Cell 1. Cannot proceed.")

candidate_schemas_for_evaluation = all_db_schemas_sql_strings # Use all converted schemas
print(f"\nEach of the {len(selected_nl_queries)} selected queries will be evaluated against all {len(candidate_schemas_for_evaluation)} available Spider database schemas.")

if not candidate_schemas_for_evaluation: # Should not happen if all_db_schemas_sql_strings was populated
    raise ValueError("No candidate schemas available for evaluation. This indicates an issue with schema loading or conversion in Cell 1.")

INFO: This experiment configuration will test each randomly selected query against ALL available Spider database schemas.
Loaded 20 queries from '/raid/infolab/gaurav/Llama_Spider_A100_Project/100_queries.txt':
  Query 1: 'What are the names and release years for all the songs of the youngest singer?' (True DB: concert_singer)
  Query 2: 'What are names of countries with the top 3 largest population?' (True DB: world_1)
  Query 3: 'What are the names and birth dates of people, ordered by their names in alphabetical order?' (True DB: poker_player)
  Query 4: 'How many different store locations are there?' (True DB: employee_hire_evaluation)
  Query 5: 'How many different nationalities do conductors have?' (True DB: orchestra)
  Query 6: 'How many states are there?' (True DB: voter_1)
  Query 7: 'What are the codes of template types that have fewer than 3 templates?' (True DB: cre_Doc_Template_Mgt)
  Query 8: 'How many dogs have not gone through any treatment?' (True DB: dog_kennels)
  Q

In [94]:
import os
import json 
LOCAL_EXPERIMENT_BASE_DIR = "/raid/infolab/gaurav/Llama_Spider_A100_Project/"


EXPERIMENT_RUN_NAME = "randomQ_allDBs_run1" 
EXPERIMENT_PROJECT_DIR = os.path.join(LOCAL_EXPERIMENT_BASE_DIR, EXPERIMENT_RUN_NAME)

try:
    os.makedirs(EXPERIMENT_PROJECT_DIR, exist_ok=True)
    print(f"Ensured experiment project directory exists: '{EXPERIMENT_PROJECT_DIR}'")
except OSError as e:
    print(f"Error creating directory {EXPERIMENT_PROJECT_DIR}: {e}")
    EXPERIMENT_PROJECT_DIR = "." 


RESULTS_FILENAME = "spider_random_query_all_db_scores_prompt_llama-2.json"
EXPERIMENT_RESULTS_FILE = os.path.join(EXPERIMENT_PROJECT_DIR, RESULTS_FILENAME)

print(f"Experiment results will be saved to: {EXPERIMENT_RESULTS_FILE}")

Ensured experiment project directory exists: '/raid/infolab/gaurav/Llama_Spider_A100_Project/randomQ_allDBs_run1'
Experiment results will be saved to: /raid/infolab/gaurav/Llama_Spider_A100_Project/randomQ_allDBs_run1/spider_random_query_all_db_scores_prompt_llama-2.json


In [105]:
# Cell defining get_yes_no_token_ids (CORRECTED)
def get_yes_no_token_ids(tokenizer_arg):
    """
    Determines the token IDs for 'Yes' and 'No', accounting for potential leading spaces.
    Llama-2-chat tends to produce " Yes" or " No" as single tokens after the prompt.
    """
    # Try with leading space first, as it's common for chat models
    yes_token_id_with_space = tokenizer_arg.encode(" Yes", add_special_tokens=False)
    no_token_id_with_space = tokenizer_arg.encode(" No", add_special_tokens=False)

    if len(yes_token_id_with_space) == 1 and len(no_token_id_with_space) == 1:
        print("Using ' Yes' and ' No' (with leading space) for Yes/No token IDs.")
        return yes_token_id_with_space[0], no_token_id_with_space[0] # Explicit return
    else:
        # Fallback to "Yes" and "No" without leading space
        yes_token_id_no_space = tokenizer_arg.encode("Yes", add_special_tokens=False)
        no_token_id_no_space = tokenizer_arg.encode("No", add_special_tokens=False)
        if len(yes_token_id_no_space) == 1 and len(no_token_id_no_space) == 1:
            print("Warning: Using 'Yes' and 'No' (no leading space) for Yes/No token IDs. This might be suboptimal for chat models.")
            return yes_token_id_no_space[0], no_token_id_no_space[0] # Explicit return
        else:
            # This case is problematic.
            print(f"ERROR: Could not determine reliable single token IDs for 'Yes'/'No' or ' Yes'/' No'.")
            print(f"Tokenization of ' Yes': {yes_token_id_with_space} (decoded: {[tokenizer_arg.decode(t) for t in yes_token_id_with_space]})")
            print(f"Tokenization of ' No': {no_token_id_with_space} (decoded: {[tokenizer_arg.decode(t) for t in no_token_id_with_space]})")
            print(f"Tokenization of 'Yes': {yes_token_id_no_space} (decoded: {[tokenizer_arg.decode(t) for t in yes_token_id_no_space]})")
            print(f"Tokenization of 'No': {no_token_id_no_space} (decoded: {[tokenizer_arg.decode(t) for t in no_token_id_no_space]})")
            # It's better to raise an error here so the problem is immediately obvious
            # rather than returning None and causing a TypeError later.
            raise ValueError("Unstable tokenization for 'Yes'/'No'. Review tokenization outputs above. Cannot proceed without reliable Yes/No token IDs.")

print("Helper function 'get_yes_no_token_ids' defined (with actual logic).")

Helper function 'get_yes_no_token_ids' defined (with actual logic).


In [106]:
if 'tokenizer' in globals() and tokenizer is not None:
    try:
        YES_TOKEN_ID, NO_TOKEN_ID = get_yes_no_token_ids(tokenizer)
        print(f"YES_TOKEN_ID: {YES_TOKEN_ID} ('{tokenizer.decode([YES_TOKEN_ID])}')")
        print(f"NO_TOKEN_ID: {NO_TOKEN_ID} ('{tokenizer.decode([NO_TOKEN_ID])}')")
    except ValueError as e:
        print(f"Error defining YES/NO token IDs: {e}")
else:
    print("ERROR: 'tokenizer' is not defined. Cannot define YES_TOKEN_ID and NO_TOKEN_ID.")

YES_TOKEN_ID: 3869 ('Yes')
NO_TOKEN_ID: 1939 ('No')


In [107]:
def get_yes_probability(model_arg, tokenizer_arg, system_prompt_arg, user_prompt_content_arg, yes_token_id_arg, no_token_id_arg, max_length=model.config.max_position_embeddings):
    """
    Gets the probability of the model answering "Yes" to the given query and schema.
    """
    messages = [
        {"role": "system", "content": system_prompt_arg},
        {"role": "user", "content": user_prompt_content_arg}
    ]

    # Apply the chat template and capture it in a variable
    prompt_for_model = tokenizer_arg.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # <-- Insert print here to inspect the fully-formatted prompt
    # print("=== Prompt After Chat Template ===")
    # print(prompt_for_model)
    # print("=== End of Prompt ===\n")

    inputs = tokenizer_arg(
        prompt_for_model,
        return_tensors="pt",
        truncation=True,
        max_length=max_length - 10
    )
    inputs = {k: v.to(model_arg.device) for k, v in inputs.items()}

    if inputs['input_ids'].shape[1] >= max_length - 10:
        print(f"Warning: Prompt for query was truncated. Length: {inputs['input_ids'].shape[1]}")

    with torch.no_grad():
        outputs = model_arg(**inputs)
        logits = outputs.logits
        next_token_logits = logits[:, -1, :]
        logit_yes = next_token_logits[:, yes_token_id_arg].item()
        logit_no = next_token_logits[:, no_token_id_arg].item()

    max_logit = max(logit_yes, logit_no)
    exp_yes = torch.exp(torch.tensor(logit_yes - max_logit, device=model_arg.device))
    exp_no = torch.exp(torch.tensor(logit_no - max_logit, device=model_arg.device))

    prob_yes = exp_yes / (exp_yes + exp_no)
    return prob_yes.item()


print("Core function 'get_yes_probability' defined.")  # Add a print statement to confirm execution


Core function 'get_yes_probability' defined.


In [108]:
# --- Condensed Prompt Configuration (fits in a 2K‐token window) ---

SYSTEM_PROMPT = """
You are an expert analyst. Decide if a natural‐language question can be answered *only* from the given schema. 
If all required tables, columns, and join‐paths exist, respond with exactly “Yes”. Otherwise, respond with exactly “No”.
"""

USER_PROMPT_TEMPLATE = """
# Few‐Shot Examples (Spider style)

[Schema: student(student_id, student_name); course(course_id, course_name); enrollment(student_id→student, course_id→course)]
Q: List the names of all students enrolled in the 'Math' course.
Reasoning: enrollment links student↔course; course_name exists; student_name exists → SQL possible.
A: Yes

[Schema: orders(order_id, customer_id, amount); customer(customer_id, customer_name)]
Q: Find the total number of orders placed in 2019.
Reasoning: no order_date or year column → cannot filter by 2019.
A: No

# Now Evaluate

[Schema: {schema_string}]
Q: {nl_query}
A:
"""

In [99]:
print("Testing get_yes_probability directly...")
try:
    # Construct a very simple schema and query for testing
    test_schema = "CREATE TABLE TestTable (id INT, name TEXT, salary INT);"
    test_nl_query = "random prompt score"
    sample_user_prompt_content = USER_PROMPT_TEMPLATE.format(
        schema_string=test_schema,
        nl_query=test_nl_query
    )
    # print(f"Test User Prompt: {sample_user_prompt_content}")

    # Make sure all these variables are defined and loaded:
    # model, tokenizer, SYSTEM_PROMPT, YES_TOKEN_ID, NO_TOKEN_ID
    prob = get_yes_probability(
        model,
        tokenizer,
        SYSTEM_PROMPT,
        sample_user_prompt_content,
        YES_TOKEN_ID,
        NO_TOKEN_ID
    )
    print(f"get_yes_probability returned: {prob}")
except Exception as e:
    import traceback
    print("Error during direct call to get_yes_probability:")
    traceback.print_exc()

Testing get_yes_probability directly...
get_yes_probability returned: 0.9724147915840149


In [121]:
import torch
# --- 1. Create a Sample Input (Schema + Query) ---
# You can use any example; this is just for demonstration
test_schema_string = """
CREATE TABLE department (did INTEGER, dname TEXT, budget REAL, num_employees INTEGER);
CREATE TABLE employee (eid INTEGER, ename TEXT, age INTEGER, department_did INTEGER, FOREIGN KEY(department_did) REFERENCES department(did));
CREATE TABLE project (pid INTEGER, pname TEXT, lead_eid INTEGER, FOREIGN KEY(lead_eid) REFERENCES employee(eid));
CREATE TABLE works_on (employee_eid INTEGER, project_pid INTEGER, hours INTEGER, FOREIGN KEY(employee_eid) REFERENCES employee(eid), FOREIGN KEY(project_pid) REFERENCES project(pid));
"""
test_nl_query = "Name the latest Iphone launched by APPLE"

# Use the USER_PROMPT_TEMPLATE from cell 3dfe1828-...
# This template includes few-shot examples and placeholders for the current schema/query.
sample_user_prompt_content = USER_PROMPT_TEMPLATE.format(
    schema_string=test_schema_string,
    nl_query=test_nl_query
)

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": sample_user_prompt_content}
]

# --- 2. Prepare Inputs for the Model ---

_messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": USER_PROMPT_TEMPLATE.format(schema_string=test_schema_string, nl_query=test_nl_query)}
]


tokenized_input_string = tokenizer.apply_chat_template(
    _messages,
    tokenize=False,
    add_generation_prompt=True # This is important!
)

inputs = tokenizer(tokenized_input_string, return_tensors="pt", truncation=True, max_length=2048-10).to(model.device)


# --- 3. Generate with SAMPLING (as in your original cell, but fixed) ---
print("\n--- Generating with SAMPLING (do_sample=True) ---")
max_new_tokens_sample = 10 # We expect "Yes" or "No", so few tokens are needed.
with torch.no_grad():
    outputs_sample = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens_sample,
        do_sample=True,
        num_beams=1,  # Explicitly state you are not using beam search
        temperature=0.1,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    )
    

# Decode and clean the output
input_length = inputs.input_ids.shape[1]
generated_tokens_sample = outputs_sample[0][input_length:]
response_text_sample = tokenizer.decode(generated_tokens_sample, skip_special_tokens=True).strip()
print(f"Sampled Model Response: '{response_text_sample}'")


# --- 4. Generate with GREEDY DECODING ---
print("\n--- Generating with GREEDY DECODING (do_sample=False) ---")
max_new_tokens_greedy = 100 # Expecting "Yes" or "No"
with torch.no_grad():
    outputs_greedy = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens_greedy,
        do_sample=False,  # Key for greedy decoding
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    )

# Decode and clean the output
generated_tokens_greedy = outputs_greedy[0][input_length:] # input_length is the same
response_text_greedy = tokenizer.decode(generated_tokens_greedy, skip_special_tokens=True).strip()
print(f"Greedy Model Response: '{response_text_greedy}'")

# --- 5. (Optional) Compare with get_yes_probability ---
print("\n--- For reference: get_yes_probability output ---")
# Use the 'user' part of the _messages, which is the fully formatted user prompt
user_prompt_content_for_prob = _messages[1]['content']
try:
    prob_yes = get_yes_probability(
        model,
        tokenizer,
        SYSTEM_PROMPT, # from cell 3dfe1828-...
        user_prompt_content_for_prob, # This is already formatted with schema, query, and "A:"
        YES_TOKEN_ID, # from cell abb8b6cb-...
        NO_TOKEN_ID   # from cell abb8b6cb-...
    )
    print(f"P(Yes) from get_yes_probability: {prob_yes:.4f}")
    if prob_yes > 0.5:
        print("Based on get_yes_probability, 'Yes' is more likely.")
    else:
        print("Based on get_yes_probability, 'No' is more likely.")
except Exception as e:
    print(f"Error running get_yes_probability for comparison: {e}")
    import traceback
    traceback.print_exc()


--- Generating with SAMPLING (do_sample=True) ---


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Sampled Model Response: 'No, the given schema does not contain any'

--- Generating with GREEDY DECODING (do_sample=False) ---
Greedy Model Response: 'No, the given schema does not contain any information about iPhones or their launch dates. The schema is focused on employee and project data, with no mention of products or launch dates. Therefore, it is not possible to answer the question "Name the latest iPhone launched by APPLE" based solely on the given schema.

Answer: No'

--- For reference: get_yes_probability output ---
P(Yes) from get_yes_probability: 0.9846
Based on get_yes_probability, 'Yes' is more likely.


In [88]:
# --- Ensure these imports are at the top of your script/notebook ---
import json
import os
import traceback
from tqdm.auto import tqdm # Use .auto or .notebook for Jupyter

# --- Prerequisites (must be defined and populated from Cell 1 and Cell 2): ---
# model, tokenizer, SYSTEM_PROMPT, USER_PROMPT_TEMPLATE,
# YES_TOKEN_ID, NO_TOKEN_ID, get_yes_probability,
# selected_nl_queries, candidate_schemas_for_evaluation, EXPERIMENT_RESULTS_FILE
# --- (Assume these are correctly defined above this cell) ---

# --- 3.1. Initialize Results Storage ---
experiment_all_query_results = []

# --- 3.2. Start the Loop ---
# This initial print is fine as it's before any tqdm loops start for this cell's main logic
print(f"\n--- Starting Experiment: {len(selected_nl_queries)} Random Queries vs. {len(candidate_schemas_for_evaluation)} Total DB Schemas ---")

# Outer loop: Iterate through each randomly selected NL query
for query_idx, nl_query_info in enumerate(tqdm(selected_nl_queries, desc="Processing NL Queries")):
    current_nl_query_text = nl_query_info['question']
    true_db_id_for_query = nl_query_info['db_id']
    experiment_query_id = f"spider_dev_q{query_idx}_{nl_query_info.get('query_id', 'idx'+str(query_idx))}"

    # Use tqdm.write for status updates related to the outer loop's progress
    # The '\n' at the beginning helps separate entries for each query visually.
    tqdm.write(f"\nProcessing Query {query_idx + 1}/{len(selected_nl_queries)} (ID: {experiment_query_id}): '{current_nl_query_text}' (True DB: {true_db_id_for_query})")

    scores_for_current_query = []

    # --- Optional: For debugging, print scores for the VERY FIRST query only ---
    # print_debug_scores_for_first_query_only = True
    # if print_debug_scores_for_first_query_only and query_idx == 0:
    #     tqdm.write(f"  --- Incremental Scores for First Query: '{current_nl_query_text}' ---")
    # --- End Optional Debug Print Setup ---

    # Inner loop: Iterate through each candidate database schema
    for candidate_db_id, candidate_schema_sql in tqdm(
        candidate_schemas_for_evaluation.items(),
        desc=f"  DBs for Q:{experiment_query_id[:20]}", # Description for the inner bar
        leave=False  # Inner bar will be removed upon completion of its loop
    ):
        user_prompt_content = USER_PROMPT_TEMPLATE.format(
            schema_string=candidate_schema_sql,
            nl_query=current_nl_query_text
        )
        p_yes_score = -1.0

        try:
            p_yes_score = get_yes_probability(
                model, tokenizer, SYSTEM_PROMPT, user_prompt_content, YES_TOKEN_ID, NO_TOKEN_ID
            )
        except Exception as e:
            # Use tqdm.write for error messages occurring inside the inner loop
            tqdm.write(f"    ERROR: Exception in get_yes_probability for Query ID '{experiment_query_id}' with DB '{candidate_db_id}'.")
            tqdm.write(f"    Exception type: {type(e).__name__}, Message: {e}")
            # if you need full traceback for debugging, tqdm.write(traceback.format_exc()) might work,
            # but it can be very verbose. Printing to a log file is better for extensive tracebacks.
            # traceback.print_exc() # This will print to stderr and might still mess with tqdm display

        scores_for_current_query.append({
            'candidate_db_id': candidate_db_id,
            'p_yes_score': p_yes_score
        })

        # --- Optional: For debugging, print scores for the VERY FIRST query only ---
        # if print_debug_scores_for_first_query_only and query_idx == 0:
        #     tqdm.write(f"    DB: {candidate_db_id}, Score: {p_yes_score:.4f}") # Incremental print with tqdm.write
        # --- End Optional Debug Print ---

    ranked_databases_for_query = sorted(scores_for_current_query, key=lambda x: x['p_yes_score'], reverse=True)

    # --- Optional: For debugging, print sorted scores for the VERY FIRST query only ---
    # if print_debug_scores_for_first_query_only and query_idx == 0:
    #     tqdm.write(f"  --- Sorted Ranked Databases for First Query: '{current_nl_query_text}' (Top 10) ---")
    #     for rank_info in ranked_databases_for_query[:10]:
    #         tqdm.write(f"    Ranked DB: {rank_info['candidate_db_id']}, Score: {rank_info['p_yes_score']:.4f}")
    # --- End Optional Debug Print ---

    experiment_all_query_results.append({
        'experiment_query_id': experiment_query_id,
        'nl_query_text': current_nl_query_text,
        'true_db_id': true_db_id_for_query,
        'ranked_databases_with_scores': ranked_databases_for_query
    })

    # --- 3.3. Periodic Saving of Results ---
    if (query_idx + 1) % 1 == 0 or (query_idx + 1) == len(selected_nl_queries):
        try:
            with open(EXPERIMENT_RESULTS_FILE, 'w') as f_out:
                json.dump(experiment_all_query_results, f_out, indent=2)
            # Use tqdm.write for save messages that occur between outer loop iterations
            tqdm.write(f"  Successfully saved intermediate results for {len(experiment_all_query_results)} queries to {EXPERIMENT_RESULTS_FILE}")
        except Exception as e:
            tqdm.write(f"  ERROR: Could not save intermediate results: {e}")

# --- 3.4. Experiment Loop Completion ---
# These final prints are after all tqdm loops are done, so standard print is fine.
print("\n--- Experiment Loop Finished ---")
if experiment_all_query_results:
    print(f"Processed {len(experiment_all_query_results)} queries in total.")
    try:
        with open(EXPERIMENT_RESULTS_FILE, 'w') as f_out:
            json.dump(experiment_all_query_results, f_out, indent=2)
        print(f"Final results comprehensively saved to {EXPERIMENT_RESULTS_FILE}")
    except Exception as e:
        print(f"ERROR: Could not save final results: {e}")
else:
    print("No results were generated from the experiment. Check logs for errors.")


--- Starting Experiment: 20 Random Queries vs. 166 Total DB Schemas ---


Processing NL Queries:   0%|          | 0/20 [00:00<?, ?it/s]


Processing Query 1/20 (ID: spider_dev_q0_idx0): 'What are the names and release years for all the songs of the youngest singer?' (True DB: concert_singer)


  DBs for Q:spider_dev_q0_idx0:   0%|          | 0/166 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [85]:
import os
import json

# Path where the evaluation summary (Recall@K results) will be saved
EVAL_RESULTS_SAVE_PATH = "recall_k_results_context_4096_prompt_changed.json"

# --- 4.1. Define Recall@K Calculation Function ---
def calculate_recall_at_k_metric(all_query_results_list, k_values_list):
    """
    Calculates Recall@K for a list of K values.
    Each item in all_query_results_list should be a dictionary with:
        'true_db_id': The ground truth database ID for the query.
        'ranked_databases_with_scores': A list of {'candidate_db_id': id, 'p_yes_score': score},
                                         sorted by score in descending order.
    """
    recall_counts = {k: 0 for k in k_values_list}  # Stores how many times true_db was in top K
    total_valid_queries = 0  # Queries for which we have a true_db_id

    if not all_query_results_list:
        return {k: 0.0 for k in k_values_list}, 0

    for query_result in all_query_results_list:
        true_db = query_result.get('true_db_id')
        ranked_dbs_info = query_result.get('ranked_databases_with_scores')

        if true_db is None or ranked_dbs_info is None:
            print(f"Warning: Skipping query result due to missing 'true_db_id' or 'ranked_databases_with_scores': "
                  f"{query_result.get('experiment_query_id', 'Unknown Query')}")
            continue  # Skip if essential information is missing

        total_valid_queries += 1
        # Extract just the DB IDs from the ranked list
        ranked_db_ids_only = [item['candidate_db_id'] for item in ranked_dbs_info]

        for k in k_values_list:
            # Get the top K predicted database IDs
            top_k_predicted_dbs = ranked_db_ids_only[:k]
            if true_db in top_k_predicted_dbs:
                recall_counts[k] += 1

    # Calculate final recall percentages
    recall_percentages = {}
    if total_valid_queries > 0:
        for k in k_values_list:
            recall_percentages[k] = (recall_counts[k] / total_valid_queries) * 100.0  # As percentage
    else:
        recall_percentages = {k: 0.0 for k in k_values_list}

    return recall_percentages, total_valid_queries


# --- 4.2. Perform Evaluation ---
# Load results if this cell is run in a new session and experiment_all_query_results isn't in memory
# (assuming results were saved to EXPERIMENT_RESULTS_FILE)
loaded_results_for_eval = None
if 'experiment_all_query_results' in globals() and experiment_all_query_results:
    print("Using in-memory experiment_all_query_results for evaluation.")
    loaded_results_for_eval = experiment_all_query_results
elif os.path.exists(EXPERIMENT_RESULTS_FILE):
    print(f"Loading results from {EXPERIMENT_RESULTS_FILE} for evaluation...")
    try:
        with open(EXPERIMENT_RESULTS_FILE, 'r') as f_in:
            loaded_results_for_eval = json.load(f_in)
        print(f"Successfully loaded {len(loaded_results_for_eval)} results from file.")
    except Exception as e:
        print(f"Error loading results from file for evaluation: {e}")
else:
    print("No results available in memory or in the specified results file for evaluation.")

if loaded_results_for_eval:
    K_VALUES_TO_EVALUATE = [1, 3, 5, 10]  # Define the K values you care about
    recall_scores_map, num_queries_evaluated = calculate_recall_at_k_metric(
        loaded_results_for_eval, K_VALUES_TO_EVALUATE
    )

    print("\n--- Evaluation: Recall@K ---")
    print(f"Evaluated on {num_queries_evaluated} queries.")
    for k_val, recall_val in recall_scores_map.items():
        print(f"Recall@{k_val}: {recall_val:.2f}%")

    # --- 4.2.1. Save evaluation results to a JSON file ---
    try:
        eval_summary = {
            "num_queries_evaluated": num_queries_evaluated,
            "recall_scores": recall_scores_map
        }
        with open(EVAL_RESULTS_SAVE_PATH, 'w') as fout:
            json.dump(eval_summary, fout, indent=2)
        print(f"Saved evaluation results to '{EVAL_RESULTS_SAVE_PATH}'")
    except Exception as save_err:
        print(f"Error saving evaluation results: {save_err}")

    # --- 4.3. Optional: Print Detailed Results for a Few Queries ---
    print("\n--- Sample Detailed Query Results (Top 5 Queries) ---")
    for i, res in enumerate(loaded_results_for_eval[:5]):  # Show for first 5 queries
        print(f"\nQuery {i+1}: '{res.get('nl_query_text', '<no text>')}' (True DB: {res.get('true_db_id')})")
        print("  Top Ranked Databases (with P(Yes) scores):")
        for rank, db_info in enumerate(res.get('ranked_databases_with_scores', [])[:5]):  # Show top 5 ranked DBs
            is_true_db_char = "*" if db_info['candidate_db_id'] == res['true_db_id'] else " "
            print(f"    {rank+1}. {db_info['candidate_db_id']}{is_true_db_char} "
                  f"(Score: {db_info['p_yes_score']:.4f})")
else:
    print("Cannot perform evaluation as no results were loaded or generated.")


Using in-memory experiment_all_query_results for evaluation.

--- Evaluation: Recall@K ---
Evaluated on 20 queries.
Recall@1: 0.00%
Recall@3: 0.00%
Recall@5: 0.00%
Recall@10: 0.00%
Saved evaluation results to 'recall_k_results_context_4096_prompt_changed.json'

--- Sample Detailed Query Results (Top 5 Queries) ---

Query 1: 'What are the names and release years for all the songs of the youngest singer?' (True DB: concert_singer)
  Top Ranked Databases (with P(Yes) scores):
    1. real_estate_properties  (Score: 0.9931)
    2. student_assessment  (Score: 0.9909)
    3. driving_school  (Score: 0.9909)
    4. college_1  (Score: 0.9909)
    5. academic  (Score: 0.9900)

Query 2: 'What are names of countries with the top 3 largest population?' (True DB: world_1)
  Top Ranked Databases (with P(Yes) scores):
    1. real_estate_properties  (Score: 0.9933)
    2. insurance_and_eClaims  (Score: 0.9897)
    3. student_assessment  (Score: 0.9893)
    4. solvency_ii  (Score: 0.9883)
    5. academic