# Fine‑tuning **Gemma 3 27B IT** with QLoRA on Google Colab A100  

This end‑to‑end notebook shows how to:
1. Install **[Unsloth](https://github.com/unslothai/unsloth)** for fast 4‑bit loading & LoRA training.
2. Convert our **Teach** human‑labelled CSV + framework JSON into **ChatML `.jsonl`**.
3. Fine‑tune with **Transformers ✕ PEFT ✕ bitsandbytes** (QLoRA under the hood).
4. **Advanced Checkpointing**: Save every 50 steps, keep 5 checkpoints, auto-resume training.
5. Push the LoRA adapter **and** a merged FP16 model to the Hugging Face Hub.

> **Colab requirements** · A100‑40 GB tier (Pro+).  
> The Gemma 3 27B model loads in ∼19 GB VRAM at 4‑bit, leaving room for LoRA params.

**Key Features:**
- ✅ **Built-in Checkpoint Support** with automatic resume
- ✅ **Connection Keepalive** to prevent disconnections
- ✅ **Optimized for A100** with BF16 precision
- ✅ **Memory Efficient** with gradient checkpointing

References: Unsloth documentation, HuggingFace Transformers guide, Google Gemma 3 model card.

In [None]:
# 🔌 0. Setup Connection Keepalive (Run this first!)
from IPython.display import display, HTML, JavaScript

# JavaScript to prevent Colab disconnection
keepalive_js = """
function ClickConnect(){
    console.log("Colab keepalive: Clicking connect button"); 
    var buttons = document.querySelectorAll("colab-connect-button");
    if(buttons.length > 0){
        buttons[0].shadowRoot.querySelector("#connect").click();
    }
    // Also try alternative selectors
    var altButton = document.querySelector('colab-connect-button');
    if(altButton && altButton.shadowRoot){
        var connectBtn = altButton.shadowRoot.querySelector('#connect');
        if(connectBtn) connectBtn.click();
    }
}

// Click every 60 seconds (60000 ms)
var keepAliveInterval = setInterval(ClickConnect, 60000);
console.log("Colab keepalive started - will click connect every 60 seconds");

// Function to stop keepalive
function stopKeepAlive(){
    clearInterval(keepAliveInterval);
    console.log("Colab keepalive stopped");
}
"""

display(JavaScript(keepalive_js))
print("✅ Connection keepalive activated - your session will stay connected!")
print("💡 To stop keepalive later, run: stopKeepAlive() in browser console")

In [None]:
# ⚙️ 1. Environment check
import os, subprocess, json, torch, sys
print(f"Torch version: {torch.__version__}")
!nvidia-smi -L
!nvidia-smi --query-gpu=name,memory.total --format=csv

# Verify A100 GPU
gpu_info = !nvidia-smi --query-gpu=name --format=csv,noheader,nounits
if 'A100' not in str(gpu_info):
    print("⚠️ WARNING: A100 GPU not detected. This notebook is optimized for A100.")
    print(f"Current GPU: {gpu_info}")
else:
    print("✅ A100 GPU detected - ready for Gemma 3 27B training!")

Torch version: 2.6.0+cu124
GPU 0: NVIDIA A100-SXM4-40GB (UUID: GPU-192476bc-ac85-af36-d7bd-055f87277638)
name, memory.total [MiB]
NVIDIA A100-SXM4-40GB, 40960 MiB


In [None]:
%%bash
set -e  # stop on first failure

# 0️⃣  Clean out leftovers that would confuse the resolver
pip uninstall -y unsloth torch torchvision torchaudio fsspec protobuf gcsfs || true
pip cache purge

# 1️⃣  Core PyTorch stack for CUDA 12.4 (current Colab backend, 2025-06)
pip install --upgrade pip
pip install --extra-index-url https://download.pytorch.org/whl/cu124 \
           torch==2.6.0 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 triton

# 2️⃣  Unsloth compiled **for Torch 2.6 + CUDA 12.4**
pip install "unsloth[cu124-torch260] @ git+https://github.com/unslothai/unsloth.git"

# 3️⃣  Your usual helpers, but tell pip **not to touch Torch again**
pip install --no-deps peft transformers bitsandbytes datasets trl accelerate huggingface_hub

# 4️⃣  Bring the remaining stragglers in line
pip install fsspec==2025.3.2 protobuf==5.29.1 gcsfs==2025.3.2 tensorflow-metadata -U

# 5️⃣  Install additional dependencies for advanced training
pip install wandb psutil humanize GPUtil

In [None]:
# 🔑 2. Login to the Hugging Face Hub (one‑time per session)
from huggingface_hub import notebook_login
notebook_login()

## 3 Data upload / mounting
Place the following files in **Google Drive** → `/MyDrive/teach_data/`:
```
high_Teach_1.json                # framework spec
peru_cleaned_transcripts.csv     # human labels + transcripts
```
Alternatively, upload them directly to the Colab *Files* pane and adjust paths below.

In [None]:
# ▶️ 3a (Optional) Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 📂 3b Define paths
DATA_DIR = "/content/drive/MyDrive/teach_data"   # change if not using Drive
CSV_PATH = f"{DATA_DIR}/peru_cleaned_transcripts.csv"
FRAME_PATH = f"{DATA_DIR}/high_Teach_1.json"

assert os.path.exists(CSV_PATH), "❌ CSV not found – check path!"
assert os.path.exists(FRAME_PATH), "❌ Framework JSON not found!"
print("✅ Data files found successfully!")

## 4 Materialise ChatML `.jsonl`
Every training example is a triple of *system / user / assistant* messages.  
We reuse the helper functions shipped in the evaluation repo.

In [None]:
# 🛠️ 4. Build train/val JSONL
import pandas as pd
import json
import random
import math
import tqdm
import numpy as np

RNG = random.Random(42)
TRAIN_OUT = "train_chatml.jsonl"
VAL_OUT   = "val_chatml.jsonl"

df = pd.read_csv(CSV_PATH, dtype=str)
framework = json.load(open(FRAME_PATH))

print(f"📊 Loaded {len(df)} rows from CSV")
print(f"🎯 Framework has {len(framework['structure']['domains'])} domains")

# Extract IDs
clip_info = df['School_Clip'].str.extract(r'(?P<base_id>\d{6,7})\s*Clip\s*(?P<clip>[12])')
df['base_id'] = clip_info['base_id']
df['clip_number'] = clip_info['clip'].map({'1':'first','2':'last'})

# Helper: prompt builder mirroring evaluation pipeline
def make_prompt(comp, transcript):
    system = "You are an expert Teach framework scorer. Reply ONLY with valid JSON: {\"score\":<label>,\"analysis\":<text>}"
    user = (
        f"### Component: {comp['name']}\n"
        f"### Allowed labels: {', '.join(map(str, comp.get('scoreList',['Y','N'])))}\n\n"
        f"Transcript:\n{transcript}\n"
    )
    return system, user

rows_train, rows_val = [], []

for _, row in tqdm.tqdm(df.iterrows(), total=len(df), desc="Processing examples"):
    transcript = (
        row.get('First Audio Transcript Text', '') if row['clip_number']=='first'
        else row.get('Last Audio Transcript Text', '')
    )
    if not isinstance(transcript, str) or len(transcript.strip())==0:
        continue  # skip blank transcripts
    for domain in framework['structure']['domains']:
        for comp in domain['components']:
            cname = comp['name']
            if cname not in row or pd.isna(row[cname]):
                continue
            label = str(row[cname]).strip()
            sys_msg, user_msg = make_prompt(comp, transcript)
            assistant_msg = json.dumps({"score": label }, ensure_ascii=False)
            example = {
                "messages": [
                    {"role": "system", "content": sys_msg},
                    {"role": "user", "content": user_msg},
                    {"role": "assistant", "content": assistant_msg}
                ]
            }
            # 80‑20 split on the fly
            (rows_val if RNG.random()<0.2 else rows_train).append(example)

# Write files
for path, rows in [(TRAIN_OUT, rows_train),(VAL_OUT,rows_val)]:
    with open(path,'w',encoding='utf-8') as f:
        for ex in rows:
            f.write(json.dumps(ex, ensure_ascii=False)+'\n')
    print(f"✅ Wrote {len(rows):,} examples → {path}")

print(f"\n📈 Dataset summary:")
print(f"  • Training examples: {len(rows_train):,}")
print(f"  • Validation examples: {len(rows_val):,}")
print(f"  • Total examples: {len(rows_train + rows_val):,}")

## 5 Model Loading with Advanced Configuration
* **Gemma 3 27B IT** - Latest multimodal model from Google  
* **BF16** precision for A100 optimal performance  
* **4-bit quantization** for memory efficiency  
* **LoRA r=64, α=16** → good trade‑off for 27B models

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
import torch
import gc

# 🚀 Model configuration - Gemma 3 27B IT
MODEL_NAME = "google/gemma-3-27b-it"
print(f"🎯 Loading model: {MODEL_NAME}")

# BitsAndBytes configuration for 4-bit quantization
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # BF16 for A100
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

print("📥 Loading model in 4-bit precision...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_cfg,
    trust_remote_code=True,
    attn_implementation="flash_attention_2"  # Flash Attention 2 for efficiency
)

model = prepare_model_for_kbit_training(model)

# LoRA configuration for 27B model
peft_cfg = LoraConfig(
    r=64,                    # Higher rank for 27B model
    lora_alpha=16,           # Conservative alpha
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","down_proj","up_proj"]
)
model = get_peft_model(model, peft_cfg)

# Tokenizer setup
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token  # Required for Gemma
tokenizer.padding_side = "right"  # Recommended for training

# Model statistics
total_params = model.num_parameters()
trainable_params = model.num_parameters(only_trainable=True)
print(f"\n📊 Model Statistics:")
print(f"  • Total parameters: {total_params:,}")
print(f"  • Trainable parameters: {trainable_params:,}")
print(f"  • Trainable %: {100*trainable_params/total_params:.4f}%")

# Memory cleanup
gc.collect()
torch.cuda.empty_cache()

print("✅ Model loaded successfully with LoRA adapters!")

In [None]:
# 📚 6. Load dataset into 🤗 Datasets
from datasets import load_dataset

train_ds = load_dataset('json', data_files=TRAIN_OUT, split='train')
val_ds   = load_dataset('json', data_files=VAL_OUT,   split='val')

print(f"📊 Dataset loaded:")
print(f"  • Training examples: {len(train_ds):,}")
print(f"  • Validation examples: {len(val_ds):,}")
print(f"  • Example message structure: {list(train_ds[0]['messages'][0].keys())}")

# Show a sample
print(f"\n📝 Sample training example:")
sample = train_ds[0]['messages']
print(f"System: {sample[0]['content'][:100]}...")
print(f"User: {sample[1]['content'][:100]}...")
print(f"Assistant: {sample[2]['content']}")

## 7 Advanced Training Configuration

### 🎯 Key Features:
- **Advanced Checkpointing**: Save every 50 steps, keep 5 most recent
- **Auto-Resume**: Automatically resume from last checkpoint
- **Memory Optimization**: Gradient checkpointing + BF16
- **A100 Optimized**: Perfect settings for 40GB A100 GPU

In [None]:
# 🚀 7. Configure Advanced SFTTrainer
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
import os

# Training configuration
MAX_LEN = 2048  # Sequence length
OUTPUT_DIR = "./gemma3-27b-teach-checkpoints"
HF_REPO = "your-username/gemma3-27b-teach-lora"  # Change this!

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Advanced Training Arguments with Checkpointing
training_args = TrainingArguments(
    # === Output & Checkpointing ===
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=False,  # Don't overwrite existing checkpoints
    
    # === Advanced Checkpointing Configuration ===
    save_strategy="steps",           # Save based on steps, not epochs
    save_steps=50,                   # Save every 50 steps (frequent for 6k examples)
    save_total_limit=5,              # Keep only 5 most recent checkpoints
    resume_from_checkpoint=True,     # Auto-resume from latest checkpoint
    
    # === Training Schedule ===
    num_train_epochs=3,
    max_steps=-1,                    # Use epochs instead of max_steps
    
    # === Learning Configuration ===
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    
    # === Memory & Performance Optimization ===
    per_device_train_batch_size=1,   # Conservative for 27B model
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,    # Effective batch size: 8
    gradient_checkpointing=True,      # Essential for memory efficiency
    dataloader_num_workers=2,         # Parallel data loading
    dataloader_pin_memory=True,       # Speed up data transfer
    
    # === Precision & Hardware ===
    bf16=True,                        # BF16 for A100 (better than FP16)
    fp16=False,                       # Disable FP16 when using BF16
    tf32=True,                        # Enable TF32 for A100
    
    # === Evaluation & Logging ===
    eval_strategy="steps",
    eval_steps=100,                   # Evaluate every 100 steps
    logging_strategy="steps",
    logging_steps=10,                 # Log every 10 steps
    
    # === Model Selection ===
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    
    # === Regularization ===
    weight_decay=0.01,
    max_grad_norm=1.0,
    
    # === Hub Integration ===
    push_to_hub=False,                # We'll push manually later
    report_to="tensorboard",          # Logging backend
    
    # === Stability ===
    seed=42,
    data_seed=42,
    
    # === Advanced Features ===
    remove_unused_columns=False,      # Keep all columns for SFT
    group_by_length=True,             # Group similar lengths for efficiency
    ddp_find_unused_parameters=False, # DDP optimization
)

# SFT-specific configuration
sft_config = SFTConfig(
    max_seq_length=MAX_LEN,
    packing=True,                     # Pack multiple conversations
    dataset_text_field="messages",   # Field containing the conversation
    dataset_num_proc=2,              # Parallel dataset processing
)

print(f"🎯 Training Configuration:")
print(f"  • Model: {MODEL_NAME}")
print(f"  • Output dir: {OUTPUT_DIR}")
print(f"  • Max sequence length: {MAX_LEN}")
print(f"  • Batch size (effective): {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  • Learning rate: {training_args.learning_rate}")
print(f"  • Epochs: {training_args.num_train_epochs}")
print(f"  • Save every: {training_args.save_steps} steps")
print(f"  • Keep checkpoints: {training_args.save_total_limit}")
print(f"  • Precision: {'BF16' if training_args.bf16 else 'FP16' if training_args.fp16 else 'FP32'}")
print(f"  • Resume from checkpoint: {training_args.resume_from_checkpoint}")

In [None]:
# 🏋️ 8. Initialize SFTTrainer with Advanced Configuration
print("🚀 Initializing SFTTrainer with advanced checkpointing...")

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    processing_class=tokenizer,
    
    # SFT-specific parameters
    max_seq_length=MAX_LEN,
    packing=True,                     # Pack multiple conversations for efficiency
    dataset_text_field="messages",   # Field containing conversation data
    dataset_num_proc=2,              # Parallel dataset processing
    
    # Memory optimizations
    neftune_noise_alpha=None,        # Disable NEFTune for stability
)

