In [27]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [28]:
# Install just the 'evaluate' package; don’t upgrade or reinstall any of Kaggle’s other packages
!pip install -q --no-deps evaluate


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [29]:
# =============================================================================
# CELL 2: Import Libraries and Setup
# =============================================================================

import torch
import pandas as pd
import numpy as np
import sqlparse
import re
from datasets import Dataset, load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, 
    TrainingArguments, Trainer, DataCollatorForSeq2Seq
)
from sklearn.model_selection import train_test_split
from evaluate import load
import warnings
warnings.filterwarnings('ignore')

In [30]:
# =============================================================================
# CELL 3: Load and Filter Dataset
# =============================================================================

# Load the Gretel synthetic text-to-SQL dataset using your working method
print("Loading dataset...")
# Single "train" split on Hugging Face
raw_ds = load_dataset("gretelai/synthetic_text_to_sql", split="train")

# Convert to pandas for easy manipulation
df_raw = raw_ds.to_pandas()
print("Raw dataset shape:", df_raw.shape)
print("Columns:", df_raw.columns.tolist())

print("\nMissing (NaN) values per column:")
print(df_raw.isna().sum())

print("\nEmpty-string values per column:")
print((df_raw == "").sum())

mask_raw = df_raw.isna().any(axis=1) | (df_raw == "").any(axis=1)
print(f"\nRows with ≥1 missing/empty field: {mask_raw.sum()}/{len(df_raw)}")

# Unique domains
print("\nUnique domains:", df_raw['domain'].unique())

# For each domain, list the SQL task types ("roles")
domain_roles = df_raw.groupby('domain')['sql_task_type'].unique()
print("\nSQL task types by domain:")
for dom, roles in domain_roles.items():
    print(f"  {dom}: {roles}")

# Display sample row to understand structure
print("\nSample row:")
print(df_raw.iloc[0])

# Filter for multiple domains: financial services, insurance, healthcare, and finance
target_domains = ['financial services', 'insurance', 'healthcare', 'finance']
df_filtered = df_raw[df_raw['domain'].isin(target_domains)].copy()

print(f"\nFiltered dataset size (multiple domains): {len(df_filtered)}")

# Display distribution of selected domains
print("\nSelected domains distribution:")
print(df_filtered['domain'].value_counts())

# Display distribution of all domains (for reference)
print("\nAll domains distribution (top 10):")
print(df_raw['domain'].value_counts().head(10))

Loading dataset...
Raw dataset shape: (100000, 11)
Columns: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation']

Missing (NaN) values per column:
id                            0
domain                        0
domain_description            0
sql_complexity                0
sql_complexity_description    0
sql_task_type                 0
sql_task_type_description     0
sql_prompt                    0
sql_context                   0
sql                           0
sql_explanation               0
dtype: int64

Empty-string values per column:
id                            0
domain                        0
domain_description            0
sql_complexity                0
sql_complexity_description    0
sql_task_type                 0
sql_task_type_description     0
sql_prompt                    0
sql_context                   0
sql                           0


In [31]:
# =============================================================================
# CELL 4: Data Preprocessing and Format Preparation (BETTER PROMPTING)
# =============================================================================

def prepare_input_output(row):
    """
    Prepare input and output format for the models
    Input: schema + natural language query (with better prompting)
    Output: SQL + explanation
    """
    # Create input text with better prompting
    input_text = f"""Convert this natural language question to SQL using the given database schema.

Database Schema:
{row['sql_context']}

Question: {row['sql_prompt']}

Please generate a SQL query with explanation:"""
    
    # Create output text (SQL + explanation) - unchanged
    sql_query = row['sql']
    explanation = row.get('sql_explanation', 'No explanation provided')
    output_text = f"SQL: {sql_query}\nExplanation: {explanation}"
    
    return input_text, output_text

# Prepare the data
print("Preparing input-output pairs...")
inputs = []
outputs = []

for idx, row in df_filtered.iterrows():
    input_text, output_text = prepare_input_output(row)
    inputs.append(input_text)
    outputs.append(output_text)

# Create DataFrame with processed data
processed_df = pd.DataFrame({
    'input': inputs,
    'output': outputs
})

print(f"Processed {len(processed_df)} examples")
print("\nSample input:")
print(processed_df['input'].iloc[0][:400] + "...")
print("\nSample output:")
print(processed_df['output'].iloc[0][:300] + "...")

Preparing input-output pairs...
Processed 4318 examples

Sample input:
Convert this natural language question to SQL using the given database schema.

Database Schema:
CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);

Question: What is the total trade value and average price for each trader and stock in the trade_history table?

Please generate a SQL query with explanation:...

Sample output:
SQL: SELECT trader_id, stock, SUM(price * quantity) as total_trade_value, AVG(price) as avg_price FROM trade_history GROUP BY trader_id, stock;
Explanation: This query calculates the total trade value and average price for each trader and stock by grouping the trade_history table by the trader_id an...


In [32]:
# =============================================================================
# CELL 5: Train-Test Split
# =============================================================================

# Split the data
train_df, test_df = train_test_split(processed_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.125, random_state=42)  # 0.125 * 0.8 = 0.1 of total

print(f"Training set: {len(train_df)} examples")
print(f"Validation set: {len(val_df)} examples") 
print(f"Test set: {len(test_df)} examples")

# Convert to HuggingFace datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

Training set: 3022 examples
Validation set: 432 examples
Test set: 864 examples


In [33]:
# =============================================================================
# CELL 6: Model and Tokenizer Setup
# =============================================================================

# Model configurations
models_config = {
    'codet5-small': 'Salesforce/codet5-small',
    # 't5-small': 't5-small'
    'flan-t5-small': 'google/flan-t5-small' 
}

