# Phase2 Scattering Router A/B (Colab ready)
- Clone repo (if needed)
- Install deps
- Train baseline router
- Train scattering router (`--use-scattering-router`)
- Show stdout/stderr if失敗
- Summarize loss/PPL from logs


In [None]:
# Repo setup (clone if needed, add to sys.path)
import os, sys, subprocess, pathlib
REPO_URL = 'https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git'
REPO_DIR = 'Project-ResNet-BK-An-O-N-Language-Model-Architecture'
cwd = pathlib.Path.cwd()
candidates = [cwd, cwd.parent, cwd / REPO_DIR, cwd.parent / REPO_DIR]
root = next((p for p in candidates if (p / 'src').exists()), None)
if root is None:
    root = cwd / REPO_DIR
    if not root.exists():
        subprocess.run(['git', 'clone', REPO_URL, str(root)], check=True)
if root != pathlib.Path.cwd():
    os.chdir(root)
root_str = str(pathlib.Path.cwd())
if root_str not in sys.path:
    sys.path.insert(0, root_str)
print('PWD:', root_str)


In [None]:
import subprocess, sys
print('Installing requirements...')
res = subprocess.run([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'], cwd=repo_root, text=True, capture_output=True)
print('returncode', res.returncode)
print(res.stdout)
print(res.stderr)
assert res.returncode == 0, 'pip install failed'


In [None]:
import subprocess, sys
baseline_dir = repo_root / 'checkpoints' / 'phase2_baseline'
baseline_dir.mkdir(parents=True, exist_ok=True)
cmd = [sys.executable, 'train.py', '--config-preset', 'baseline', '--save-dir', str(baseline_dir)]
print('Running baseline router:', ' '.join(cmd))
res = subprocess.run(cmd, cwd=repo_root, text=True, capture_output=True)
print('returncode', res.returncode)
print('--- stdout ---')
print(res.stdout)
print('--- stderr ---')
print(res.stderr)
assert res.returncode == 0, 'baseline training failed'
print('Logs in', baseline_dir / 'logs')


In [None]:
import subprocess, sys
scales = [0.05, 0.1, 0.2]
warmup = 100
summaries = []
for scale in scales:
    scatt_dir = repo_root / 'checkpoints' / f'phase2_scattering_{scale}'
    scatt_dir.mkdir(parents=True, exist_ok=True)
    cmd = [sys.executable, 'train.py', '--config-preset', 'baseline', '--use-scattering-router', '--scattering-scale', str(scale), '--scattering-scale-warmup-steps', str(warmup), '--save-dir', str(scatt_dir)]
    print('Running scattering router:', ' '.join(cmd))
    res = subprocess.run(cmd, cwd=repo_root, text=True, capture_output=True)
    print('returncode', res.returncode)
    print('--- stdout ---')
    print(res.stdout)
    print('--- stderr ---')
    print(res.stderr)
    assert res.returncode == 0, f'scattering training failed (scale={scale})'
    print('Logs in', scatt_dir / 'logs')


In [None]:
import csv, pathlib

def latest_csv(log_dir: pathlib.Path):
    files = sorted(log_dir.glob('*.csv'), key=lambda p: p.stat().st_mtime, reverse=True)
    return files[0] if files else None

def read_last_row(csv_path: pathlib.Path):
    with csv_path.open() as f:
        rows = list(csv.DictReader(f))
    return rows[-1] if rows else None

def summarize(run_name, save_dir: pathlib.Path):
    log_dir = save_dir / 'logs'
    csv_path = latest_csv(log_dir)
    if not csv_path:
        print(f"[{run_name}] no CSV in {log_dir}")
        return None
    last = read_last_row(csv_path)
    if not last:
        print(f"[{run_name}] empty CSV {csv_path}")
        return None
    out = {
        'path': str(csv_path),
        'step': int(last['step']),
        'epoch': int(last['epoch']),
        'loss': float(last['loss']),
        'ppl': float(last['perplexity']),
        'grad_norm': float(last['grad_norm']),
        'routing_entropy': float(last.get('routing_entropy', 0) or 0),
        'nan': int(last.get('num_nan_grads', 0) or 0),
        'inf': int(last.get('num_inf_grads', 0) or 0),
    }
    print(f"[{run_name}] step {out['step']} epoch {out['epoch']} loss {out['loss']:.4f} ppl {out['ppl']:.1f} grad_norm {out['grad_norm']:.3f} entropy {out['routing_entropy']:.4f} nan {out['nan']} inf {out['inf']} ({csv_path.name})")
    return out

base_summary = summarize('baseline', baseline_dir)
for scale in scales:
    scatt_dir = repo_root / 'checkpoints' / f'phase2_scattering_{scale}'
    scatt_summary = summarize(f'scattering_{scale}', scatt_dir)
    if base_summary and scatt_summary:
        print(f"Δloss (scale {scale}): {scatt_summary['loss'] - base_summary['loss']:+.4f}, Δppl: {scatt_summary['ppl'] - base_summary['ppl']:+.1f}, Δentropy: {scatt_summary['routing_entropy'] - base_summary['routing_entropy']:+.4f}")
    print('---')