# Check for existing checkpoints
checkpoints = [d for d in os.listdir(OUTPUT_DIR) if d.startswith('checkpoint-')] if os.path.exists(OUTPUT_DIR) else []
if checkpoints:
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[1]))
    checkpoint_path = os.path.join(OUTPUT_DIR, latest_checkpoint)
    print(f"📂 Found existing checkpoints: {len(checkpoints)}")
    print(f"📍 Latest checkpoint: {latest_checkpoint}")
    print(f"🔄 Training will resume from: {checkpoint_path}")
else:
    print(f"🆕 No existing checkpoints found - starting fresh training")

print("✅ SFTTrainer initialized successfully!")
print(f"\n📊 Training Overview:")
print(f"  • Total training examples: {len(train_ds):,}")
print(f"  • Total validation examples: {len(val_ds):,}")
print(f"  • Estimated training steps: {len(train_ds) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs:,}")
print(f"  • Checkpoints will be saved every: {training_args.save_steps} steps")
print(f"  • Evaluation every: {training_args.eval_steps} steps")

## 8 Training with Advanced Monitoring

### 🚀 Ready to train with:
- ✅ **Auto-resume** from checkpoints
- ✅ **Memory optimization** for 27B model
- ✅ **A100-optimized** settings
- ✅ **Connection keepalive** active