def setup_model_and_tokenizer(model_name):
    """Setup model and tokenizer"""
    print(f"Loading {model_name}...")
    
    tokenizer = AutoTokenizer.from_pretrained(models_config[model_name])
    model = AutoModelForSeq2SeqLM.from_pretrained(models_config[model_name])
    
    # Add special tokens if needed
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model.to(device)
    return model, tokenizer

In [34]:
# =============================================================================
# CELL 7: Data Tokenization Function
# =============================================================================

def tokenize_data(examples, tokenizer, max_input_length=512, max_output_length=256):
    """Tokenize the input-output pairs"""
    
    # Tokenize inputs
    model_inputs = tokenizer(
        examples['input'],
        max_length=max_input_length,
        truncation=True,
        padding=True,
        return_tensors="pt"
    )
    
    # Tokenize outputs
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples['output'],
            max_length=max_output_length,
            truncation=True,
            padding=True,
            return_tensors="pt"
        )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [35]:
# =============================================================================
# CELL 8: Evaluation Functions
# =============================================================================

def extract_sql_components(sql_text):
    """Extract structural components from SQL"""
    try:
        # Parse SQL
        parsed = sqlparse.parse(sql_text)[0]
        
        components = {
            'tables': set(),
            'columns': set(),
            'conditions': set(),
            'functions': set()
        }
        
        # Convert to string and extract components using regex
        sql_str = str(parsed).upper()
        
        # Extract table names (simple heuristic)
        from_match = re.search(r'FROM\s+([^\s\(\)]+)', sql_str)
        if from_match:
            components['tables'].add(from_match.group(1))
        
        join_matches = re.findall(r'JOIN\s+([^\s\(\)]+)', sql_str)
        components['tables'].update(join_matches)
        
        # Extract common SQL functions
        functions = re.findall(r'(COUNT|SUM|AVG|MIN|MAX|GROUP BY|ORDER BY)\s*\(', sql_str)
        components['functions'].update(functions)
        
        # Extract WHERE conditions (simplified)
        where_match = re.search(r'WHERE\s+(.+?)(?:GROUP|ORDER|$)', sql_str)
        if where_match:
            conditions = where_match.group(1).strip()
            components['conditions'].add(conditions)
        
        return components
        
    except Exception as e:
        return {'tables': set(), 'columns': set(), 'conditions': set(), 'functions': set()}

# def can_parse_sql(sql_text):
#     """Check if SQL can be parsed (syntax correctness)"""
#     try:
#         sqlparse.parse(sql_text)
#         return 1.0
#     except:
#         return 0.0

def can_parse_sql(sql_text):
    try:
        parsed = sqlparse.parse(sql_text)[0]
        # Check if it actually looks like a SQL statement
        sql_upper = str(parsed).upper().strip()
        
        # Must start with SQL keywords
        valid_starts = ('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER')
        if not any(sql_upper.startswith(keyword) for keyword in valid_starts):
            return 0.0
            
        # Must have basic SQL structure
        if 'SELECT' in sql_upper and 'FROM' not in sql_upper:
            return 0.0  # SELECT without FROM is usually invalid
            
        return 1.0
    except:
        return 0.0

def structural_similarity(ref_sql, gen_sql):
    """Calculate structural similarity between reference and generated SQL"""
    ref_components = extract_sql_components(ref_sql)
    gen_components = extract_sql_components(gen_sql)
    
    similarities = []
    
    for component_type in ['tables', 'functions', 'conditions']:
        ref_set = ref_components[component_type]
        gen_set = gen_components[component_type]
        
        if len(ref_set) == 0 and len(gen_set) == 0:
            similarities.append(1.0)
        elif len(ref_set) == 0 or len(gen_set) == 0:
            similarities.append(0.0)
        else:
            # Jaccard similarity
            intersection = len(ref_set & gen_set)
            union = len(ref_set | gen_set)
            similarities.append(intersection / union if union > 0 else 0.0)
    
    return np.mean(similarities)

def exact_match_score(ref_text, gen_text):
    """Calculate exact match after normalization"""
    # Simple normalization
    ref_normalized = re.sub(r'\s+', ' ', ref_text.strip().lower())
    gen_normalized = re.sub(r'\s+', ' ', gen_text.strip().lower())
    return 1.0 if ref_normalized == gen_normalized else 0.0

def split_sql_explanation(text):
    """Split combined SQL+explanation text"""
    # Look for SQL: and Explanation: markers
    sql_match = re.search(r'SQL:\s*(.*?)(?=\nExplanation:|$)', text, re.DOTALL)
    exp_match = re.search(r'Explanation:\s*(.*?)$', text, re.DOTALL)
    
    sql = sql_match.group(1).strip() if sql_match else text
    explanation = exp_match.group(1).strip() if exp_match else ""
    
    return sql, explanation

def evaluate_predictions(references, predictions):
    """Comprehensive evaluation of predictions"""
    
    # Load BLEU metric
    bleu_metric = load("bleu")
    
    results = {
        'syntax_scores': [],
        'structural_scores': [],
        'exact_match_scores': [],
        'explanation_bleu_scores': [],
        'combined_scores': []
    }
    
    for ref, pred in zip(references, predictions):
        # Split SQL and explanations
        ref_sql, ref_exp = split_sql_explanation(ref)
        pred_sql, pred_exp = split_sql_explanation(pred)
        
        # Calculate individual scores
        syntax_score = can_parse_sql(pred_sql)
        structural_score = structural_similarity(ref_sql, pred_sql)
        exact_score = exact_match_score(ref_sql, pred_sql)
        
        # BLEU for explanations
        if ref_exp and pred_exp:
            bleu_score = bleu_metric.compute(
                predictions=[pred_exp], 
                references=[[ref_exp]]
            )['bleu']
        else:
            bleu_score = 0.0
        
        # Combined score
        combined_score = (0.2 * syntax_score + 
                         0.4 * structural_score + 
                         0.1 * exact_score + 
                         0.3 * bleu_score)
        
        # Store results
        results['syntax_scores'].append(syntax_score)
        results['structural_scores'].append(structural_score)
        results['exact_match_scores'].append(exact_score)
        results['explanation_bleu_scores'].append(bleu_score)
        results['combined_scores'].append(combined_score)
    
    # Calculate averages
    avg_results = {key: np.mean(values) for key, values in results.items()}
    
    return avg_results, results

