# Test Trained Model from SageMaker

This notebook demonstrates how to:
1. Download the trained model from S3
2. Load the fine-tuned LoRA weights
3. Test SQL generation with proper prompts
4. Evaluate model performance

**Latest Successful Training Job**: gl-rl-gpu-20250923-033651

## 1. Setup Environment

In [None]:
# Install required packages
!pip install torch transformers peft datasets accelerate boto3 -q

In [None]:
import torch
import json
import boto3
import tarfile
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import time

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

# AWS Configuration
AWS_PROFILE = "personal-yahoo"
REGION = "us-east-1"
ACCOUNT_ID = "340350204194"
BUCKET = f"gl-rl-model-sagemaker-{ACCOUNT_ID}-{REGION}"

# Initialize boto3 session
session = boto3.Session(profile_name=AWS_PROFILE, region_name=REGION)
s3_client = session.client('s3')

print(f"\nAWS Configuration:")
print(f"  Profile: {AWS_PROFILE}")
print(f"  Region: {REGION}")
print(f"  Bucket: {BUCKET}")

## 2. Download Model from S3

In [None]:
# Specify the training job name
JOB_NAME = "gl-rl-gpu-20250923-033651"  # Replace with your job name
MODEL_S3_PATH = f"s3://{BUCKET}/output/{JOB_NAME}/output/model.tar.gz"

print(f"Model S3 path: {MODEL_S3_PATH}")

# Create local directory
model_dir = Path(f"./models/{JOB_NAME}")
model_dir.mkdir(parents=True, exist_ok=True)

# Download model
print("\nDownloading model from S3...")
model_tar_path = model_dir / "model.tar.gz"

s3_client.download_file(
    Bucket=BUCKET,
    Key=f"output/{JOB_NAME}/output/model.tar.gz",
    Filename=str(model_tar_path)
)

print(f"Downloaded model to: {model_tar_path}")

# Extract model
print("Extracting model...")
with tarfile.open(model_tar_path, 'r:gz') as tar:
    tar.extractall(model_dir)

print(f"✅ Model extracted to: {model_dir}")

# List extracted files
print("\nExtracted files:")
for file in model_dir.glob("*"):
    if file.is_file() and file.name != "model.tar.gz":
        print(f"  - {file.name}")

## 3. Load Fine-tuned Model

In [None]:
# Base model name (same as used in training)
BASE_MODEL = "Qwen/Qwen2.5-Coder-1.5B-Instruct"

print(f"Loading base model: {BASE_MODEL}")
print("This may take a few minutes...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True,
    padding_side='left'
)

# Set padding token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("✅ Tokenizer loaded")

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None
)

print("✅ Base model loaded")

# Load LoRA configuration and model
print("\nLoading LoRA weights...")
try:
    # Load PEFT config
    peft_config = PeftConfig.from_pretrained(str(model_dir))
    print(f"LoRA Config: r={peft_config.r}, alpha={peft_config.lora_alpha}")
    
    # Load fine-tuned model
    model = PeftModel.from_pretrained(base_model, str(model_dir))
    
    # Optionally merge LoRA weights for faster inference
    # model = model.merge_and_unload()
    
    print("✅ Fine-tuned model loaded successfully!")
except Exception as e:
    print(f"⚠️ Could not load LoRA weights: {e}")
    print("Using base model instead")
    model = base_model

# Move to device
if device.type == "cuda":
    model = model.cuda()

# Set to eval mode
model.eval()

print(f"\nModel ready on {device}")
print(f"Model size: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters")

## 4. Define Generation Function with Proper Prompt Format

In [None]:
def generate_sql(query, context="", max_new_tokens=150, temperature=0.7, verbose=False):
    """
    Generate SQL from natural language query using the exact training format
    """
    
    # Use the EXACT format from training
    prompt = f"""<|im_start|>system
You are a SQL expert. Generate SQL queries based on natural language questions.
Context: {context}<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>assistant"""
    
    if verbose:
        print("=" * 60)
        print("PROMPT:")
        print(prompt)
        print("=" * 60)
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate
    start_time = time.time()
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            num_return_sequences=1
        )
    
    generation_time = time.time() - start_time
    
    # Decode full output
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # Extract SQL after assistant tag
    if "<|im_start|>assistant" in full_response:
        sql = full_response.split("<|im_start|>assistant")[-1]
    elif "assistant" in full_response:
        sql = full_response.split("assistant")[-1]
    else:
        sql = full_response[len(prompt):]
    
    # Clean up SQL
    sql = sql.replace("<|im_end|>", "").strip()
    
    if verbose:
        print("\nFULL RESPONSE:")
        print(full_response)
        print("\nEXTRACTED SQL:")
        print(sql)
        print("=" * 60)
    
    return sql, generation_time

print("✅ SQL generation function ready")

## 5. Test SQL Generation

In [None]:
# Test with a simple query first (verbose mode to see what's happening)
test_query = "Show me all customers"
test_context = "customers(id, name, email, created_at)"

print(f"Test Query: {test_query}")
print(f"Context: {test_context}")
print()

sql, gen_time = generate_sql(test_query, test_context, verbose=True)

print(f"\n📊 Results:")
print(f"Generated SQL: {sql}")
print(f"Generation time: {gen_time:.2f} seconds")

