# CADFire Training Notebook — Iterative Branch Edition

Designed for **repeated, interruptible training runs** across multiple named
checkpoint branches. Each phase (1–4) can be run as many times as you like;
every run continues from the last checkpoint in the active branch.

## Branch Metaphor
A *branch* is a named training lineage stored in `model_saves/<branch>/`. You can:
- **create** a fresh branch to start a new experiment
- **fork** an existing branch to diverge from any saved snapshot
- **resume** the latest checkpoint in any branch and keep training
- **snapshot** the current weights with a human-readable tag at any step

## Predefined Branch Names
Branches use memorable names from a curated pool so GIFs and checkpoints are easy
to refer to. You can also supply any custom name.

## Phase Design
| Phase | What trains | Re-runnable? | When to re-run |
|-------|-------------|-------------|----------------|
| **1 – Tool Classifier** | text + fusion + tool head | Yes | After new tools / task prompts |
| **2 – Semantic Cursor** | all params | Yes | After new supervised task types |
| **3 – Teacher Forcing** | all params | Yes | After new multi-step task types |
| **— Diagnostics** | — | Yes | Health check at any time |
| **4 – PPO RL** | all params | Yes | Resume whenever; pause anytime |

> **Tip**: If the agent learned faster than expected, stop Phase 4 early with the
> interrupt button (■) and continue later. Progress is always saved.


In [None]:
# Cell 1: Setup (run once per instance)
!pip install -q torch numpy matplotlib pillow imageio

import sys, os
sys.path.insert(0, os.getcwd())
print('Python path set. Ready.')


In [None]:
# Cell 2: Branch management
# Run this cell once per session.  It defines all branch helpers and loads
# (or creates) the branch registry stored in model_saves/branches.json.

import datetime, json, shutil, time
from pathlib import Path

SAVES_ROOT    = Path('model_saves')
REGISTRY_FILE = SAVES_ROOT / 'branches.json'
SAVES_ROOT.mkdir(parents=True, exist_ok=True)

# Memorable auto-assign pool
_BRANCH_POOL = [
    'brian', 'jimmy', 'karen', 'duke', 'atlas', 'nova', 'remy', 'sage',
    'felix', 'iris', 'coda', 'echo', 'zara', 'orion', 'pixel', 'nexus',
    'drift', 'quill', 'ember', 'frost', 'spark', 'lyra', 'cedar', 'storm',
]

def _load_registry():
    if REGISTRY_FILE.exists():
        with open(REGISTRY_FILE) as f:
            return json.load(f)
    return {}

def _save_registry(reg):
    with open(REGISTRY_FILE, 'w') as f:
        json.dump(reg, f, indent=2)

def _branch_dir(name):
    return SAVES_ROOT / name

def _next_auto_name():
    used = set(_load_registry().keys())
    for n in _BRANCH_POOL:
        if n not in used:
            return n
    return 'run_' + datetime.datetime.now().strftime('%Y%m%d_%H%M')

def _branch_stats(name):
    d   = _branch_dir(name)
    pts = sorted(d.glob('*.pt')) if d.exists() else []
    diag = d / 'diagnostics.json'
    best = None
    if diag.exists():
        try:
            dat  = json.load(open(diag))
            best = dat.get('best_reward')
        except Exception:
            pass
    return {'n_ckpts': len(pts),
            'size_mb': round(sum(p.stat().st_size for p in pts)/1e6, 1),
            'best_reward': best,
            'tags': [p.stem for p in pts]}

# ---- Public API ----------------------------------------------------------

def create_branch(name=None, note=''):
    """Create a new training branch. name=None auto-assigns from the pool."""
    if name is None:
        name = _next_auto_name()
    reg = _load_registry()
    if name in reg:
        print(f"Branch '{name}' already exists. Use switch_branch('{name}').")
        return name
    _branch_dir(name).mkdir(parents=True, exist_ok=True)
    reg[name] = {'created': datetime.datetime.now().isoformat(),
                 'forked_from': None, 'note': note}
    _save_registry(reg)
    print(f"Created branch: '{name}'  ->  {_branch_dir(name)}/")
    return name

