# SRL GRPO Training (Colab, A100) - Unsloth Edition
End-to-end notebook to build SRL data, split 95/5, and train with GRPO + LoRA on an A100.

**Uses Unsloth for 2-4x faster training with 50-80% less memory!**

In [None]:
# ============================================================================
# Setup and Installation
# ============================================================================
# Runtime: GPU (A100 recommended)
# Make sure to select GPU: Runtime -> Change runtime type -> GPU

import os
from pathlib import Path

# Repository configuration
REPO_URL = "https://github.com/iroblesrazzaq/SRL-reasoning.git"
BRANCH = "main"
WORKDIR = "/content/SRL-reasoning"

# Install Unsloth (must be before other imports)
# This installs optimized versions of transformers, trl, etc.
!pip install -q "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

# Install remaining dependencies
!pip install -q bitsandbytes datasets

# Clone repo if not exists
if not os.path.exists(WORKDIR):
    !git clone --branch $BRANCH $REPO_URL $WORKDIR

%cd $WORKDIR
!git pull

# Install package (without overwriting unsloth's dependencies)
!pip install -e . --no-deps

# Verify GPU
import torch
print("=" * 80)
print("SETUP COMPLETE (Unsloth Edition)")
print("=" * 80)
print(f"âœ“ Repository: {WORKDIR}")
print(f"âœ“ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"âœ“ GPU: {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"âœ“ GPU Memory: {gpu_mem:.2f} GB")
print("=" * 80)

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
/content/SRL-reasoning
Already up to date.
Obtaining file:///content/SRL-reasoning
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: srl-reasoning
  Building editable for srl-reasoning (pyproject.toml) ... [?25l[?25hdone
  Created wheel for srl-reasoning: filename=srl_reasoning-0.1.0-py3-none-any.whl size=1404 sha256=709c98779cf3acf43898fa0311a6bc3dfa9c0a26edf08eba546c08575ee7cbbf
  Stored in directory: /tmp/pip-ephem-wheel-cache-nb3ew93q/wheels/9b/30/2b/cb824dafeae6a6c41265bcc7f5fed90a664ea6e694ffd44198
Successfully bui

In [None]:
# Verify Unsloth installation
import unsloth
print(f"Unsloth version: {unsloth.__version__}")
! python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}')"

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
Unsloth version: 2025.11.6
PyTorch 2.9.0+cu126, CUDA 12.6


In [None]:
#@title 1. Global config
import random, numpy as np
from pathlib import Path
import gc

SEED = 42
BASE_MODEL = 'Qwen/Qwen3-4B-Instruct-2507'
REPO_DIR = Path('/content/SRL-reasoning')
OUTPUT_DIR = Path('/content/outputs/srl_grpo')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR = REPO_DIR / 'data'
DATA_DIR.mkdir(parents=True, exist_ok=True)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
print('DATA_DIR:', DATA_DIR)

DATA_DIR: /content/SRL-reasoning/data


In [None]:
#@title 2. Build SRL data (s1K-1.1 -> step-wise JSONL)
from src.shared.build_srl_data import load_teacher_dataset, normalize_dataset, build_srl_dataset, save_jsonl
from src.shared.splits import split_by_trajectory

raw_ds = load_teacher_dataset('simplescaling/s1K-1.1', split='train')
norm_trajs = normalize_dataset(raw_ds)
srl_examples = build_srl_dataset(norm_trajs)

all_path = DATA_DIR / 'srl_steps.jsonl'
save_jsonl(srl_examples, all_path)

train_examples, val_examples, _ = split_by_trajectory(
    str(all_path),
    train_ratio=0.95,
    val_ratio=0.05,
    test_ratio=0.0,
    seed=SEED,
)

train_path = DATA_DIR / 'train.jsonl'
val_path = DATA_DIR / 'val.jsonl'
save_jsonl(train_examples, train_path)
save_jsonl(val_examples, val_path)

print(f'Train examples: {len(train_examples)}')
print(f'Val examples:   {len(val_examples)}')

Loading dataset: simplescaling/s1K-1.1 (split: train)...
Loaded 1000 examples


Normalizing trajectories: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1000/1000 [00:00<00:00, 4220.03example/s]
Building SRL examples: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 606/606 [00:00<00:00, 102717.65trajectory/s]
Saving to JSONL: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2675/2675 [00:00<00:00, 53412.18example/s]
Saving to JSONL: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2528/2528 [00:00<00:00, 53459.45example/s]
Saving to JSONL: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 143/143 [00:00<00:00, 44069.47example/s]

Train examples: 2528
Val examples:   143





In [None]:
#@title 3. Prepare HF datasets for GRPO
from scripts.train_srl import load_srl_dataset

train_dataset = load_srl_dataset(str(train_path))
val_dataset = load_srl_dataset(str(val_path))

print(f'HF datasets -> train {len(train_dataset)}, val {len(val_dataset)}')

HF datasets -> train 2528, val 143


In [None]:
#@title 4. Load model with Unsloth (2-4x faster, 50% less memory!)
from unsloth import FastLanguageModel

# Clear any existing GPU memory
gc.collect()
torch.cuda.empty_cache()

# Unsloth optimizes the model automatically
# - Fused attention kernels
# - Optimized LoRA
# - Memory-efficient forward/backward
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=1024,           # Reduced for memory (512 prompt + 256 completion + buffer)
    dtype=torch.bfloat16,
    load_in_4bit=True,             # 4-bit for memory efficiency
    trust_remote_code=True,
)

# Apply LoRA with Unsloth's optimized implementation
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
    use_gradient_checkpointing="unsloth",  # Unsloth's optimized checkpointing
    random_state=SEED,
)

