# SRL GRPO Training (Colab, A100)
End-to-end notebook for SRL GRPO training with LoRA.

**Features:**
- Saves checkpoint to disk AND Drive after each epoch
- No validation (removed to avoid failures)
- Crash-resilient: progress saved to Drive survives session kills
- **Unique run names** for parallel experiments

In [None]:
#@title 0. Mount Google Drive FIRST
from google.colab import drive
drive.mount('/content/drive')
print('‚úì Drive mounted')

In [None]:
#@title 1. ‚ö†Ô∏è SET RUN CONFIG HERE ‚ö†Ô∏è (change for each experiment)
from pathlib import Path
from datetime import datetime

# ============================================================
# CHANGE THESE FOR EACH EXPERIMENT!
# ============================================================
# Run name examples:
#   RUN_NAME = "baseline_string"      # 100% string similarity
#   RUN_NAME = "cosine_only"          # 100% cosine similarity
#   RUN_NAME = "hybrid_90_10"         # 90% string + 10% cosine

RUN_NAME = "baseline_string"  # <-- CHANGE THIS

# ============================================================
# REWARD FUNCTION WEIGHTS
# ============================================================
# Three experiment configurations:
#   baseline_string:  STRING_WEIGHT=1.0, COSINE_WEIGHT=0.0
#   cosine_only:      STRING_WEIGHT=0.0, COSINE_WEIGHT=1.0
#   hybrid_90_10:     STRING_WEIGHT=0.9, COSINE_WEIGHT=0.1

STRING_WEIGHT = 1.0  # Weight for string similarity (difflib)
COSINE_WEIGHT = 0.0  # Weight for cosine similarity (embeddings)

# Validate reward weights
if STRING_WEIGHT < 0 or COSINE_WEIGHT < 0:
    raise ValueError("Weights must be non-negative")
if STRING_WEIGHT + COSINE_WEIGHT <= 0:
    raise ValueError("STRING_WEIGHT + COSINE_WEIGHT must be > 0")

# Embedding model (only loaded if COSINE_WEIGHT > 0)
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B"

# ============================================================
# Auto-generate paths based on RUN_NAME
# ============================================================
DRIVE_BASE = Path('/content/drive/MyDrive/srl_outputs')
DRIVE_CKPT_DIR = DRIVE_BASE / f'checkpoints_{RUN_NAME}'
DRIVE_FINAL_DIR = DRIVE_BASE / f'final_{RUN_NAME}'

DISK_CKPT_DIR = Path(f'/content/checkpoints_{RUN_NAME}')
OUTPUT_DIR = Path(f'/content/outputs_{RUN_NAME}')

# Create directories
DRIVE_CKPT_DIR.mkdir(parents=True, exist_ok=True)
DISK_CKPT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print('=' * 60)
print(f'RUN NAME: {RUN_NAME}')
print('=' * 60)
print(f'  Drive checkpoints: {DRIVE_CKPT_DIR}')
print(f'  Drive final:       {DRIVE_FINAL_DIR}')
print(f'  Disk checkpoints:  {DISK_CKPT_DIR}')
print(f'  Output dir:        {OUTPUT_DIR}')
print('=' * 60)

# Check for existing checkpoints
existing = list(DRIVE_CKPT_DIR.glob('epoch_*'))
if existing:
    print(f'\n‚ö†Ô∏è  Found {len(existing)} existing checkpoint(s) for this run:')
    for p in sorted(existing):
        print(f'      - {p.name}')
    print('\n   (Training will add more checkpoints to this run)')
else:
    print(f'\n‚úì Fresh run - no existing checkpoints')

In [None]:
#@title 2. Environment setup (clone + installs)
import os, sys
from pathlib import Path
%env TORCH_CUDA_ARCH_LIST=8.0

REPO_URL = "https://github.com/iroblesrazzaq/SRL-reasoning.git"
REPO_DIR = Path('/content/SRL-reasoning')
REPO_DIR_STR = str(REPO_DIR)

if not REPO_DIR.exists():
    !git clone $REPO_URL $REPO_DIR_STR

os.chdir(REPO_DIR_STR)
if REPO_DIR_STR not in sys.path:
    sys.path.append(REPO_DIR_STR)

!pip install transformers peft bitsandbytes accelerate datasets trl --no-build-isolation
!pip install --no-build-isolation --no-cache-dir flash-attn

!pip install git+https://github.com/huggingface/trl.git
!pip install -e .

import torch
device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
print(f'PyTorch {torch.__version__} | Device: {device_name}')

In [None]:
#@title 3. Global config
import random, numpy as np
from pathlib import Path