In [36]:
# =============================================================================
# CELL 9: Zero-Shot Evaluation Function
# =============================================================================

def zero_shot_evaluation(model, tokenizer, test_dataset, model_name):
    """Perform zero-shot evaluation"""
    print(f"\n=== Zero-Shot Evaluation for {model_name} ===")
    
    model.eval()
    predictions = []
    references = []
    
    # Generate predictions
    for i, example in enumerate(test_dataset):
        if i % 20 == 0:
            print(f"Processing example {i+1}/{len(test_dataset)}")
        
        input_text = example['input']
        reference = example['output']
        
        # Tokenize input
        inputs = tokenizer(
            input_text, 
            return_tensors="pt", 
            max_length=512, 
            truncation=True
        ).to(device)
        
        # Generate prediction
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=256,
                num_beams=2,
                early_stopping=True
            )
        
        # Decode prediction
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        predictions.append(prediction)
        references.append(reference)
    
    # Evaluate
    avg_results, detailed_results = evaluate_predictions(references, predictions)
    
    print(f"\nZero-Shot Results for {model_name}:")
    print(f"Syntax Correctness: {avg_results['syntax_scores']:.3f}")
    print(f"Structural Similarity: {avg_results['structural_scores']:.3f}")
    print(f"Exact Match: {avg_results['exact_match_scores']:.3f}")
    print(f"Explanation BLEU: {avg_results['explanation_bleu_scores']:.3f}")
    print(f"Combined Score: {avg_results['combined_scores']:.3f}")
    
    return avg_results, predictions, references

In [37]:
# =============================================================================
# CELL 10: Fine-tuning Function
# =============================================================================

def fine_tune_model(model, tokenizer, train_dataset, val_dataset, model_name):
    """Fine-tune the model"""
    print(f"\n=== Fine-tuning {model_name} ===")

      # Disable wandb
    import os
    os.environ["WANDB_DISABLED"] = "true"
    
    # Tokenize datasets
    def tokenize_function(examples):
        return tokenize_data(examples, tokenizer)
    
    tokenized_train = train_dataset.map(tokenize_function, batched=True)
    tokenized_val = val_dataset.map(tokenize_function, batched=True)
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=f'./results_{model_name}',
        num_train_epochs=3,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        warmup_steps=50,
        weight_decay=0.01,
        logging_dir=f'./logs_{model_name}',
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=50,
        save_strategy="steps",
        save_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        save_total_limit=2,
        gradient_accumulation_steps=2,
        fp16=True if torch.cuda.is_available() else False,
    )

    # training_args = TrainingArguments(
    #     # Output and logging
    #     output_dir=f'./results_{model_name}',
    #     logging_dir=f'./logs_{model_name}',
    #     logging_steps=25,                      # Log less frequently
        
    #     # Training schedule
    #     num_train_epochs=3,
    #     per_device_train_batch_size=6,         # Increased from 4 (better GPU utilization)
    #     per_device_eval_batch_size=8,          # Higher for evaluation (no gradients)
    #     gradient_accumulation_steps=1,         # Reduced since batch size increased
        
    #     # Learning rate and scheduling  
    #     learning_rate=3e-4,                    # Slightly higher for better convergence
    #     lr_scheduler_type="cosine_with_restarts", # Better than pure cosine
    #     warmup_ratio=0.06,                     # 6% warmup (shorter for small dataset)
        
    #     # Regularization and stability
    #     weight_decay=0.01,                     # Light regularization
    #     max_grad_norm=1.0,                     # Gradient clipping
        
    #     # Evaluation and saving
    #     eval_strategy="steps",
    #     eval_steps=100,                        # Less frequent evaluation
    #     save_strategy="steps", 
    #     save_steps=200,                        # Save less frequently
    #     load_best_model_at_end=True,
    #     metric_for_best_model="eval_loss",
    #     greater_is_better=False,
    #     save_total_limit=1,                    # Keep only best model (save space)
        
    #     # Memory and performance optimization
    #     fp16=True,                             # Half precision (faster + less memory)
    #     dataloader_num_workers=2,              # Faster data loading
    #     remove_unused_columns=False,           # Sometimes helps with seq2seq
    #     prediction_loss_only=True,             # Focus on loss during evaluation
        
    #     # Kaggle-specific optimizations
    #     dataloader_pin_memory=True,            # Faster GPU transfers
    #     skip_memory_metrics=True,              # Reduce logging overhead
    #     report_to=None,                        # Disable all experiment tracking
    # )
    
    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    
    # Train
    print("Starting training...")
    trainer.train()
    
    print(f"Fine-tuning completed for {model_name}")
    return trainer

In [38]:
# =============================================================================
# CELL 11: Post Fine-tuning Evaluation Function
# =============================================================================

def post_finetune_evaluation(model, tokenizer, test_dataset, model_name):
    """Evaluate model after fine-tuning"""
    print(f"\n=== Post Fine-tuning Evaluation for {model_name} ===")
    
    return zero_shot_evaluation(model, tokenizer, test_dataset, f"{model_name}_finetuned")