**Training will automatically resume if interrupted!**

In [None]:
# 🏋️ 9. Start/Resume Training with Advanced Monitoring
import time
import psutil
import GPUtil
from datetime import datetime

def print_system_stats():
    """Print current system and GPU statistics"""
    # GPU stats
    gpus = GPUtil.getGPUs()
    if gpus:
        gpu = gpus[0]
        print(f"🖥️  GPU: {gpu.memoryUsed:.0f}MB / {gpu.memoryTotal:.0f}MB ({gpu.memoryUtil*100:.1f}%) | Util: {gpu.load*100:.1f}%")
    
    # RAM stats
    ram = psutil.virtual_memory()
    print(f"💾 RAM: {ram.used/1024**3:.1f}GB / {ram.total/1024**3:.1f}GB ({ram.percent:.1f}%)")

print(f"🚀 Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"🎯 Model: {MODEL_NAME}")
print(f"📊 Dataset: {len(train_ds):,} train, {len(val_ds):,} val examples")
print(f"⚙️  Configuration: BF16, LoRA r={peft_cfg.r}, α={peft_cfg.lora_alpha}")
print_system_stats()
print("\n" + "="*60)
print("🏃‍♂️ TRAINING STARTED - Auto-resume enabled")
print("🔄 Checkpoints saved every 50 steps to:", OUTPUT_DIR)
print("💾 Only 5 most recent checkpoints will be kept")
print("📊 Tensorboard logs available in: runs/")
print("="*60 + "\n")

try:
    # Train with automatic resume
    trainer.train(resume_from_checkpoint=True)
    
    print("\n" + "="*60)
    print("🎉 TRAINING COMPLETED SUCCESSFULLY!")
    print(f"⏰ Finished at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print_system_stats()
    print("="*60)
    
except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user")
    print("💾 Latest checkpoint saved - you can resume training later")
    
except Exception as e:
    print(f"\n❌ Training error: {e}")
    print("💾 Check for saved checkpoints in:", OUTPUT_DIR)
    raise

# Save final adapter
ADAPTER_DIR = "gemma3-27b-teach-lora-final"
print(f"\n💾 Saving final LoRA adapter to: {ADAPTER_DIR}")
trainer.model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)
print(f"✅ Final adapter saved successfully!")

# Memory cleanup
del trainer
gc.collect()
torch.cuda.empty_cache()
print("🧹 Memory cleanup completed")

## 9 Model Upload & Deployment

Upload both the LoRA adapter and merged model to Hugging Face Hub.

In [None]:
# ☁️ 10a. Push LoRA Adapter to Hub
from huggingface_hub import HfApi, create_repo

# Update this with your username!
HF_USERNAME = "your-username"  # CHANGE THIS!
LORA_REPO = f"{HF_USERNAME}/gemma3-27b-teach-lora"

print(f"☁️ Uploading LoRA adapter to: {LORA_REPO}")

try:
    # Create repository
    create_repo(repo_id=LORA_REPO, exist_ok=True, repo_type="model")
    
    # Upload adapter
    api = HfApi()
    api.upload_folder(
        folder_path=ADAPTER_DIR,
        repo_id=LORA_REPO,
        commit_message="Add Gemma 3 27B LoRA adapter for Teach framework"
    )
    
    print(f"✅ LoRA adapter uploaded successfully!")
    print(f"🔗 Model available at: https://huggingface.co/{LORA_REPO}")
    
except Exception as e:
    print(f"❌ Upload failed: {e}")
    print(f"💡 Make sure to update HF_USERNAME and check your permissions")

In [None]:
# ➕ 10b. (Optional) Merge LoRA and Upload Full Model
# WARNING: This creates a large model file (~50GB) - ensure you have space!

MERGE_MODEL = input("🤔 Do you want to merge and upload the full model? (y/N): ").lower().strip()

if MERGE_MODEL in ['y', 'yes']:
    print("🔄 Loading and merging LoRA with base model...")
    print("⚠️ This will take several minutes and use significant memory")
    
    try:
        # Load the trained model
        from peft import PeftModel
        
        # Load base model in full precision for merging
        base_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )
        
        # Load and merge LoRA
        peft_model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
        merged_model = peft_model.merge_and_unload()
        
        # Save merged model
        MERGED_DIR = "gemma3-27b-teach-merged"
        print(f"💾 Saving merged model to: {MERGED_DIR}")
        merged_model.save_pretrained(MERGED_DIR, safetensors=True)
        tokenizer.save_pretrained(MERGED_DIR)
        
        # Upload merged model
        MERGED_REPO = f"{HF_USERNAME}/gemma3-27b-teach-merged"
        print(f"☁️ Uploading merged model to: {MERGED_REPO}")
        
        create_repo(repo_id=MERGED_REPO, exist_ok=True, repo_type="model")
        api.upload_folder(
            folder_path=MERGED_DIR,
            repo_id=MERGED_REPO,
            commit_message="Add merged Gemma 3 27B model fine-tuned for Teach framework"
        )
        
        print(f"✅ Merged model uploaded successfully!")
        print(f"🔗 Model available at: https://huggingface.co/{MERGED_REPO}")
        
        # Cleanup
        del base_model, peft_model, merged_model
        gc.collect()
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"❌ Merge/upload failed: {e}")
        print(f"💡 You can still use the LoRA adapter with the base model")
        