def fork_branch(from_name, to_name=None, from_tag='latest', note=''):
    """Fork from_name/from_tag.pt into a new branch to_name."""
    if to_name is None:
        to_name = _next_auto_name()
    src = _branch_dir(from_name) / f'{from_tag}.pt'
    if not src.exists():
        raise FileNotFoundError(
            f"'{from_tag}.pt' not found in branch '{from_name}'.\n"
            f"Available: {[p.name for p in _branch_dir(from_name).glob('*.pt')]}")
    _branch_dir(to_name).mkdir(parents=True, exist_ok=True)
    shutil.copy(src, _branch_dir(to_name) / 'latest.pt')
    reg = _load_registry()
    reg[to_name] = {'created': datetime.datetime.now().isoformat(),
                    'forked_from': f'{from_name}/{from_tag}', 'note': note}
    reg.setdefault(from_name, {})
    _save_registry(reg)
    print(f"Forked '{from_name}/{from_tag}' -> branch '{to_name}'")
    return to_name

def list_branches(verbose=True):
    """Show all branches with checkpoint counts and best reward."""
    reg = _load_registry()
    if not reg:
        print('No branches yet. Use create_branch() to start.')
        return []
    if verbose:
        print(f"{'Branch':<12}  {'Ckpts':<6}  {'Size':<8}  {'Best Reward':<13}  Note / Forked From")
        print('-' * 72)
    names = []
    for name in sorted(reg):
        s = _branch_stats(name)
        br = f"{s['best_reward']:.4f}" if s['best_reward'] is not None else '-'
        note = reg[name].get('note') or reg[name].get('forked_from') or ''
        if verbose:
            print(f"{name:<12}  {s['n_ckpts']:<6}  {s['size_mb']:>5.1f} MB  {br:<13}  {note}")
        names.append(name)
    if verbose: print()
    return names

def switch_branch(name):
    """Set the active branch for this session."""
    global ACTIVE_BRANCH, CKPT_DIR
    if not _branch_dir(name).exists():
        print(f"Branch '{name}' not found. Create it first.")
        return
    ACTIVE_BRANCH = name
    CKPT_DIR      = str(_branch_dir(name))
    print(f"Active branch: '{ACTIVE_BRANCH}'  ->  {CKPT_DIR}/")

def save_snapshot(tag, branch=None):
    """Copy latest.pt -> <tag>.pt as a permanent labelled snapshot."""
    b   = branch or ACTIVE_BRANCH
    src = _branch_dir(b) / 'latest.pt'
    dst = _branch_dir(b) / f'{tag}.pt'
    if not src.exists():
        print(f"No latest.pt in branch '{b}'.")
        return
    shutil.copy(src, dst)
    print(f"Snapshot saved: {dst}  ({dst.stat().st_size/1e6:.1f} MB)")

def branch_history(branch=None, last_n=20):
    """Print recent PPO training log entries for a branch."""
    b    = branch or ACTIVE_BRANCH
    diag = _branch_dir(b) / 'diagnostics.json'
    if not diag.exists():
        print(f"No diagnostics.json in branch '{b}'.")
        return
    d   = json.load(open(diag))
    log = d.get('training_log', [])
    print(f"Branch '{b}' | steps={d.get('total_steps',0):,} | best={d.get('best_reward', 0):.4f}")
    for e in log[-last_n:]:
        print(f"  step={e.get('step',0):>8,}  reward={e.get('avg_reward',0):.3f}"
              f"  len={e.get('avg_episode_length',0):.1f}  diff={e.get('difficulty',0):.1f}")

ACTIVE_BRANCH = None
CKPT_DIR      = None

print('Branch helpers loaded:')
print('  create_branch([name])    - start a new training lineage')
print('  fork_branch(src, [dst])  - copy a snapshot into a new branch')
print('  switch_branch(name)      - activate a branch for training')
print('  list_branches()          - show all branches')
print('  save_snapshot(tag)       - pin weights with a label')
print('  branch_history([name])   - recent RL training log')


In [None]:
# Cell 3: Pick (or create) the active branch
# -----------------------------------------------
# Option A: Create a fresh auto-named branch
# name = create_branch()            # -> 'brian', 'jimmy', ...

# Option B: Create with a custom name
# name = create_branch('my_exp')

# Option C: Resume an existing branch
# switch_branch('brian')

# Option D: Fork from a snapshot in another branch
# fork_branch('brian', 'jimmy', from_tag='phase3_final')
# switch_branch('jimmy')

# ---- Default: create a new auto-named branch ----
name = create_branch()
switch_branch(name)

print()
list_branches()


