# Test Notebook 04: Model Loading

**Purpose**: Verify model loads and runs inference

**Tests**:
1. Load Llama-3.1-8B-Instruct model
2. Check GPU memory usage
3. Test inference on 1-2 sample prompts
4. Verify output format (a|b|c|d)
5. Test with different temperature settings
6. Measure inference speed


In [1]:
import sys
sys.path.append('..')

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from src.agentic.prompts import get_prompt, format_for_llama, parse_model_output, letter_to_label
import yaml
from pathlib import Path
import time


  from .autonotebook import tqdm as notebook_tqdm


## 1. Check GPU and Configuration


In [2]:
# Check GPU
print("GPU Availability:")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  Device: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"  Current allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
else:
    print("  WARNING: No GPU available! Model loading will fail or be very slow.")

# Load config
with open('../configs/baseline.yaml', 'r') as f:
    config = yaml.safe_load(f)

print(f"\nBaseline Configuration:")
print(f"  Model: {config['model_name']}")
print(f"  dtype: {config['dtype']}")
print(f"  Load in 4-bit: {config['load_in_4bit']}")
print(f"  Temperature: {config['temperature']}")
print(f"  Max new tokens: {config['max_new_tokens']}")


GPU Availability:
  CUDA available: True
  Device: NVIDIA GeForce RTX 4090
  Memory: 23.5 GB
  Current allocated: 0.00 GB

Baseline Configuration:
  Model: meta-llama/Meta-Llama-3.1-8B-Instruct
  dtype: bf16
  Load in 4-bit: False
  Temperature: 0.1
  Max new tokens: 8


## 2. Load Model and Tokenizer

This will take a few minutes and ~16GB GPU memory for bf16


In [3]:
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config['model_name'])

print("Loading model... (this may take a few minutes)")
start_time = time.time()

# Load model with appropriate dtype
if config['load_in_4bit']:
    from transformers import BitsAndBytesConfig
    quantization_config = BitsAndBytesConfig(load_in_4bit=True)
    model = AutoModelForCausalLM.from_pretrained(
        config['model_name'],
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.float16
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        config['model_name'],
        device_map="auto",
        torch_dtype=torch.bfloat16 if config['dtype'] == 'bf16' else torch.float16
    )

load_time = time.time() - start_time

print(f"\n✅ Model loaded in {load_time:.1f}s")
if torch.cuda.is_available():
    print(f"GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")


Loading tokenizer...
Loading model... (this may take a few minutes)


Downloading shards:  25%|██▌       | 1/4 [07:53<23:40, 473.51s/it]


KeyboardInterrupt: 

## 3. Test Inference on Sample Prompts

Create sample prompts and test model inference


In [None]:
# Sample test cases
test_cases = [
    {
        "note": "Social History: Tob (-), EtOH - a glass of wine 1-2x/month, IVDU (-), lives with her husband.",
        "trigger": "IVDU",
        "expected": "none"
    },
    {
        "note": "Patient admits to daily heroin use by injection. Currently using 2-3 bags per day.",
        "trigger": "heroin",
        "expected": "current"
    },
    {
        "note": "History of cocaine abuse in 1990s, has been clean for 20 years.",
        "trigger": "cocaine", 
        "expected": "past"
    }
]

# Create prompts
for i, case in enumerate(test_cases):
    prompt_dict = get_prompt("status_v1", note=case["note"], trigger=case["trigger"])
    formatted = format_for_llama(prompt_dict["system"], prompt_dict["user"])
    
    print(f"\nTest Case {i+1}:")
    print("=" * 80)
    print(f"Note: {case['note'][:100]}...")
    print(f"Trigger: '{case['trigger']}'")
    print(f"Expected: {case['expected']}")
    print(f"\nFormatted prompt length: {len(formatted)} chars")
    print("=" * 80)


In [None]:
# Run inference on test cases
print("Running inference...")
print("=" * 80)

for i, case in enumerate(test_cases):
    prompt_dict = get_prompt("status_v1", note=case["note"], trigger=case["trigger"])
    formatted = format_for_llama(prompt_dict["system"], prompt_dict["user"])
    
    # Tokenize
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    
    # Generate
    start_time = time.time()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=config['max_new_tokens'],
            temperature=config['temperature'],
            top_p=config['top_p'],
            do_sample=True if config['temperature'] > 0 else False
        )
    inference_time = time.time() - start_time
    
    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract response (remove prompt)
    response = generated_text[len(formatted):].strip()
    
    # Parse output
    letter = parse_model_output(response)
    predicted_label = letter_to_label(letter) if letter else "PARSE_ERROR"
    
    # Display results
    print(f"\nTest Case {i+1}:")
    print(f"  Trigger: '{case['trigger']}'")
    print(f"  Expected: {case['expected']}")
    print(f"  Model output: '{response[:50]}'")
    print(f"  Parsed letter: {letter}")
    print(f"  Predicted label: {predicted_label}")
    print(f"  ✓ CORRECT" if predicted_label == case['expected'] else f"  ✗ INCORRECT")
    print(f"  Inference time: {inference_time:.3f}s")
    print("-" * 80)

print("\n" + "=" * 80)


## ✅ Validation Checklist

**Check before proceeding:**

- Model loads successfully on GPU
- GPU memory usage reasonable (<20GB for bf16)
- Inference runs without errors on sample data
- Output parsing (a|b|c|d → labels) works correctly
- Model produces sensible outputs for clinical notes
- Inference speed acceptable (~0.2-0.5s per sample)

If all checks pass, proceed to implementing the full baseline inference engine!