SEED = 42
BASE_MODEL = 'Qwen/Qwen3-4B-Instruct-2507'
DATA_DIR = REPO_DIR / 'data'
DATA_DIR.mkdir(parents=True, exist_ok=True)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(f'RUN_NAME: {RUN_NAME}')
print(f'DATA_DIR: {DATA_DIR}')
print(f'OUTPUT_DIR: {OUTPUT_DIR}')
print(f'Reward weights: string={STRING_WEIGHT}, cosine={COSINE_WEIGHT}')
if COSINE_WEIGHT > 0:
    print(f'Embedding model: {EMBEDDING_MODEL}')

In [None]:
#@title 4. Build SRL data (s1K-1.1 -> step-wise JSONL)
from src.shared.build_srl_data import load_teacher_dataset, normalize_dataset, build_srl_dataset, save_jsonl
from src.shared.splits import split_by_trajectory

raw_ds = load_teacher_dataset('simplescaling/s1K-1.1', split='train')
norm_trajs = normalize_dataset(raw_ds)
srl_examples = build_srl_dataset(norm_trajs)

all_path = DATA_DIR / 'srl_steps.jsonl'
save_jsonl(srl_examples, all_path)

train_examples, val_examples, _ = split_by_trajectory(
    str(all_path),
    train_ratio=1.0,
    val_ratio=0.0,
    test_ratio=0.0,  # Explicitly discard 90%
    seed=SEED,
)

train_path = DATA_DIR / 'train.jsonl'
val_path = DATA_DIR / 'val.jsonl'
save_jsonl(train_examples, train_path)
save_jsonl(val_examples, val_path)

print(f'Train examples: {len(train_examples)}')
print(f'Val examples:   {len(val_examples)}')

In [None]:
#@title 5. Prepare HF datasets for GRPO (train only - no eval)
from scripts.train_srl import load_srl_dataset

train_dataset = load_srl_dataset(str(train_path))
# NOTE: val_dataset not used - validation removed to avoid failures

print(train_dataset[:2])
print(f'HF datasets -> train {len(train_dataset)}')

In [None]:
#@title 6. Load model + tokenizer (LoRA, flash-attn, grad checkpointing)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    padding_side='left',
    trust_remote_code=True,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    attn_implementation='flash_attention_2',
    torch_dtype=torch.bfloat16,
    device_map='auto',
    trust_remote_code=True,
)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules='all-linear',
    task_type=TaskType.CAUSAL_LM,
    bias='none',
)
model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f'Trainable params: {trainable_params/1e6:.1f}M / {total_params/1e6:.1f}M')
if torch.cuda.is_available():
    print('Model device:', next(model.parameters()).device)

In [None]:
#@title 7. Configure GRPO trainer (with epoch checkpoints)
import shutil
from inspect import signature
from trl import GRPOConfig
from transformers import TrainerCallback
from scripts.train_srl import SRLGRPOTrainer, create_reward_function

# Create reward function with configured weights
reward_fn = create_reward_function(
    tokenizer=tokenizer,
    string_weight=STRING_WEIGHT,
    cosine_weight=COSINE_WEIGHT,
    embedding_model_name=EMBEDDING_MODEL,
)

# ============================================================
# Callback to save after each epoch to disk AND Drive
# ============================================================
class EpochCheckpointCallback(TrainerCallback):
    """Save model to disk and Drive after each epoch."""
    
    def __init__(self, trainer_ref, tokenizer, disk_dir, drive_dir, run_name):
        self.trainer_ref = trainer_ref
        self.tokenizer = tokenizer
        self.disk_dir = Path(disk_dir)
        self.drive_dir = Path(drive_dir)
        self.run_name = run_name
        self.last_saved_epoch = -1
    
    def on_epoch_end(self, args, state, control, **kwargs):
        """Save checkpoint at end of each epoch."""
        epoch = int(state.epoch)
        
        # Avoid duplicate saves
        if epoch <= self.last_saved_epoch:
            return
        self.last_saved_epoch = epoch
        
        ckpt_name = f'epoch_{epoch}_step_{state.global_step}'
        disk_path = self.disk_dir / ckpt_name
        drive_path = self.drive_dir / ckpt_name
        
        print(f'\n{"="*60}')
        print(f'[{self.run_name}] EPOCH {epoch} COMPLETE - SAVING CHECKPOINT')
        print(f'{"="*60}')
        
        # Save to Colab disk (fast)
        try:
            self.trainer_ref.save_model(str(disk_path))
            self.tokenizer.save_pretrained(str(disk_path))
            # Verify files were saved
            saved_files = list(disk_path.glob('*'))
            if not saved_files:
                print(f'  ‚úó ERROR: No files saved to {disk_path}')
                return
            print(f'  ‚úì Saved to disk: {disk_path} ({len(saved_files)} files)')
        except Exception as e:
            print(f'  ‚úó ERROR saving to disk: {e}')
            return  # Don't try Drive if disk save failed
        
        # Copy to Google Drive (backup) - wrapped in try/except to not crash training
        try:
            import os as _os
            if drive_path.exists():
                shutil.rmtree(drive_path)
            shutil.copytree(disk_path, drive_path)
            _os.sync()  # Force flush to ensure Drive write completes
            print(f'  ‚úì Backed up to Drive: {drive_path}')
            
            # Show Drive contents
            existing = list(self.drive_dir.glob('epoch_*'))
            print(f'  ‚úì Drive checkpoints for {self.run_name}: {len(existing)}')
            for p in sorted(existing)[-3:]:  # Show last 3
                print(f'      - {p.name}')
            if len(existing) > 3:
                print(f'      ... and {len(existing) - 3} more')
        except Exception as e:
            print(f'  ‚ö†Ô∏è  WARNING: Drive backup failed (disk copy still exists): {e}')
            print(f'      Disk checkpoint at: {disk_path}')
        
        print(f'{"="*60}\n')