In [None]:
# Cell 4: Verify environment
import torch
print(f'PyTorch : {torch.__version__}')
print(f'CUDA    : {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU     : {torch.cuda.get_device_name(0)}')
    print(f'Memory  : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f'\nRL tasks ({TaskRegistry.count()}): {TaskRegistry.list_tasks()[:8]}...')

from cadfire.utils.config import num_tools
print(f'Tools: {num_tools()}')

# Vocabulary coverage check
from cadfire.utils.vocab_tracker import VocabTracker
VocabTracker().build(); VocabTracker().print_summary()


## Phase 1 — Tool Classifier Pretraining

**Re-runnable.** Each execution loads the latest checkpoint in the active branch
and continues training. Run again whenever new tools or prompt variants are added.

- Trains: text encoder + fusion bridge + tool head  
- Frozen: vision encoder, spectral encoder, cursor head


In [None]:
# Phase 1 (re-runnable)
from train import run_pretrain_tool
assert CKPT_DIR, 'Run Cell 3 first to pick a branch.'

agent, history1 = run_pretrain_tool(
    num_epochs=30,    # increase to 50 for ~95%+ accuracy
    lr=1e-3,
    batch_size=64,
    checkpoint_dir=CKPT_DIR,
)

print(f'Phase 1 accuracy : {history1["tool_accuracies"][-1]:.1%}')
print(f'Phase 1 loss     : {history1["tool_losses"][-1]:.4f}')


In [None]:
# Phase 1 eval + snapshot
import matplotlib.pyplot as plt

def _plot_history(h, title=''):
    has_cursor = bool(h.get('cursor_losses'))
    ncols = 3 if has_cursor else 2
    fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 3))
    axes[0].plot(h.get('tool_losses', []), color='steelblue')
    axes[0].set_title('Tool Loss'); axes[0].set_xlabel('Epoch')
    axes[1].plot(h.get('tool_accuracies', []), color='green')
    axes[1].set_title('Tool Accuracy'); axes[1].set_xlabel('Epoch'); axes[1].set_ylim(0,1)
    if has_cursor:
        axes[2].plot(h['cursor_losses'], color='orange')
        axes[2].set_title('Cursor Loss'); axes[2].set_xlabel('Epoch')
    if title: fig.suptitle(title, fontsize=12)
    plt.tight_layout(); plt.show()

_plot_history(history1, title=f'Phase 1 - {ACTIVE_BRANCH}')
save_snapshot('phase1_final')


## Phase 2 — Semantic Cursor Pretraining

**Re-runnable.** Run again after adding new supervised task types.

Trains all parameters on 20+ task types with Gaussian cursor targets.


In [None]:
# Phase 2 (re-runnable)
from train import run_pretrain_semantic

agent, history2 = run_pretrain_semantic(
    agent=globals().get('agent'),   # reuse if in memory
    num_samples=20_000,
    num_epochs=20,
    lr=3e-4,
    batch_size=32,
    sigma=12.0,
    cursor_weight=1.0,
    num_workers=0,
    checkpoint_dir=CKPT_DIR,
)
print(f'Phase 2 accuracy    : {history2["tool_accuracies"][-1]:.1%}')
print(f'Phase 2 cursor loss : {history2["cursor_losses"][-1]:.4f}')


In [None]:
_plot_history(history2, title=f'Phase 2 - {ACTIVE_BRANCH}')
save_snapshot('phase2_final')


## Phase 3 — Teacher-Forced Multi-Step Pretraining

**Re-runnable.** Run again after adding new multi-step task types.

Oracle actions drive the environment; agent loss is computed per step.


In [None]:
# Phase 3 (re-runnable)
from train import run_pretrain_teacher

agent, history3 = run_pretrain_teacher(
    agent=globals().get('agent'),
    num_trajectories=5_000,
    num_epochs=15,
    lr=1e-4,
    batch_size=16,
    polygon_ratio=0.7,
    checkpoint_dir=CKPT_DIR,
)
print(f'Phase 3 accuracy : {history3["tool_accuracies"][-1]:.1%}')


In [None]:
_plot_history(history3, title=f'Phase 3 - {ACTIVE_BRANCH}')
save_snapshot('phase3_final')


## Diagnostics — Multi-Task GIFs with Prompt Overlay

Generates animated GIFs for polygon tracing AND a variety of RL task types.
Each frame shows the viewport, cursor heatmap, and the **current text prompt**.