else:
    print("⏭️ Skipping model merge - LoRA adapter is sufficient for most use cases")

## 10 Model Testing & Inference

Test the fine-tuned model with some sample inputs.

In [None]:
# 🔬 11. Quick Inference Test
from transformers import pipeline
from peft import PeftModel
import json

print("🧪 Testing the fine-tuned model...")

try:
    # Load the fine-tuned model for inference
    base_model_inference = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,
        quantization_config=bnb_cfg
    )
    
    # Load LoRA adapter
    model_with_lora = PeftModel.from_pretrained(base_model_inference, ADAPTER_DIR)
    
    # Create pipeline
    pipe = pipeline(
        "text-generation",
        model=model_with_lora,
        tokenizer=tokenizer,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    
    # Test prompts
    test_cases = [
        {
            "component": "Supportive Learning Environment",
            "labels": "Y, N, N/A",
            "transcript": "Teacher greets each student by name and asks how they feel about their homework. She provides encouraging feedback and helps struggling students."
        },
        {
            "component": "Clear Learning Objectives",
            "labels": "Y, N, N/A",
            "transcript": "Today we will learn about fractions. By the end of this lesson, you should be able to add and subtract fractions with the same denominator."
        },
        {
            "component": "Student Engagement",
            "labels": "Y, N, N/A",
            "transcript": "Students are talking among themselves and not paying attention to the teacher's instructions."
        }
    ]
    
    print("\n" + "="*80)
    print("🎯 INFERENCE TESTS")
    print("="*80)
    
    for i, case in enumerate(test_cases, 1):
        # Format the prompt
        system_msg = "You are an expert Teach framework scorer. Reply ONLY with valid JSON: {\"score\":<label>,\"analysis\":<text>}"
        user_msg = f"### Component: {case['component']}\n### Allowed labels: {case['labels']}\n\nTranscript:\n{case['transcript']}\n"
        
        # Create conversation format
        conversation = [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg}
        ]
        
        # Apply chat template
        prompt = tokenizer.apply_chat_template(
            conversation,
            tokenize=False,
            add_generation_prompt=True
        )
        
        print(f"\n🧪 Test {i}: {case['component']}")
        print(f"📝 Transcript: {case['transcript'][:100]}...")
        
        # Generate response
        response = pipe(
            prompt,
            max_new_tokens=150,
            do_sample=False,
            temperature=0.1,
            pad_token_id=tokenizer.eos_token_id
        )[0]['generated_text']
        
        # Extract just the assistant's response
        assistant_response = response.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
        
        print(f"🤖 Response: {assistant_response}")
        
        # Try to parse as JSON
        try:
            parsed = json.loads(assistant_response)
            print(f"✅ Valid JSON - Score: {parsed.get('score', 'N/A')}")
        except:
            print(f"⚠️ Invalid JSON format")
        
        print("-" * 60)
    
    print("\n✅ Inference testing completed!")
    
    # Cleanup
    del pipe, model_with_lora, base_model_inference
    gc.collect()
    torch.cuda.empty_cache()
    