# ============================================================
# GRPO Config - NO EVALUATION
# ============================================================
grpo_kwargs = {
    'output_dir': str(OUTPUT_DIR),
    'num_train_epochs': 30,             # [cite: 235] Matches paper
    'per_device_train_batch_size': 4,   # A100 80GB capacity
    # Effective batch size: 4 * 32 = 128 (reduced from paper's 512 for memory)
    # Note: GRPO generates num_generations=4 per prompt, so actual throughput is higher
    'gradient_accumulation_steps': 32,  # [cite: 523]
    'learning_rate': 5e-7,              # [cite: 530] Paper uses 5e-7
    'beta': 0.0,                        # [cite: 536] KL coeff is 0 for SRL
    'warmup_ratio': 0.0,                # [cite: 531] No warmup
    'max_grad_norm': 1.0,               # [cite: 525]
    'num_generations': 4,               # [cite: 534]
    'temperature': 1.0,                 # [cite: 533] Explicitly set rollout temp
    'max_prompt_length': 512,           # Max input prompt tokens
    'max_completion_length': 256,       # Max generated tokens per completion

    # === SAVE STRATEGY ===
    'save_strategy': 'no',              # We handle saves manually via callback
    
    # === NO EVALUATION (removed to avoid failures) ===
    'evaluation_strategy': 'no',        # No validation - was causing failures
    'load_best_model_at_end': False,    # Disabled since no eval
    
    # Convention / Engineering settings
    'logging_steps': 1,
    'optim': 'adamw_8bit',              # Convention (Paper uses H100s, likely 8bit or full)
    'bf16': True,                       # [cite: 527]
    'report_to': 'none',
    'seed': SEED,
}

supported = set(signature(GRPOConfig.__init__).parameters)
grpo_config = GRPOConfig(**{k: v for k, v in grpo_kwargs.items() if k in supported})

# Create trainer (NO eval_dataset)
# Check if GRPOTrainer accepts 'tokenizer' parameter (TRL version compatibility)
# Note: Check GRPOTrainer (parent), not SRLGRPOTrainer (which uses **kwargs)
from trl import GRPOTrainer
_grpo_params = set(signature(GRPOTrainer.__init__).parameters)
_trainer_kwargs = {
    "model": model,
    "args": grpo_config,
    "train_dataset": train_dataset,
    "reward_funcs": reward_fn,
    "filter_epsilon": 1e-4,
}
# TRL versions vary - some use 'tokenizer', some use 'processing_class'
if "tokenizer" in _grpo_params:
    _trainer_kwargs["tokenizer"] = tokenizer
elif "processing_class" in _grpo_params:
    _trainer_kwargs["processing_class"] = tokenizer
else:
    print("‚ö†Ô∏è  Warning: Neither 'tokenizer' nor 'processing_class' found in GRPOTrainer params")

trainer = SRLGRPOTrainer(**_trainer_kwargs)

# Add epoch checkpoint callback
epoch_callback = EpochCheckpointCallback(
    trainer_ref=trainer,
    tokenizer=tokenizer,
    disk_dir=DISK_CKPT_DIR,
    drive_dir=DRIVE_CKPT_DIR,
    run_name=RUN_NAME,
)
trainer.add_callback(epoch_callback)