In [39]:
# =============================================================================
# CELL 11.5: Device Setup
# =============================================================================

# Set up device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# If using GPU, print GPU info
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("GPU not available, using CPU")

Using device: cuda
GPU: Tesla T4
Memory: 14.7 GB


In [40]:
# =============================================================================
# CELL 12: Run Complete Experiment - CodeT5-Small
# =============================================================================

print("="*60)
print("STARTING EXPERIMENT WITH CODET5-SMALL")
print("="*60)

# Setup CodeT5-Small
codet5_model, codet5_tokenizer = setup_model_and_tokenizer('codet5-small')

# Zero-shot evaluation
codet5_zero_results, codet5_zero_preds, codet5_zero_refs = zero_shot_evaluation(
    codet5_model, codet5_tokenizer, test_dataset, 'CodeT5-Small'
)

# Fine-tune CodeT5-Small
codet5_trainer = fine_tune_model(
    codet5_model, codet5_tokenizer, train_dataset, val_dataset, 'codet5-small'
)

# Post fine-tuning evaluation
codet5_ft_results, codet5_ft_preds, codet5_ft_refs = post_finetune_evaluation(
    codet5_model, codet5_tokenizer, test_dataset, 'CodeT5-Small'
)

STARTING EXPERIMENT WITH CODET5-SMALL
Loading codet5-small...

=== Zero-Shot Evaluation for CodeT5-Small ===
Processing example 1/864
Processing example 21/864
Processing example 41/864
Processing example 61/864
Processing example 81/864
Processing example 101/864
Processing example 121/864
Processing example 141/864
Processing example 161/864
Processing example 181/864
Processing example 201/864
Processing example 221/864
Processing example 241/864
Processing example 261/864
Processing example 281/864
Processing example 301/864
Processing example 321/864
Processing example 341/864
Processing example 361/864
Processing example 381/864
Processing example 401/864
Processing example 421/864
Processing example 441/864
Processing example 461/864
Processing example 481/864
Processing example 501/864
Processing example 521/864
Processing example 541/864
Processing example 561/864
Processing example 581/864
Processing example 601/864
Processing example 621/864
Processing example 641/864
Proces

Map:   0%|          | 0/3022 [00:00<?, ? examples/s]

Map:   0%|          | 0/432 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting training...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
50,0.7968,0.64037
100,0.6198,0.453571
150,0.5582,0.400075
200,0.4583,0.374543
250,0.4334,0.349339
300,0.4278,0.33748
350,0.423,0.328616
400,0.388,0.32061
450,0.3783,0.317735
500,0.3547,0.313231


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Fine-tuning completed for codet5-small

=== Post Fine-tuning Evaluation for CodeT5-Small ===

=== Zero-Shot Evaluation for CodeT5-Small_finetuned ===
Processing example 1/864
Processing example 21/864
Processing example 41/864
Processing example 61/864
Processing example 81/864
Processing example 101/864
Processing example 121/864
Processing example 141/864
Processing example 161/864
Processing example 181/864
Processing example 201/864
Processing example 221/864
Processing example 241/864
Processing example 261/864
Processing example 281/864
Processing example 301/864
Processing example 321/864
Processing example 341/864
Processing example 361/864
Processing example 381/864
Processing example 401/864
Processing example 421/864
Processing example 441/864
Processing example 461/864
Processing example 481/864
Processing example 501/864
Processing example 521/864
Processing example 541/864
Processing example 561/864
Processing example 581/864
Processing example 601/864
Processing example 

In [41]:
# =============================================================================
# CELL 13: Run Complete Experiment - T5-Small
# =============================================================================

print("="*60)
print("STARTING EXPERIMENT WITH T5-SMALL")
print("="*60)

# Setup T5-Small
# t5_model, t5_tokenizer = setup_model_and_tokenizer('t5-small')
t5_model, t5_tokenizer = setup_model_and_tokenizer('flan-t5-small')

# Zero-shot evaluation
t5_zero_results, t5_zero_preds, t5_zero_refs = zero_shot_evaluation(
    t5_model, t5_tokenizer, test_dataset, 'T5-Small'
)

# Fine-tune T5-Small
t5_trainer = fine_tune_model(
    t5_model, t5_tokenizer, train_dataset, val_dataset, 't5-small'
)

# Post fine-tuning evaluation
t5_ft_results, t5_ft_preds, t5_ft_refs = post_finetune_evaluation(
    t5_model, t5_tokenizer, test_dataset, 'T5-Small'
)

STARTING EXPERIMENT WITH T5-SMALL
Loading flan-t5-small...


tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


=== Zero-Shot Evaluation for T5-Small ===
Processing example 1/864
Processing example 21/864
Processing example 41/864
Processing example 61/864
Processing example 81/864
Processing example 101/864
Processing example 121/864
Processing example 141/864
Processing example 161/864
Processing example 181/864
Processing example 201/864
Processing example 221/864
Processing example 241/864
Processing example 261/864
Processing example 281/864
Processing example 301/864
Processing example 321/864
Processing example 341/864
Processing example 361/864
Processing example 381/864
Processing example 401/864
Processing example 421/864
Processing example 441/864
Processing example 461/864
Processing example 481/864
Processing example 501/864
Processing example 521/864
Processing example 541/864
Processing example 561/864
Processing example 581/864
Processing example 601/864
Processing example 621/864
Processing example 641/864
Processing example 661/864
Processing example 681/864
Processing example

Map:   0%|          | 0/3022 [00:00<?, ? examples/s]

Map:   0%|          | 0/432 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Starting training...