except Exception as e:
    print(f"❌ Inference test failed: {e}")
    print("💡 The model may still be valid - check the saved adapter manually")

## 11 Summary & Next Steps

### ✅ Completed:
- **Model**: Fine-tuned Gemma 3 27B IT with LoRA (r=64, α=16)
- **Dataset**: Processed Teach framework data into ChatML format
- **Training**: Advanced checkpointing with auto-resume capability
- **Hardware**: Optimized for A100 40GB with BF16 precision
- **Monitoring**: Connection keepalive to prevent disconnections

### 🚀 Next Steps:
1. **Evaluate Performance**: Run the model on your Teach evaluation pipeline
2. **Hyperparameter Tuning**: Experiment with different LoRA ranks and learning rates
3. **Production Deployment**: Use the uploaded model for inference
4. **Continuous Training**: Resume training with additional data if needed

### 📊 Model Files:
- **LoRA Adapter**: `{ADAPTER_DIR}/` (lightweight, ~100MB)
- **Checkpoints**: `{OUTPUT_DIR}/checkpoint-*/` (auto-saved every 50 steps)
- **Hub Repository**: `https://huggingface.co/{LORA_REPO}` (if uploaded)

### 🛠️ Usage Example:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-3-27b-it")

# Load fine-tuned adapter
model = PeftModel.from_pretrained(base_model, "your-username/gemma3-27b-teach-lora")
tokenizer = AutoTokenizer.from_pretrained("your-username/gemma3-27b-teach-lora")