Run at any point — after any phase, or during a PPO pause.


In [None]:
# Diagnostics - polygon tracing (classic)
import torch
from cadfire.training.diagnostics import generate_diagnostic_gifs

device   = 'cuda' if torch.cuda.is_available() else 'cpu'
diag_dir = f'diagnostics/{ACTIVE_BRANCH}'

gif_metrics = generate_diagnostic_gifs(
    agent, output_dir=diag_dir, n_episodes=6, device=device, fps=1.5, verbose=True
)
print(f'GIFs -> {diag_dir}/')


In [None]:
# Diagnostics - RL task rollouts (draw / select / modify / view)
from cadfire.training.diagnostics import generate_task_rollout_gifs

task_metrics = generate_task_rollout_gifs(
    agent,
    task_categories=['draw', 'select', 'modify', 'view'],
    n_per_category=2,
    output_dir=f'diagnostics/{ACTIVE_BRANCH}/rl_tasks',
    device=device,
    fps=1.5,
    verbose=True,
)


## Phase 4 — PPO Reinforcement Learning

**Re-runnable / resumable.** Always loads the latest checkpoint in the active
branch. Interrupt at any time with the stop button — progress is saved every
`save_interval` steps.

**Workflow tips:**
- Watch the reward curve. If it plateaus, interrupt, re-run Phase 2/3, then resume.
- Snapshot at milestones: `save_snapshot('rl_v2')`.
- Fork to try hyperparameter variants: `fork_branch(ACTIVE_BRANCH, 'jimmy')`.
- Load a specific snapshot: `fork_branch('brian', 'test', from_tag='phase3_final')`.


In [None]:
# Phase 4 - PPO (resumable, interruptible)
from train import run_training

run_training(
    num_steps=200_000,   # interrupt anytime; re-run to continue
    resume=True,         # always resume from latest in the branch
    checkpoint_dir=CKPT_DIR,
)


In [None]:
# PPO training curves for the active branch
import matplotlib.pyplot as plt, json
from pathlib import Path

diag_path = Path(CKPT_DIR) / 'diagnostics.json'
if not diag_path.exists():
    print('No diagnostics.json found. Run Phase 4 first.')
else:
    d   = json.load(open(diag_path))
    log = d['training_log']
    steps   = [e['step'] for e in log]
    rewards = [e.get('avg_reward', 0) for e in log]
    diffs   = [e.get('difficulty', 0) for e in log]
    entropies = [e.get('entropy', 0) for e in log]

    fig, axes = plt.subplots(1, 3, figsize=(15, 3))
    axes[0].plot(steps, rewards, color='steelblue', lw=1)
    axes[0].set_title('Avg Reward'); axes[0].set_xlabel('Step')
    axes[1].plot(steps, diffs, color='orange', lw=1)
    axes[1].set_title('Curriculum Difficulty'); axes[1].set_xlabel('Step')
    axes[2].plot(steps, entropies, color='green', lw=1)
    axes[2].set_title('Policy Entropy'); axes[2].set_xlabel('Step')

    fig.suptitle(f'Branch: {ACTIVE_BRANCH}  Best reward: {d.get("best_reward",0):.4f}')
    plt.tight_layout(); plt.show()
    print(f'Total steps: {d.get("total_steps",0):,}')
    print(f'Best reward: {d.get("best_reward", float("-inf")):.4f}')


In [None]:
# Compare reward curves across all branches
import matplotlib.pyplot as plt, json
from pathlib import Path

colors = ['steelblue','orange','green','red','purple','brown','teal','magenta']
fig, ax = plt.subplots(figsize=(12, 4))

for i, bname in enumerate(list_branches(verbose=False)):
    diag = Path('model_saves') / bname / 'diagnostics.json'
    if not diag.exists():
        continue
    d   = json.load(open(diag))
    log = d.get('training_log', [])
    if not log:
        continue
    steps   = [e['step'] for e in log]
    rewards = [e.get('avg_reward', 0) for e in log]
    ax.plot(steps, rewards, label=bname, color=colors[i % len(colors)], lw=1.2)

ax.set_title('Reward Comparison - All Branches')
ax.set_xlabel('Training Step'); ax.set_ylabel('Avg Reward')
ax.legend(); plt.tight_layout(); plt.show()


In [None]:
# Checkpoint browser - all branches
import time
from pathlib import Path