Step,Training Loss,Validation Loss
50,14.9579,10.292455
100,3.9146,3.371082
150,2.8108,2.33274
200,2.1619,1.644126
250,1.6801,1.173841
300,1.3752,0.90714
350,1.1605,0.769701
400,1.0096,0.692351
450,0.9346,0.647684
500,0.8416,0.623924


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Fine-tuning completed for t5-small

=== Post Fine-tuning Evaluation for T5-Small ===

=== Zero-Shot Evaluation for T5-Small_finetuned ===
Processing example 1/864
Processing example 21/864
Processing example 41/864
Processing example 61/864
Processing example 81/864
Processing example 101/864
Processing example 121/864
Processing example 141/864
Processing example 161/864
Processing example 181/864
Processing example 201/864
Processing example 221/864
Processing example 241/864
Processing example 261/864
Processing example 281/864
Processing example 301/864
Processing example 321/864
Processing example 341/864
Processing example 361/864
Processing example 381/864
Processing example 401/864
Processing example 421/864
Processing example 441/864
Processing example 461/864
Processing example 481/864
Processing example 501/864
Processing example 521/864
Processing example 541/864
Processing example 561/864
Processing example 581/864
Processing example 601/864
Processing example 621/864
Proc

In [42]:
# Debug: Check actual zero-shot predictions
print("=== DEBUGGING ZERO-SHOT OUTPUTS ===")
print(f"First CodeT5 prediction: {codet5_zero_preds[0]}")
print(f"First T5 prediction: {t5_zero_preds[0]}")
print(f"Are they identical? {codet5_zero_preds[0] == t5_zero_preds[0]}")

print(f"\nSecond CodeT5 prediction: {codet5_zero_preds[1]}")
print(f"Second T5 prediction: {t5_zero_preds[1]}")
print(f"Are they identical? {codet5_zero_preds[1] == t5_zero_preds[1]}")


=== DEBUGGING ZERO-SHOT OUTPUTS ===
First CodeT5 prediction: bank_namebank_namebank_namebank_name(bank_name)bank_namebank_name(bank_name)(bank_name)(bank_name)(bank_name)(bank_name)(bank_name)
First T5 prediction: What is the total amount of socially responsible loans issued by each bank?
Are they identical? False