# Set padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

# Print memory usage
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f'\nâœ“ Unsloth model loaded!')
print(f'  Trainable params: {trainable_params/1e6:.1f}M / {total_params/1e6:.1f}M')
if torch.cuda.is_available():
    print(f'  GPU Memory used: {torch.cuda.memory_allocated(0)/1e9:.2f} GB')

Are you certain you want to do remote code execution?
==((====))==  Unsloth 2025.11.6: Fast Qwen3 patching. Transformers: 4.57.2.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.11.6 patched 36 layers with 0 QKV layers, 0 O layers and 0 MLP layers.



âœ“ Unsloth model loaded!
  Trainable params: 33.0M / 4055.5M
  GPU Memory used: 8.20 GB


In [None]:
#@title 5. Configure GRPO trainer (Unsloth-optimized settings)
from inspect import signature
from trl import GRPOConfig
from scripts.train_srl import SRLGRPOTrainer, create_reward_function

reward_fn = create_reward_function(tokenizer)

# Check which parameter names are supported
supported = set(signature(GRPOConfig.__init__).parameters)

# ============================================================================
# UNSLOTH-OPTIMIZED GRPO CONFIGURATION
# ============================================================================
# With Unsloth's memory efficiency, we can use MORE AGGRESSIVE settings!
# - Higher batch size
# - More generations per prompt
# - Longer sequences
# ============================================================================

grpo_kwargs = {
    'output_dir': str(OUTPUT_DIR),

    # === TRAINING SCHEDULE (Paper settings) ===
    'num_train_epochs': 30,              # Paper: 30 epochs
    'learning_rate': 5e-7,               # Paper: 5e-7
    'warmup_ratio': 0.0,                 # Paper: no warmup
    'max_grad_norm': 1.0,                # Paper: 1.0

    # === BATCH SIZE (can be more aggressive with Unsloth!) ===
    'per_device_train_batch_size': 1,    # Reduced for memory
    'num_generations': 2,                 # Minimal for memory
    'per_device_eval_batch_size': 2,   # Must be divisible by num_generations
    'gradient_accumulation_steps': 512,   # 1 Ã— 512 = 512 effective batch

    # === GRPO-SPECIFIC ===
    'beta': 0.0,                         # Paper: no KL penalty for SRL
    'temperature': 1.0,                  # Paper: 1.0 for rollouts

    # === TOKEN LIMITS (can use longer with Unsloth!) ===
    'max_prompt_length': 512,            # Reduced for memory
    'max_completion_length': 256,        # Reduced for memory

    # === CHECKPOINTING ===
    'save_strategy': 'epoch',
    'save_total_limit': 2,
    'load_best_model_at_end': True,
    'metric_for_best_model': 'eval_reward',
    'greater_is_better': True,

    # === OPTIMIZATION ===
    'optim': 'adamw_8bit',               # Memory-efficient optimizer
    'bf16': True,                        # bfloat16 precision

    # === LOGGING ===
    'logging_steps': 1,
    'report_to': 'none',
    'seed': SEED,
}

