# SRL GRPO Training (Colab, A100)
End-to-end notebook to build SRL data, split 95/5, and train with GRPO + LoRA on an A100.

In [None]:
#@title 0. Environment setup (clone + installs)
import os, sys
from pathlib import Path

REPO_URL = "https://github.com/iroblesrazzaq/SRL-reasoning.git"
REPO_DIR = Path('/content/SRL-reasoning')
REPO_DIR_STR = str(REPO_DIR)

if not REPO_DIR.exists():
    !git clone $REPO_URL $REPO_DIR_STR

os.chdir(REPO_DIR_STR)
if REPO_DIR_STR not in sys.path:
    sys.path.append(REPO_DIR_STR)

!pip install -q transformers peft bitsandbytes accelerate datasets trl vllm flash-attn --no-build-isolation
!pip install -q git+https://github.com/huggingface/trl.git
!pip install -q -e .

import torch
device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
print(f'PyTorch {torch.__version__} | Device: {device_name}')



In [None]:
#@title 1. Global config
import random, numpy as np
from pathlib import Path

SEED = 42
BASE_MODEL = 'Qwen/Qwen2.5-7B-Instruct'
OUTPUT_DIR = Path('/content/outputs/srl_grpo')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR = REPO_DIR / 'data'
DATA_DIR.mkdir(parents=True, exist_ok=True)

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
print('DATA_DIR:', DATA_DIR)



In [None]:
#@title 2. Build SRL data (s1K-1.1 -> step-wise JSONL)
from src.shared.build_srl_data import load_teacher_dataset, normalize_dataset, build_srl_dataset, save_jsonl
from src.shared.splits import split_by_trajectory

raw_ds = load_teacher_dataset('simplescaling/s1K-1.1', split='train')
norm_trajs = normalize_dataset(raw_ds)
srl_examples = build_srl_dataset(norm_trajs)

all_path = DATA_DIR / 'srl_steps.jsonl'
save_jsonl(srl_examples, all_path)

train_examples, val_examples, _ = split_by_trajectory(
    str(all_path),
    train_ratio=0.95,
    val_ratio=0.05,
    test_ratio=0.0,
    seed=SEED,
)

train_path = DATA_DIR / 'train.jsonl'
val_path = DATA_DIR / 'val.jsonl'
save_jsonl(train_examples, train_path)
save_jsonl(val_examples, val_path)

print(f'Train examples: {len(train_examples)}')
print(f'Val examples:   {len(val_examples)}')



In [None]:
#@title 3. Prepare HF datasets for GRPO
from scripts.train_srl import load_srl_dataset

train_dataset = load_srl_dataset(str(train_path))
val_dataset = load_srl_dataset(str(val_path))

print(train_dataset[:2])
print(f'HF datasets -> train {len(train_dataset)}, val {len(val_dataset)}')



In [None]:
#@title 4. Load model + tokenizer (LoRA, flash-attn, grad checkpointing)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    padding_side='left',
    trust_remote_code=True,
 )
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    attn_implementation='flash_attention_2',
    torch_dtype=torch.bfloat16,
    device_map='auto',
    trust_remote_code=True,
)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules='all-linear',
    task_type=TaskType.CAUSAL_LM,
    bias='none',
)
model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f'Trainable params: {trainable_params/1e6:.1f}M / {total_params/1e6:.1f}M')
if torch.cuda.is_available():
    print('Model device:', next(model.parameters()).device)



In [None]:
#@title 5. Configure GRPO trainer and train
from inspect import signature
from trl import GRPOConfig
from scripts.train_srl import SRLGRPOTrainer, create_reward_function

reward_fn = create_reward_function(tokenizer)

grpo_kwargs = {
    'output_dir': str(OUTPUT_DIR),
    'num_train_epochs': 30,             # [cite: 235] Matches paper
    'per_device_train_batch_size': 8,   # A100 80GB capacity
    # To match Paper Batch Size of 512: 8 * 64 = 512
    'gradient_accumulation_steps': 64,  # [cite: 523]
    'learning_rate': 5e-7,              # [cite: 530] Paper uses 5e-7
    'beta': 0.0,                        # [cite: 536] KL coeff is 0 for SRL
    'warmup_ratio': 0.0,                # [cite: 531] No warmup
    'max_grad_norm': 1.0,               # [cite: 525]
    'num_generations': 8,               # [cite: 534]
    'temperature': 1.0,                 # [cite: 533] Explicitly set rollout temp
    
    # Convention / Engineering settings
    'logging_steps': 1,
    'save_strategy': 'epoch',
    'evaluation_strategy': 'epoch',     # Paper evaluates on val set
    'save_total_limit': 2,
    'load_best_model_at_end': True,     # [cite: 235] "select checkpoint with best perf"
    'metric_for_best_model': 'reward',
    'greater_is_better': True,
    'optim': 'adamw_8bit',              # Convention (Paper uses H100s, likely 8bit or full)
    'bf16': True,                       # [cite: 527]
    'report_to': 'none',
    'seed': SEED,
}

supported = set(signature(GRPOConfig.__init__).parameters)
grpo_config = GRPOConfig(**{k: v for k, v in grpo_kwargs.items() if k in supported})

trainer = SRLGRPOTrainer(
    model=model,
    args=grpo_config,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    reward_funcs=reward_fn,
    filter_epsilon=1e-4,
)

train_result = trainer.train()
print(train_result)



In [None]:
#@title 6. Save best model to Google Drive
from google.colab import drive

drive.mount('/content/drive')

BEST_DIR = Path('/content/drive/MyDrive/SRL_Best_Model')
BEST_DIR.mkdir(parents=True, exist_ok=True)

trainer.save_model(str(BEST_DIR))
tokenizer.save_pretrained(str(BEST_DIR))

print('Saved best model to', BEST_DIR)