print('=' * 60)
print(f'TRAINING CONFIG - {RUN_NAME}')
print('=' * 60)
print(f'  Run name: {RUN_NAME}')
print(f'  Epochs: {grpo_kwargs["num_train_epochs"]}')
print(f'  Batch size: {grpo_kwargs["per_device_train_batch_size"]} x {grpo_kwargs["gradient_accumulation_steps"]} = {grpo_kwargs["per_device_train_batch_size"] * grpo_kwargs["gradient_accumulation_steps"]}')
print(f'  Max prompt length: {grpo_kwargs["max_prompt_length"]}')
print(f'  Max completion length: {grpo_kwargs["max_completion_length"]}')
print(f'  Learning rate: {grpo_kwargs["learning_rate"]}')
total_w = STRING_WEIGHT + COSINE_WEIGHT
print(f'  Reward weights: string={STRING_WEIGHT/total_w:.2f}, cosine={COSINE_WEIGHT/total_w:.2f}')
if COSINE_WEIGHT > 0:
    print(f'  Embedding model: {EMBEDDING_MODEL}')
print(f'  ‚úì Validation: DISABLED')
print(f'  ‚úì Checkpoint after each epoch to:')
print(f'      - Disk: {DISK_CKPT_DIR}')
print(f'      - Drive: {DRIVE_CKPT_DIR}')
print('=' * 60)

In [None]:
#@title 8. TRAIN (checkpoints saved after each epoch)
# Check if we should resume from a checkpoint
resume_checkpoint = None
if 'RESUME_FROM' in dir() and RESUME_FROM:
    resume_checkpoint = RESUME_FROM
    print(f'\nüîÑ RESUMING TRAINING: {RUN_NAME}')
    print(f'   From checkpoint: {resume_checkpoint}')
else:
    print(f'\nüöÄ STARTING TRAINING: {RUN_NAME}')

print('Progress is backed up to Drive after each epoch.')
print('If session dies, your checkpoints are safe!\n')

train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
print(train_result)

In [None]:
#@title 9. Save final model to Google Drive
DRIVE_FINAL_DIR.mkdir(parents=True, exist_ok=True)

trainer.save_model(str(DRIVE_FINAL_DIR))
tokenizer.save_pretrained(str(DRIVE_FINAL_DIR))
import os
os.sync()  # Ensure final model is written to Drive

print('=' * 60)
print(f'TRAINING COMPLETE - {RUN_NAME}')
print('=' * 60)
print(f'Final model saved to: {DRIVE_FINAL_DIR}')
print(f'Epoch checkpoints at: {DRIVE_CKPT_DIR}')

# List all checkpoints
print('\nAll saved checkpoints:')
for p in sorted(DRIVE_CKPT_DIR.glob('epoch_*')):
    print(f'  - {p.name}')

---
# Resume Training from Checkpoint

If your session crashed, follow these steps to resume:

1. Run **Cells 0-6** (with the **same `RUN_NAME`**!)
2. Run the **Resume cell below** to set `RESUME_FROM`
3. Run **Cell 7** (Configure GRPO trainer)
4. Run **Cell 8** (TRAIN) - it will automatically resume from the checkpoint

In [None]:
#@title [OPTIONAL] Resume from Drive checkpoint
# Run this cell BEFORE Cell 8 (TRAIN) to resume from a checkpoint.
# After running this cell, run Cell 7 (Configure GRPO) then Cell 8 (TRAIN).

print(f'Looking for checkpoints for run: {RUN_NAME}')
print(f'Directory: {DRIVE_CKPT_DIR}')
print()

RESUME_FROM = None  # Initialize to None
checkpoints = sorted(DRIVE_CKPT_DIR.glob('epoch_*'))
if checkpoints:
    print(f'Found {len(checkpoints)} checkpoint(s):')
    for i, p in enumerate(checkpoints):
        print(f'  [{i}] {p.name}')
    
    # Use most recent checkpoint
    RESUME_FROM = str(checkpoints[-1])
    print(f'\n‚úì RESUME_FROM set to: {RESUME_FROM}')
    print('\n‚ö†Ô∏è  Now run Cell 7 (Configure GRPO) then Cell 8 (TRAIN) to resume.')
else:
    print('No checkpoints found - start fresh training.')
    print('Run Cell 7 and Cell 8 normally.')

---
# Compare All Runs

List all experiment runs saved in Drive.

In [None]:
#@title List all runs in Drive
print('All SRL training runs in Drive:')
print('=' * 60)

checkpoint_dirs = sorted(DRIVE_BASE.glob('checkpoints_*'))
final_dirs = sorted(DRIVE_BASE.glob('final_*'))

if checkpoint_dirs:
    print('\nCheckpoint directories:')
    for d in checkpoint_dirs:
        run_name = d.name.replace('checkpoints_', '')
        epochs = list(d.glob('epoch_*'))
        print(f'  {run_name}: {len(epochs)} checkpoint(s)')

if final_dirs:
    print('\nFinal models:')
    for d in final_dirs:
        run_name = d.name.replace('final_', '')
        print(f'  {run_name}')

if not checkpoint_dirs and not final_dirs:
    print('No runs found yet.')