# Handle eval_strategy naming
if 'eval_strategy' in supported:
    grpo_kwargs['eval_strategy'] = 'epoch'
else:
    grpo_kwargs['evaluation_strategy'] = 'epoch'

# Filter to only supported parameters
grpo_config = GRPOConfig(**{k: v for k, v in grpo_kwargs.items() if k in supported})

print("=" * 80)
print("GRPO Config Summary (Unsloth-Optimized)")
print("=" * 80)
print(f"  per_device_train_batch_size: {grpo_kwargs['per_device_train_batch_size']}")
print(f"  num_generations: {grpo_kwargs['num_generations']}")
print(f"  Sequences per step: {grpo_kwargs['per_device_train_batch_size']} Ã— {grpo_kwargs['num_generations']} = {grpo_kwargs['per_device_train_batch_size'] * grpo_kwargs['num_generations']}")
print(f"  Effective batch size: {grpo_kwargs['per_device_train_batch_size']} Ã— {grpo_kwargs['gradient_accumulation_steps']} = {grpo_kwargs['per_device_train_batch_size'] * grpo_kwargs['gradient_accumulation_steps']}")
print(f"  Max tokens: prompt={grpo_kwargs['max_prompt_length']}, completion={grpo_kwargs['max_completion_length']}")
print(f"  Total max sequence: {grpo_kwargs['max_prompt_length'] + grpo_kwargs['max_completion_length']}")
print("=" * 80)

GRPO Config Summary (Unsloth-Optimized)
  per_device_train_batch_size: 1
  num_generations: 2
  Sequences per step: 1 Ã— 2 = 2
  Effective batch size: 1 Ã— 512 = 512
  Max tokens: prompt=512, completion=256
  Total max sequence: 768


In [None]:
#@title 6. Initialize trainer and start training

# Clear cache before training
gc.collect()
torch.cuda.empty_cache()

# Enable Unsloth's fast inference mode for generation
FastLanguageModel.for_inference(model)

trainer = SRLGRPOTrainer(
    model=model,
    args=grpo_config,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    reward_funcs=reward_fn,
    filter_epsilon=1e-4,
)

print("\n" + "=" * 80)
print("STARTING TRAINING (Unsloth-Optimized)")
print("=" * 80)
print(f"GPU Memory before training: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")

# Switch to training mode
FastLanguageModel.for_training(model)

train_result = trainer.train()
print(train_result)

The model is already on multiple devices. Skipping the move to device specified in `args`.



STARTING TRAINING (Unsloth-Optimized)
GPU Memory before training: 8.20 GB


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,528 | Num Epochs = 30 | Total steps = 270
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 512
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 512 x 1) = 512
 "-____-"     Trainable parameters = 33,030,144 of 4,055,498,240 (0.81% trained)
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 262144, 'temperature': 0.7, 'top_p': 0.8}. If this is not desired, please set these values explicitly.


OutOfMemoryError: CUDA out of memory. Tried to allocate 9.50 GiB. GPU 0 has a total capacity of 79.32 GiB of which 1.70 GiB is free. Process 55935 has 77.61 GiB memory in use. Of the allocated memory 54.07 GiB is allocated by PyTorch, and 23.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
#@title 7. Save best model to Google Drive
from google.colab import drive

drive.mount('/content/drive')

BEST_DIR = Path('/content/drive/MyDrive/SRL_Best_Model')
BEST_DIR.mkdir(parents=True, exist_ok=True)

# Save using Unsloth's optimized save
model.save_pretrained(str(BEST_DIR))
tokenizer.save_pretrained(str(BEST_DIR))

print('Saved best model to', BEST_DIR)

In [None]:
#@title 8. (Optional) Merge LoRA and save full model
# Uncomment to merge LoRA weights into the base model for easier deployment

# MERGED_DIR = Path('/content/drive/MyDrive/SRL_Merged_Model')
# MERGED_DIR.mkdir(parents=True, exist_ok=True)

# model.save_pretrained_merged(
#     str(MERGED_DIR),
#     tokenizer,
#     save_method="merged_16bit",  # Full 16-bit merged model
# )
# print('Saved merged model to', MERGED_DIR)