# Qwen3-1.7B GSM8K PPO Training (VERL)

| Model | GSM8K Accuracy |
|---|---|
| Qwen3-1.7B base | ~69.2% |
| Qwen3-1.7B + PPO | ~82.7% |

**Before running this notebook**, set up the environment in the Colab terminal:
```bash
git clone https://github.com/verl-project/verl /content/verl && pip install -e /content/verl
pip install flash-attn --no-build-isolation
pip install trl datasets sympy regex pandas pyarrow huggingface-hub
git clone https://github.com/jiayiderekchen/LLM_RL /content/LLM_RL
```
See `COLAB_GUIDE.md` for full setup details.

## 0. Check GPU

In [None]:
!nvidia-smi
import torch
gpu = torch.cuda.get_device_properties(0)
GPU_MEM_GB = gpu.total_memory / 1e9
print(f"GPU : {gpu.name}")
print(f"VRAM: {GPU_MEM_GB:.1f} GB")
IS_A100_OR_H100 = GPU_MEM_GB >= 38
print(f"Config: {'large (A100/H100)' if IS_A100_OR_H100 else 'small (T4)'}")

## 1. Mount Google Drive
Checkpoints and logs are saved directly to Drive so they survive session resets.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
DRIVE_DIR  = '/content/drive/MyDrive/LLM_RL_runs'
CKPT_DIR   = f'{DRIVE_DIR}/checkpoints'
LOG_FILE   = f'{DRIVE_DIR}/train.log'
os.makedirs(CKPT_DIR, exist_ok=True)
print(f'Checkpoints → {CKPT_DIR}')
print(f'Log         → {LOG_FILE}')

## 2. Clone / Update Repo

In [None]:
REPO_DIR = '/content/LLM_RL'
REPO_URL = 'https://github.com/jiayiderekchen/LLM_RL.git'

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL} {REPO_DIR}
else:
    !git -C {REPO_DIR} pull

os.chdir(REPO_DIR)
!ls

## 3. HuggingFace Login
Add `HF_TOKEN` to Colab Secrets (key icon in left sidebar) for one-click login.

In [None]:
from huggingface_hub import login
try:
    from google.colab import userdata
    login(token=userdata.get('HF_TOKEN'), add_to_git_credential=False)
    print('Logged in via Colab secret.')
except Exception:
    login()  # interactive fallback

## 4. Prepare GSM8K Dataset

In [None]:
os.chdir(REPO_DIR)
!python data/prepare_gsm8k.py --output_dir data/gsm8k
!ls -lh data/gsm8k/

## 5. PPO Training

Runs `scripts/train_ppo.sh` and streams output to both notebook and Drive log file.

> **Tip:** Enable *Background execution* (Runtime → Change runtime type) so the cell keeps running if the browser tab goes to sleep.

In [None]:
import subprocess, sys, os

os.chdir(REPO_DIR)

# GPU-aware batch sizes
if IS_A100_OR_H100:
    extra_overrides = [
        'data.train_batch_size=128',
        'actor_rollout_ref.actor.ppo_mini_batch_size=64',
        'actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4',
        'actor_rollout_ref.rollout.gpu_memory_utilization=0.5',
        'critic.ppo_micro_batch_size_per_gpu=4',
    ]
else:  # T4
    extra_overrides = [
        'data.train_batch_size=32',
        'data.max_prompt_length=384',
        'data.max_response_length=512',
        'actor_rollout_ref.actor.ppo_mini_batch_size=16',
        'actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2',
        'actor_rollout_ref.actor.fsdp_config.param_offload=true',
        'actor_rollout_ref.rollout.gpu_memory_utilization=0.35',
        'critic.ppo_micro_batch_size_per_gpu=2',
        'trainer.total_epochs=10',
        'trainer.test_freq=50',
    ]

cmd = [
    'bash', 'scripts/train_ppo.sh',
    f'trainer.default_local_dir={CKPT_DIR}',
] + extra_overrides

