# COGEX Model Validation

Post-training validation notebook to verify the finetuned model:

1. Load the base model + LoRA adapter
2. Test generation on sample prompts
3. Check output format compliance (COGEX structure)
4. Compute validation perplexity

In [None]:
import os
import json
from pathlib import Path

import torch
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

load_dotenv()
MODELS_PATH = os.getenv('MODELS_PATH')
print(f"MODELS_PATH: {MODELS_PATH}")

## Configuration

Update these paths to point to your trained model.

In [None]:
# CONFIGURE THESE
BASE_MODEL = "Llama-3.1-8B"  # Base model name
ADAPTER_PATH = "outputs/cogex-Llama-3.1-8B-YYYYMMDD_HHMMSS"  # Path to trained adapter

# Paths
BASE_MODEL_PATH = Path(MODELS_PATH) / BASE_MODEL
VAL_DATA_PATH = "data/val.json"

print(f"Base model: {BASE_MODEL_PATH}")
print(f"Adapter: {ADAPTER_PATH}")
print(f"\nBase model exists: {BASE_MODEL_PATH.exists()}")
print(f"Adapter exists: {Path(ADAPTER_PATH).exists()}")

## 1. Load Model

In [None]:
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_PATH,
    local_files_only=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL_PATH,
    local_files_only=True,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()

print(f"\nModel loaded successfully!")
print(f"Device: {next(model.parameters()).device}")

## 2. Generation Function

In [None]:
def generate_cogex_output(
    instruction: str,
    input_text: str = "",
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    do_sample: bool = True,
) -> str:
    """Generate COGEX-style code and execution output."""
    # Format prompt (same as training)
    prompt = f"### Instruction\n{instruction}\n### Input\n{input_text}\n"
    
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated

## 3. Test Generation

In [None]:
# Test cases
test_cases = [
    {
        "instruction": "Give three tips for staying healthy.",
        "input": ""
    },
    {
        "instruction": "Identify the odd one out.",
        "input": '["Twitter", "Instagram", "Telegram"]'
    },
    {
        "instruction": "Pluralize the given word.",
        "input": "Corpus"
    },
    {
        "instruction": "What are the three primary colors?",
        "input": ""
    },
]

In [None]:
# Run tests
for i, case in enumerate(test_cases):
    print("=" * 70)
    print(f"Test Case {i+1}")
    print("=" * 70)
    print(f"Instruction: {case['instruction']}")
    print(f"Input: {case['input'] or '(empty)'}")
    print("\n--- Generated Output ---\n")
    
    output = generate_cogex_output(case['instruction'], case['input'])
    print(output)
    print()

## 4. Format Compliance Check

In [None]:
def check_output_format(output: str) -> dict:
    """Check if generated output follows COGEX format."""
    checks = {
        'has_code_section': '### Code' in output,
        'has_output_section': '### Output' in output,
        'has_function_def': 'def ' in output,
        'has_function_call': '>>>' in output,
        'has_return_dict': "'answer'" in output or '"answer"' in output,
        'has_docstring': '"""' in output,
        'has_undefined_functions': False,  # Check for calls without def
    }
    
    # Check for undefined function calls (pseudo-code pattern)
    import re
    # Look for function calls that aren't defined in the output
    func_calls = re.findall(r'= (\w+)\(', output)
    func_defs = re.findall(r'def (\w+)\(', output)
    undefined = set(func_calls) - set(func_defs) - {'print', 'len', 'str', 'int', 'list', 'dict', 'range'}
    checks['has_undefined_functions'] = len(undefined) > 0
    checks['undefined_functions'] = list(undefined)
    
    checks['all_passed'] = all([
        checks['has_code_section'],
        checks['has_output_section'],
        checks['has_function_def'],
        checks['has_function_call'],
        checks['has_return_dict'],
    ])
    
    return checks

In [None]:
# Run format compliance on test cases
print("Format Compliance Results")
print("=" * 70)

format_results = []
for i, case in enumerate(test_cases):
    output = generate_cogex_output(case['instruction'], case['input'])
    result = check_output_format(output)
    format_results.append(result)
    
    status = "PASS" if result['all_passed'] else "FAIL"
    print(f"\nTest {i+1}: {status}")
    print(f"  Instruction: {case['instruction'][:50]}...")
    for key, value in result.items():
        if key not in ['all_passed', 'undefined_functions']:
            symbol = "✓" if value else "✗"
            print(f"  {symbol} {key}")
    if result['undefined_functions']:
        print(f"  Undefined functions: {result['undefined_functions']}")

# Summary
n_passed = sum(1 for r in format_results if r['all_passed'])
print(f"\n{'=' * 70}")
print(f"Format compliance: {n_passed}/{len(format_results)} passed")

## 5. Validation Perplexity

In [None]:
# Load validation data
print(f"Loading validation data from {VAL_DATA_PATH}...")
with open(VAL_DATA_PATH, 'r') as f:
    val_data = json.load(f)
print(f"Loaded {len(val_data)} examples")

In [None]:
def compute_perplexity(examples: list, max_examples: int = 100) -> float:
    """Compute perplexity on validation examples."""
    import math
    
    total_loss = 0
    total_tokens = 0
    
    model.eval()
    
    for i, ex in enumerate(examples[:max_examples]):
        if i % 20 == 0:
            print(f"  Processing {i}/{min(len(examples), max_examples)}...")
        
        text = ex['text']
        inputs = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=2048
        )
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss.item()
        
        n_tokens = inputs['input_ids'].shape[1]
        total_loss += loss * n_tokens
        total_tokens += n_tokens
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    
    return perplexity, avg_loss

In [None]:
print("Computing validation perplexity (first 100 examples)...")
ppl, avg_loss = compute_perplexity(val_data, max_examples=100)

print(f"\n{'=' * 50}")
print("Validation Metrics")
print(f"{'=' * 50}")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Perplexity: {ppl:.2f}")
print(f"{'=' * 50}")

## 6. Compare with Base Model (Optional)

In [None]:
# Uncomment to compare with base model (without LoRA)
# This helps verify the finetuning improved performance

# print("\nLoading base model WITHOUT adapter for comparison...")
# base_only = AutoModelForCausalLM.from_pretrained(
#     BASE_MODEL_PATH,
#     local_files_only=True,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
# )
# base_only.eval()

# # Temporarily swap model for perplexity computation
# _model_backup = model
# model = base_only

# print("Computing base model perplexity...")
# base_ppl, base_loss = compute_perplexity(val_data, max_examples=100)

# model = _model_backup
# del base_only
# torch.cuda.empty_cache()

# print(f"\nComparison:")
# print(f"  Base model PPL: {base_ppl:.2f}")
# print(f"  Finetuned PPL:  {ppl:.2f}")
# print(f"  Improvement:    {(base_ppl - ppl) / base_ppl * 100:.1f}%")

## 7. Summary

In [None]:
print("\n" + "=" * 70)
print("VALIDATION SUMMARY")
print("=" * 70)
print(f"\nModel: {BASE_MODEL}")
print(f"Adapter: {ADAPTER_PATH}")
print(f"\nMetrics:")
print(f"  Validation Perplexity: {ppl:.2f}")
print(f"  Format Compliance: {n_passed}/{len(format_results)} passed")
print(f"\nChecklist:")
print(f"  [{'✓' if ppl < 50 else '?'}] Perplexity reasonable (< 50 expected for finetuned model)")
print(f"  [{'✓' if n_passed == len(format_results) else '✗'}] All format checks pass")
print(f"  [ ] Manual review of generated outputs looks correct")
print("\n" + "=" * 70)