Second CodeT5 prediction: () {


Question:the amount ofthe amount ofthe amount ofthe customer_id, amount, tx_date, countrythe amount ofthe customer_id, amount, tx_date, countrythe amount ofthe customer_id, amount, tx_date, country,the customer_id, amount, tx_date, country,the amount ofthe customer_id, amount, tx_
Second T5 prediction: Count the transaction dates and the total transaction amount for transactions made by customers in India.
Are they identical? False


In [43]:
# Debug: Check actual fine-tuned predictions
print("=== DEBUGGING FINE-TUNED OUTPUTS ===")
print(f"First CodeT5 fine-tuned prediction: {codet5_ft_preds[0]}")
print(f"First T5 fine-tuned prediction: {t5_ft_preds[0]}")
print(f"Are they identical? {codet5_ft_preds[0] == t5_ft_preds[0]}")

print(f"\nSecond CodeT5 fine-tuned prediction: {codet5_ft_preds[1]}")
print(f"Second T5 fine-tuned prediction: {t5_ft_preds[1]}")
print(f"Are they identical? {codet5_ft_preds[1] == t5_ft_preds[1]}")

=== DEBUGGING FINE-TUNED OUTPUTS ===
First CodeT5 fine-tuned prediction: SQL: SELECT bank_name, SUM(loan_amount) FROM socially_responsible_loans WHERE loan_date BETWEEN DATEADD(month, -1, GETDATE());
Explanation: This query calculates the total amount of socially responsible loans issued by each bank by summing up the loan_amount column and filtering the results by bank_name and the loan_date column.
First T5 fine-tuned prediction: SQL: SELECT bank_name, loan_amount FROM socially_responsible_loans WHERE bank_name = 'socially_responsible_loans'; Explanation: This query calculates the total amount of socially responsible loans issued by each bank.
Are they identical? False

Second CodeT5 fine-tuned prediction: SQL: SELECT customer_id, tx_date, SUM(amount) FROM transactions_4 WHERE country = 'India';
Explanation: This query finds the transaction dates and the total transaction amount for transactions made by customers residing in India. It does this by using the SUM function on the amount

In [44]:
# =============================================================================
# CELL 14: Final Results Comparison
# =============================================================================

print("\n" + "="*80)
print("FINAL RESULTS COMPARISON")
print("="*80)

# Create comparison table
results_df = pd.DataFrame({
    'Model': ['CodeT5-Small (Zero-shot)', 'CodeT5-Small (Fine-tuned)', 
              'T5-Small (Zero-shot)', 'T5-Small (Fine-tuned)'],
    'Syntax Correctness': [
        codet5_zero_results['syntax_scores'],
        codet5_ft_results['syntax_scores'],
        t5_zero_results['syntax_scores'],
        t5_ft_results['syntax_scores']
    ],
    'Structural Similarity': [
        codet5_zero_results['structural_scores'],
        codet5_ft_results['structural_scores'],
        t5_zero_results['structural_scores'],
        t5_ft_results['structural_scores']
    ],
    'Exact Match': [
        codet5_zero_results['exact_match_scores'],
        codet5_ft_results['exact_match_scores'],
        t5_zero_results['exact_match_scores'],
        t5_ft_results['exact_match_scores']
    ],
    'Explanation BLEU': [
        codet5_zero_results['explanation_bleu_scores'],
        codet5_ft_results['explanation_bleu_scores'],
        t5_zero_results['explanation_bleu_scores'],
        t5_ft_results['explanation_bleu_scores']
    ],
    'Combined Score': [
        codet5_zero_results['combined_scores'],
        codet5_ft_results['combined_scores'],
        t5_zero_results['combined_scores'],
        t5_ft_results['combined_scores']
    ]
})

print(results_df.round(3))

# Calculate improvements
print("\n" + "="*50)
print("IMPROVEMENT ANALYSIS")
print("="*50)

codet5_improvement = codet5_ft_results['combined_scores'] - codet5_zero_results['combined_scores']
t5_improvement = t5_ft_results['combined_scores'] - t5_zero_results['combined_scores']

print(f"CodeT5-Small improvement: {codet5_improvement:.3f}")
print(f"T5-Small improvement: {t5_improvement:.3f}")

print(f"\nBest performing model: ", end="")
best_score = max(codet5_ft_results['combined_scores'], t5_ft_results['combined_scores'])
if codet5_ft_results['combined_scores'] == best_score:
    print("CodeT5-Small (Fine-tuned)")
else:
    print("T5-Small (Fine-tuned)")


FINAL RESULTS COMPARISON
                       Model  Syntax Correctness  Structural Similarity  \
0   CodeT5-Small (Zero-shot)               0.005                  0.203   
1  CodeT5-Small (Fine-tuned)               0.993                  0.638   
2       T5-Small (Zero-shot)               0.049                  0.203   
3      T5-Small (Fine-tuned)               0.889                  0.349   

   Exact Match  Explanation BLEU  Combined Score  
0          0.0             0.000           0.082  
1          0.1             0.286           0.549  
2          0.0             0.000           0.091  
3          0.0             0.208           0.380  

IMPROVEMENT ANALYSIS
CodeT5-Small improvement: 0.467
T5-Small improvement: 0.289

Best performing model: CodeT5-Small (Fine-tuned)


In [45]:
# =============================================================================
# CELL 15: Sample Predictions Analysis
# =============================================================================

print("\n" + "="*60)
print("SAMPLE PREDICTIONS ANALYSIS")
print("="*60)

# Show a few sample predictions
for i in range(min(3, len(test_dataset))):
    print(f"\n--- Example {i+1} ---")
    print(f"Input: {test_dataset[i]['input'][:200]}...")
    print(f"\nReference: {test_dataset[i]['output'][:200]}...")
    print(f"\nCodeT5 Zero-shot: {codet5_zero_preds[i][:200]}...")
    print(f"\nCodeT5 Fine-tuned: {codet5_ft_preds[i][:200]}...")
    print(f"\nT5 Zero-shot: {t5_zero_preds[i][:200]}...")
    print(f"\nT5 Fine-tuned: {t5_ft_preds[i][:200]}...")
    print("-" * 60)

print("\n🎉 EXPERIMENT COMPLETED SUCCESSFULLY! 🎉")
print("Check the results above to analyze model performance.")


SAMPLE PREDICTIONS ANALYSIS

--- Example 1 ---
Input: Convert this natural language question to SQL using the given database schema.

Database Schema:
CREATE TABLE socially_responsible_loans (bank_name VARCHAR(255), loan_amount DECIMAL(10,2), loan_date D...

Reference: SQL: SELECT bank_name, SUM(loan_amount) FROM socially_responsible_loans GROUP BY bank_name;
Explanation: This query calculates the total amount of socially responsible loans issued by each bank by sum...

CodeT5 Zero-shot: bank_namebank_namebank_namebank_name(bank_name)bank_namebank_name(bank_name)(bank_name)(bank_name)(bank_name)(bank_name)(bank_name)...

CodeT5 Fine-tuned: SQL: SELECT bank_name, SUM(loan_amount) FROM socially_responsible_loans WHERE loan_date BETWEEN DATEADD(month, -1, GETDATE());
Explanation: This query calculates the total amount of socially responsib...

T5 Zero-shot: What is the total amount of socially responsible loans issued by each bank?...

T5 Fine-tuned: SQL: SELECT bank_name, loan_amount FR

In [46]:
# # =============================================================================
# # MANUAL TESTING: New Task for Fine-Tuned Models
# # =============================================================================

# # Define your custom test case
# custom_schema = """
# CREATE TABLE employees (emp_id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2), hire_date DATE);
# CREATE TABLE projects (project_id INT, project_name VARCHAR(100), budget DECIMAL(12,2), start_date DATE);
# CREATE TABLE assignments (emp_id INT, project_id INT, hours_worked DECIMAL(5,2));
# """

# custom_question = "Find all employees in the Engineering department who have worked more than 100 hours total across all projects, and show their names and total hours worked."

# # Format the input (same as training format)
# custom_input = f"Schema: {custom_schema}\nQuestion: {custom_question}"

# print("=== CUSTOM TEST CASE ===")
# print("Input:")
# print(custom_input)
# print("\n" + "="*60)

# # Test both fine-tuned models
# def test_custom_query(model, tokenizer, input_text, model_name):
#     inputs = tokenizer(
#         input_text,
#         return_tensors="pt",
#         max_length=512,
#         truncation=True
#     ).to(device)
    
#     with torch.no_grad():
#         outputs = model.generate(
#             **inputs,
#             max_length=256,
#             num_beams=2,
#             early_stopping=True
#         )
    
#     prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
#     print(f"\n{model_name} Prediction:")
#     print(prediction)
#     return prediction

# # Test both models
# codet5_custom = test_custom_query(codet5_model, codet5_tokenizer, custom_input, "CodeT5-Small Fine-tuned")
# t5_custom = test_custom_query(t5_model, t5_tokenizer, custom_input, "T5-Small Fine-tuned")

In [47]:
# =============================================================================
# MANUAL TESTING: Easier Task for Fine-Tuned Models
# =============================================================================

# Define easier test cases (progressive difficulty)
print("=== EASIER CUSTOM TEST CASES ===")

# Test Case 1: Single table, simple filter
custom_schema_1 = """
CREATE TABLE employees (emp_id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2));
"""

custom_question_1 = "Find all employees in the Engineering department and show their names and salaries."

# Test Case 2: Single table with aggregation
custom_schema_2 = """
CREATE TABLE sales (sale_id INT, product VARCHAR(100), amount DECIMAL(10,2), region VARCHAR(50));
"""

custom_question_2 = "What is the total sales amount for each region?"

# Test Case 3: Simple join (2 tables)
custom_schema_3 = """
CREATE TABLE customers (customer_id INT, name VARCHAR(100), city VARCHAR(50));
CREATE TABLE orders (order_id INT, customer_id INT, amount DECIMAL(10,2));
"""

custom_question_3 = "Show customer names and their order amounts."

# Test all cases
test_cases = [
    (custom_schema_1, custom_question_1, "Single Table Filter"),
    (custom_schema_2, custom_question_2, "Single Table Aggregation"), 
    (custom_schema_3, custom_question_3, "Simple Two-Table Join")
]

def test_custom_query(model, tokenizer, input_text, model_name):
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True
    ).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=256,
            num_beams=2,
            early_stopping=True
        )
    
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return prediction