## 6. Test Multiple Queries

In [None]:
# Test queries matching training data format
test_queries = [
    {
        "query": "Show all customers",
        "context": "customers(id, name, email, created_at)",
        "expected": "SELECT * FROM customers;"
    },
    {
        "query": "Calculate average order value",
        "context": "orders(id, customer_id, total_amount, order_date)",
        "expected": "SELECT AVG(total_amount) as avg_order_value FROM orders;"
    },
    {
        "query": "Count orders by customer",
        "context": "orders(id, customer_id, total_amount, order_date), customers(id, name)",
        "expected": "SELECT customer_id, COUNT(*) as order_count FROM orders GROUP BY customer_id;"
    },
    {
        "query": "Find top 5 products by revenue",
        "context": "products(id, name, category), sales(id, product_id, quantity, price)",
        "expected": "SELECT p.name, SUM(s.quantity * s.price) as revenue FROM products p JOIN sales s ON p.id = s.product_id GROUP BY p.id, p.name ORDER BY revenue DESC LIMIT 5;"
    },
    {
        "query": "Get employees hired last month",
        "context": "employees(id, name, department, hire_date, salary)",
        "expected": "SELECT * FROM employees WHERE hire_date >= DATE_SUB(CURDATE(), INTERVAL 1 MONTH);"
    }
]

print("Testing Multiple Queries")
print("=" * 80)

results = []
for i, test in enumerate(test_queries, 1):
    print(f"\n📝 Query {i}/{len(test_queries)}")
    print(f"Question: {test['query']}")
    print(f"Schema: {test['context']}")
    print("-" * 40)
    
    sql, gen_time = generate_sql(test['query'], test['context'], verbose=False)
    
    print(f"Expected: {test['expected']}")
    print(f"Generated: {sql}")
    print(f"Time: {gen_time:.2f}s")
    
    # Simple similarity check
    is_similar = any(keyword in sql.upper() for keyword in ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP', 'ORDER'])
    print(f"Valid SQL: {'✅' if is_similar else '❌'}")
    
    results.append({
        'query': test['query'],
        'generated': sql,
        'expected': test['expected'],
        'time': gen_time,
        'valid': is_similar
    })

# Summary
print("\n" + "=" * 80)
print("📊 SUMMARY")
print("=" * 80)
valid_count = sum(1 for r in results if r['valid'])
avg_time = sum(r['time'] for r in results) / len(results)

print(f"Total queries: {len(results)}")
print(f"Valid SQL generated: {valid_count}/{len(results)} ({valid_count/len(results)*100:.1f}%)")
print(f"Average generation time: {avg_time:.2f} seconds")

## 7. Test with Training Data Format

In [None]:
# Load actual training data to test
training_data_path = "../../gl_rl_model/data/training/query_pairs.jsonl"

if Path(training_data_path).exists():
    print("Testing with actual training data format...\n")
    
    with open(training_data_path, 'r') as f:
        training_examples = [json.loads(line) for line in f.readlines()[:5]]  # Test first 5
    
    for i, example in enumerate(training_examples, 1):
        print(f"\nExample {i}:")
        print(f"Query: {example['query']}")
        
        # Use 'context' or 'reasoning' field
        context = example.get('context', example.get('reasoning', ''))
        print(f"Context: {context}")
        print(f"Expected SQL: {example['sql']}")
        
        sql, _ = generate_sql(example['query'], context)
        print(f"Generated SQL: {sql}")
        print("-" * 60)
else:
    print(f"Training data not found at {training_data_path}")

## 8. Save Results and Analysis

In [None]:
# Save test results
results_file = f"test_results_{JOB_NAME}.json"

with open(results_file, 'w') as f:
    json.dump({
        'job_name': JOB_NAME,
        'model': BASE_MODEL,
        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'device': str(device),
        'results': results,
        'summary': {
            'total_queries': len(results),
            'valid_sql': valid_count,
            'success_rate': valid_count/len(results)*100,
            'avg_generation_time': avg_time
        }
    }, f, indent=2)

print(f"\n✅ Results saved to: {results_file}")

# Display training metrics if available
metrics_path = model_dir / "metrics.json"
if metrics_path.exists():
    with open(metrics_path) as f:
        metrics = json.load(f)
    print("\n📈 Training Metrics:")
    for key, value in metrics.items():
        print(f"  {key}: {value}")

## 9. Troubleshooting Guide

### If the model generates poor SQL:

1. **Check the prompt format** - Must match training exactly
2. **Verify model weights loaded** - Check for LoRA config files
3. **Adjust temperature** - Lower (0.3-0.5) for more deterministic output
4. **Increase max_new_tokens** - Some queries need more tokens
5. **Check training loss** - Should have decreased during training

### Common Issues:

- **Model outputs conversation instead of SQL**: Prompt format mismatch
- **Model outputs generic text**: LoRA weights not loaded properly
- **Model outputs incomplete SQL**: Increase max_new_tokens
- **Model outputs garbage**: Check if training actually worked (loss > 0)

### Next Steps:

1. **More Training**: Train for more epochs (5-10)
2. **More Data**: Add more diverse training examples
3. **Hyperparameter Tuning**: Adjust learning rate, LoRA rank
4. **Larger Model**: Try Qwen2.5-Coder-7B for better performance