# Use for Teach framework scoring!
```

**🎉 Training Complete! Your Gemma 3 27B model is ready for Teach framework evaluation.**

In [None]:
# 🧹 Final Cleanup and Summary
import shutil
from datetime import datetime

print("🧹 Performing final cleanup...")

# Memory cleanup
gc.collect()
torch.cuda.empty_cache()

# Print final summary
print("\n" + "="*80)
print("🎉 TRAINING SESSION COMPLETE")
print("="*80)
print(f"⏰ Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"🎯 Model: {MODEL_NAME}")
print(f"📊 Training examples: {len(train_ds):,}")
print(f"📊 Validation examples: {len(val_ds):,}")
print(f"⚙️ LoRA configuration: r={peft_cfg.r}, α={peft_cfg.lora_alpha}")
print(f"💾 LoRA adapter saved: {ADAPTER_DIR}/")
print(f"📁 Checkpoints saved: {OUTPUT_DIR}/")

# List saved checkpoints
if os.path.exists(OUTPUT_DIR):
    checkpoints = [d for d in os.listdir(OUTPUT_DIR) if d.startswith('checkpoint-')]
    if checkpoints:
        print(f"🔄 Available checkpoints: {len(checkpoints)}")
        latest = max(checkpoints, key=lambda x: int(x.split('-')[1]))
        print(f"📍 Latest checkpoint: {latest}")

# Print file sizes
if os.path.exists(ADAPTER_DIR):
    adapter_size = sum(os.path.getsize(os.path.join(ADAPTER_DIR, f)) 
                      for f in os.listdir(ADAPTER_DIR) 
                      if os.path.isfile(os.path.join(ADAPTER_DIR, f)))
    print(f"💾 LoRA adapter size: {adapter_size / 1024**2:.1f} MB")

print("\n🚀 Ready for evaluation and deployment!")
print("🔗 Don't forget to update the HF_USERNAME variable if uploading")
print("🎯 Use the fine-tuned model in your Teach evaluation pipeline")
print("="*80)

# Stop keepalive (optional)
stop_keepalive = input("\n🔌 Stop connection keepalive? (y/N): ").lower().strip()
if stop_keepalive in ['y', 'yes']:
    display(JavaScript("stopKeepAlive();"))
    print("🔌 Connection keepalive stopped")
else:
    print("🔌 Connection keepalive still active")

print("\n✨ All done! Happy evaluating! ✨")