# Test all cases with both models
for i, (schema, question, test_name) in enumerate(test_cases, 1):
    # Format input using better prompting
    custom_input = f"""Convert this natural language question to SQL using the given database schema.

Database Schema:
{schema}

Question: {question}

Please generate a SQL query with explanation:"""
    
    print(f"\n{'='*60}")
    print(f"TEST CASE {i}: {test_name}")
    print(f"{'='*60}")
    print("Question:", question)
    print("Schema:", schema.strip())
    
    # Test CodeT5
    codet5_result = test_custom_query(codet5_model, codet5_tokenizer, custom_input, "CodeT5")
    print(f"\n🤖 CodeT5 Result:")
    print(codet5_result)
    
    # Test FLAN-T5  
    flan_t5_result = test_custom_query(t5_model, t5_tokenizer, custom_input, "FLAN-T5")
    print(f"\n🤖 FLAN-T5 Result:")
    print(flan_t5_result)
    
    print(f"\n{'='*60}")

print("\n🎯 ANALYSIS:")
print("Expected SQL queries:")
print("1. SELECT name, salary FROM employees WHERE department = 'Engineering';")  
print("2. SELECT region, SUM(amount) FROM sales GROUP BY region;")
print("3. SELECT c.name, o.amount FROM customers c JOIN orders o ON c.customer_id = o.customer_id;")

=== EASIER CUSTOM TEST CASES ===

TEST CASE 1: Single Table Filter
Question: Find all employees in the Engineering department and show their names and salaries.
Schema: CREATE TABLE employees (emp_id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2));

🤖 CodeT5 Result:
SQL: SELECT employees.emp_id, employees.name, employees.salary FROM employees WHERE employees.department = 'Engineering';
Explanation: This query finds all employees in the Engineering department and show their names and salaries.

🤖 FLAN-T5 Result:
SQL: SELECT em_id FROM employees WHERE department = 'Engineering'; Explanation: This query finds all employees in the Engineering department and shows their names and salaries.


TEST CASE 2: Single Table Aggregation
Question: What is the total sales amount for each region?
Schema: CREATE TABLE sales (sale_id INT, product VARCHAR(100), amount DECIMAL(10,2), region VARCHAR(50));

🤖 CodeT5 Result:
SQL: SELECT region, SUM(amount) FROM sales WHERE region = 'US'

In [49]:
# Quick install and import
!pip install gradio
import gradio as gr
print("✅ Gradio installed successfully!")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting gradio
  Downloading gradio-5.35.0-py3-none-any.whl.metadata (16 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.14-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.6.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.10.4 (from gradio)
  Downloading gradio_client-1.10.4-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.12.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
  Downloading safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)
Collecting semantic-version~=2.0 (from gradio)
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)
Col

In [None]:
# =============================================================================
# GRADIO INTERFACE FOR SQL GENERATION
# =============================================================================

import gradio as gr
import torch

# Install gradio first if not already installed
!pip install gradio

def generate_sql_query(schema, question, model_choice):
    """Generate SQL using the selected fine-tuned model"""
    
    # Format input using the better prompting style
    input_text = f"""Convert this natural language question to SQL using the given database schema.

Database Schema:
{schema}

Question: {question}

Please generate a SQL query with explanation:"""
    
    try:
        # Select model and tokenizer based on choice
        if model_choice == "CodeT5-Small (Fine-tuned)":
            model = codet5_model
            tokenizer = codet5_tokenizer
        else:  # FLAN-T5-Small (Fine-tuned)
            model = t5_model  # This should be your FLAN-T5 model
            tokenizer = t5_tokenizer
        
        # Tokenize input
        inputs = tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).to(device)
        
        # Generate prediction
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=256,
                num_beams=3,
                early_stopping=True,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                temperature=0.7
            )
        
        # Decode prediction
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return prediction
        
    except Exception as e:
        return f"Error generating SQL: {str(e)}"