print(f"{'Modified':<18}  {'Size':<9}  {'Branch':<12}  Tag")
print('-' * 62)
for pt in sorted(Path('model_saves').rglob('*.pt'),
                 key=lambda p: p.stat().st_mtime):
    mtime  = time.strftime('%Y-%m-%d %H:%M', time.localtime(pt.stat().st_mtime))
    size   = f'{pt.stat().st_size / 1e6:.1f} MB'
    branch = pt.parent.name
    print(f'{mtime:<18}  {size:<9}  {branch:<12}  {pt.stem}')


In [None]:
# Load a specific checkpoint from any branch
# Useful for: inspecting old checkpoints, comparing agent behaviour,
# or forking from a specific milestone.

import torch
from cadfire.model.cad_agent import CADAgent
from cadfire.training.checkpoint import CheckpointManager
from cadfire.utils.config import load_config

# Choose which branch and tag to load:
LOAD_BRANCH = ACTIVE_BRANCH      # or any other branch name
LOAD_TAG    = 'phase3_final'     # or 'latest', 'best', 'rl_v2', ...

config  = load_config()
agent2  = CADAgent(config)
device  = 'cuda' if torch.cuda.is_available() else 'cpu'

ckpt = CheckpointManager(f'model_saves/{LOAD_BRANCH}', config)
meta = ckpt.load(agent2, optimizer=None, tag=LOAD_TAG, device=device)
print(f"Loaded branch='{LOAD_BRANCH}' tag='{LOAD_TAG}' step={meta.get('step',0):,}")
agent2.to(device)


In [None]:
# Visualize agent behavior on a specific task
import numpy as np, torch, matplotlib.pyplot as plt
from cadfire.env.cad_env import CADEnv
from cadfire.tasks.registry import TaskRegistry
from cadfire.utils.config import load_config

TaskRegistry.discover()
config = load_config()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

env  = CADEnv(config, task_category='draw')  # try 'select', 'modify', 'view'
obs, info = env.reset()
prompt = info.get('prompt', '')
print(f'Prompt: {prompt}')

frames = []
for _ in range(20):
    obs_t = {k: torch.from_numpy(v).float().unsqueeze(0).to(device)
             if isinstance(v, np.ndarray) else
             torch.tensor([[v]], dtype=torch.long).to(device)
             for k, v in obs.items()}
    with torch.no_grad():
        act = agent.act(obs_t, deterministic=True)
    obs, r, term, trunc, info = env.step({
        'tool_id': act['tool_id'].item(),
        'cursor':  act['cursor'].squeeze().cpu().numpy(),
        'param':   act['param'].item(),
    })
    frames.append(obs['image'][:,:,:3].copy())
    if term or trunc: break

cols = min(6, len(frames))
fig, axes = plt.subplots(1, cols, figsize=(3*cols, 3))
axes = axes if cols > 1 else [axes]
for ax, fr in zip(axes, frames[::max(1, len(frames)//cols)]):
    ax.imshow((fr*255).astype('uint8')); ax.axis('off')
plt.suptitle(f'Prompt: {prompt}'); plt.tight_layout(); plt.show()


In [None]:
# Export to DXF
from cadfire.export.dxf_writer import DXFWriter
from cadfire.engine.cad_engine import CADEngine
from cadfire.utils.config import load_config

config = load_config()
engine = CADEngine(config)
# (populate engine.entities first, e.g. by running a task)

out_path = f'output_{ACTIVE_BRANCH}.dxf'
DXFWriter(config).write(engine, out_path)
print(f'DXF written to: {out_path}')


In [None]:
# Pull latest code + continue (tool-list growth handled automatically)
import subprocess, importlib

result = subprocess.run(['git', 'pull', 'origin', 'main'],
                        capture_output=True, text=True)
print(result.stdout.strip() or result.stderr.strip())

import cadfire.utils.config as _cfg
importlib.reload(_cfg)
from cadfire.utils.config import num_tools
print(f'Tools after pull: {num_tools()}')

# Re-run Phase 1 to absorb new prompts, then continue PPO
# from train import run_pretrain_tool, run_training
# run_pretrain_tool(num_epochs=20, checkpoint_dir=CKPT_DIR)
# run_training(num_steps=100_000, resume=True, checkpoint_dir=CKPT_DIR)
print('Uncomment the lines above to retrain + continue after pull.')