env = {**os.environ, 'N_GPUS': '1'}

print('Command:', ' '.join(cmd))
print(f'Checkpoints → {CKPT_DIR}')
print(f'Log         → {LOG_FILE}\n')

with open(LOG_FILE, 'w') as logf:
    proc = subprocess.Popen(cmd, env=env,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT,
                            text=True, bufsize=1)
    for line in proc.stdout:
        print(line, end='')   # stream to notebook
        logf.write(line)       # write to Drive
        logf.flush()
    proc.wait()

print(f'\nTraining finished. Exit code: {proc.returncode}')

## 6. Monitor Log (run while training is in progress)
Run this cell in parallel to watch key metrics from the Drive log.

In [None]:
import time, re
from IPython.display import clear_output

POLL_S = 30
step_re   = re.compile(r'global_step[:\s]+(\d+)')
reward_re = re.compile(r'reward[\w/]*[:\s]+([\d\.]+)', re.I)

print(f'Monitoring {LOG_FILE} — Ctrl+C to stop (training keeps running).')
while True:
    try:
        lines = open(LOG_FILE).readlines() if os.path.exists(LOG_FILE) else []
    except Exception:
        lines = []

    recent  = lines[-300:]
    steps   = [int(m.group(1)) for l in recent if (m := step_re.search(l))]
    rewards = [float(m.group(1)) for l in recent if (m := reward_re.search(l))]

    clear_output(wait=True)
    print(f'=== Training Monitor [{time.strftime("%H:%M:%S")}] ===')
    print(f'Log lines : {len(lines)}')
    print(f'Step      : {steps[-1] if steps else "–"}')
    print(f'Reward    : {rewards[-1]:.4f}' if rewards else 'Reward    : –')
    if len(rewards) >= 5:
        print(f'Avg(last5): {sum(rewards[-5:])/5:.4f}')
    print('\n--- Last 15 lines ---')
    for l in lines[-15:]:
        print(l.rstrip())
    time.sleep(POLL_S)

## 7. Evaluate Checkpoint from Google Drive

Run after training (or mid-training to check a specific step).

In [None]:
import glob

# Auto-detect latest checkpoint in Drive
ckpts = sorted(glob.glob(f'{CKPT_DIR}/*/global_step_*'))
if ckpts:
    for c in ckpts:
        print(c)
    EVAL_CKPT = ckpts[-1]
else:
    print('No checkpoints found in Drive yet.')
    EVAL_CKPT = None

print(f'\nWill evaluate: {EVAL_CKPT}')

In [None]:
# Override EVAL_CKPT manually if needed:
# EVAL_CKPT = f'{CKPT_DIR}/qwen3_gsm8k_ppo/global_step_500'

if EVAL_CKPT:
    os.chdir(REPO_DIR)
    !python evaluation/eval_gsm8k.py \
        --model_path {EVAL_CKPT} \
        --split test \
        --max_new_tokens 512 \
        --batch_size 16
else:
    print('Set EVAL_CKPT to a valid checkpoint path.')

## 8. Compare Base vs PPO

In [None]:
# Evaluate base model for comparison
os.chdir(REPO_DIR)
!python evaluation/eval_gsm8k.py \
    --model_path Qwen/Qwen3-1.7B \
    --split test \
    --max_new_tokens 512 \
    --batch_size 16

In [None]:
# Print summary of all eval results
import json, glob

result_files = sorted(glob.glob(f'{REPO_DIR}/evaluation/results/*.json'))
print(f'{'Model':<55} {'Accuracy':>9}  Correct/Total')
print('-' * 80)
for rf in result_files:
    with open(rf) as f:
        r = json.load(f)['report']
    name = os.path.basename(r['model_path'])
    print(f"{name:<55} {r['accuracy']*100:>8.2f}%  {r['correct']}/{r['total']}")