def create_sql_interface():
    """Create the Gradio interface"""
    
    # Custom CSS for better styling
    css = """
    .gradio-container {
        max-width: 1200px !important;
    }
    .output-text {
        font-family: 'Courier New', monospace;
        background-color: #f8f9fa;
        padding: 10px;
        border-radius: 5px;
        border: 1px solid #dee2e6;
    }
    """
    
    with gr.Blocks(css=css, title="Text-to-SQL Generator", theme=gr.themes.Soft()) as demo:
        
        # Header
        gr.Markdown("""
        # 🚀 Text-to-SQL Generator
        ### Convert natural language questions to SQL queries using fine-tuned models
        
        **Models Available:**
        - **CodeT5-Small**: Pre-trained on code, optimized for SQL generation
        - **FLAN-T5-Small**: Instruction-tuned general language model
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                # Input Section
                gr.Markdown("## 📝 Input")
                
                schema_input = gr.Textbox(
                    label="Database Schema",
                    placeholder="CREATE TABLE employees (id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2));\nCREATE TABLE projects (id INT, name VARCHAR(100), budget DECIMAL(12,2));",
                    lines=8,
                    value="""CREATE TABLE employees (emp_id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2));
CREATE TABLE projects (project_id INT, project_name VARCHAR(100), budget DECIMAL(12,2));""",
                    info="Enter your database schema with CREATE TABLE statements"
                )
                
                question_input = gr.Textbox(
                    label="Natural Language Question",
                    placeholder="Find all employees with salary greater than 50000",
                    lines=3,
                    value="Find all employees in the Engineering department",
                    info="Enter your question in plain English"
                )
                
                model_choice = gr.Radio(
                    choices=["CodeT5-Small (Fine-tuned)", "FLAN-T5-Small (Fine-tuned)"],
                    label="Select Model",
                    value="CodeT5-Small (Fine-tuned)",
                    info="Choose which fine-tuned model to use"
                )
                
                generate_btn = gr.Button("🚀 Generate SQL", variant="primary", size="lg")
                
                # Model Performance Info
                gr.Markdown("""
                **Model Performance Summary:**
                - **CodeT5**: Combined Score 0.658, 10% Exact Match
                - **FLAN-T5**: Combined Score 0.546, 0% Exact Match
                """)
            
            with gr.Column(scale=1):
                # Output Section
                gr.Markdown("## 🎯 Generated SQL")
                
                output = gr.Textbox(
                    label="SQL Query + Explanation",
                    lines=15,
                    max_lines=25,
                    elem_classes=["output-text"],
                    show_copy_button=True,
                    info="Generated SQL query with explanation"
                )
                
                # Quick Analysis
                analysis_output = gr.Textbox(
                    label="Quick Analysis",
                    lines=3,
                    visible=False,  # Can be made visible for advanced features
                    info="Automatic analysis of the generated SQL"
                )
        
        # Example Queries Section
        gr.Markdown("## 📚 Example Queries")
        
        examples = gr.Examples(
            examples=[
                [
                    """CREATE TABLE customers (customer_id INT, name VARCHAR(100), city VARCHAR(50));
CREATE TABLE orders (order_id INT, customer_id INT, amount DECIMAL(10,2), order_date DATE);""",
                    "Show customer names and their total order amounts",
                    "CodeT5-Small (Fine-tuned)"
                ],
                [
                    """CREATE TABLE employees (emp_id INT, name VARCHAR(100), department VARCHAR(50), salary DECIMAL(10,2));""",
                    "What is the average salary by department?",
                    "CodeT5-Small (Fine-tuned)"
                ],
                [
                    """CREATE TABLE products (product_id INT, name VARCHAR(100), category VARCHAR(50), price DECIMAL(10,2));
CREATE TABLE sales (sale_id INT, product_id INT, quantity INT, sale_date DATE);""",
                    "Find the top 5 selling products by total quantity sold",
                    "CodeT5-Small (Fine-tuned)"
                ],
                [
                    """CREATE TABLE students (student_id INT, name VARCHAR(100), major VARCHAR(50));
CREATE TABLE courses (course_id INT, course_name VARCHAR(100), credits INT);
CREATE TABLE enrollments (student_id INT, course_id INT, grade VARCHAR(2));""",
                    "List all students majoring in Computer Science with their enrolled courses",
                    "FLAN-T5-Small (Fine-tuned)"
                ]
            ],
            inputs=[schema_input, question_input, model_choice],
            label="Try these examples:"
        )
        
        # Instructions
        gr.Markdown("""
        ## 📋 Instructions
        1. **Enter your database schema** in the first text box (use CREATE TABLE statements)
        2. **Ask your question** in natural language
        3. **Select a model** (CodeT5 generally performs better for SQL tasks)  
        4. **Click Generate SQL** to get your query
        5. **Copy the result** using the copy button in the output box
        
        ## ⚠️ Important Notes
        - Models work best with **clear, specific questions**
        - **CodeT5** generally produces more accurate SQL than FLAN-T5
        - Results may vary for **complex multi-table queries**
        - Always **validate generated SQL** before using in production
        """)
        
        # Connect the generate button to the function
        generate_btn.click(
            fn=generate_sql_query,
            inputs=[schema_input, question_input, model_choice],
            outputs=output,
            api_name="generate_sql"
        )
    
    return demo

# Launch the interface
print("🚀 Creating SQL Generation Interface...")
print("📊 Make sure your models are loaded (codet5_model, t5_model, etc.)")

# Create and launch the demo
demo = create_sql_interface()
demo.launch(
    share=True,          # Creates public link
    debug=True,          # Shows errors in console
    server_name="0.0.0.0",  # Allows external access
    server_port=7860,    # Default Gradio port
    show_error=True      # Shows errors in interface
)

print("✅ Interface launched! Check the output above for the public URL.")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


🚀 Creating SQL Generation Interface...
📊 Make sure your models are loaded (codet5_model, t5_model, etc.)
* Running on local URL:  http://0.0.0.0:7860
* Running on public URL: https://d3766dd33907c53ba6.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
