# CADFire Training Notebook

Interactive training notebook for vast.ai GPU instances.

## Four-Phase Pipeline

| Phase | Cell | What trains | Why | Typical time |
|-------|------|-------------|-----|--------------|
| **1** | 3a | Text encoder + fusion + tool head | Text → tool mapping, no vision | ~1 min |
| **2** | 3b | All parameters (incl. text encoder) | Vision + text → tool + cursor, 20 task types, single-step | ~15–30 min |
| **3** | 3c | All parameters | 2–9-step trajectories with oracle teacher forcing | ~30–60 min |
| — | 3d | — | GIF diagnostics to evaluate polygon tracing before RL | ~1 min |
| **4** | 3e | All parameters (PPO) | Curriculum RL, sparse rewards, exploration | hours |

## Phase-2 Supervised Tasks (20 types)

**Original 11:** SELECT · MULTISELECT · ERASE · PAN · ZOOM_IN · ZOOM_OUT · HATCH · POLYLINE-next-vertex · COPY · MOVE · ROTATE

**New single-step:** SCALE · MIRROR · OFFSET

**New multi-turn chat** *(prompt = `"Draw a {shape} | <instruction>"`)*:
- `ScaleFromChatTask`       – *"Draw a circle | make it smaller"*  → SCALE
- `MoveFromChatTask`        – *"Draw a circle | move it right"*    → MOVE + cursor at destination
- `RotateFromChatTask`      – *"Draw a circle | rotate it 90°"*    → ROTATE
- `EraseFromChatTask`       – *"Draw a circle | delete it"*        → ERASE
- `ChangeColorFromChatTask` – *"Draw a circle | change it to red"* → COLOR_SET
- `CopyFromChatTask`        – *"Draw a circle | copy it to the right"* → COPY + cursor at destination

## Checkpoint Convention

```
checkpoints/
  run_20260219_1400/          ← one directory per training run
    latest.pt                 ← overwritten every save_interval steps
    best.pt                   ← highest RL reward seen
    phase1_final.pt           ← snapshot after Phase 1
    phase2_final.pt           ← snapshot after Phase 2
    phase3_final.pt           ← snapshot after Phase 3
    diagnostics.json          ← PPO training log
  run_20260220_0900/
    ...
```

> **Tip**: Pass `RESUME_FROM` in Cell 2 to warm-start from a previous run.

In [None]:
# Cell 1: Setup (run once per instance)
# pillow is required for diagnostic GIF generation (Cell 3d)
!pip install -q torch numpy matplotlib pillow

import sys, os
sys.path.insert(0, os.getcwd())

In [None]:
# Cell 2: Git pull + Run configuration
# ─────────────────────────────────────────────────────────────────────────────
# Run this at the start of every session.
# It pulls the latest code, names this training run, and sets up the
# checkpoint directory.  All subsequent cells read CKPT_DIR automatically.

import datetime, json, shutil, subprocess
from pathlib import Path

# ── Pull latest code ──────────────────────────────────────────────────────────
result = subprocess.run(["git", "pull", "origin", "main"],
                        capture_output=True, text=True)
print(result.stdout.strip() or result.stderr.strip())

# ── Run name ──────────────────────────────────────────────────────────────────
# Override RUN_NAME with a memorable label, e.g. "multiturn_v1" or "scale_fix".
# Default: timestamp so every run is uniquely tracked.
RUN_NAME    = f"run_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
CKPT_DIR    = f"checkpoints/{RUN_NAME}"

# ── Optional: warm-start from a previous run ──────────────────────────────────
# Set to a previous CKPT_DIR to copy its checkpoints in as the starting point.
# Leave as None to start fresh.
RESUME_FROM = None   # e.g. "checkpoints/run_20260219_1400"

# ── Create checkpoint directory ───────────────────────────────────────────────
Path(CKPT_DIR).mkdir(parents=True, exist_ok=True)

if RESUME_FROM and Path(RESUME_FROM).exists():
    for pt in Path(RESUME_FROM).glob("*.pt"):
        shutil.copy(pt, Path(CKPT_DIR) / pt.name)
    diag = Path(RESUME_FROM) / "diagnostics.json"
    if diag.exists():
        shutil.copy(diag, Path(CKPT_DIR) / "diagnostics.json")
    print(f"Warm-started from  : {RESUME_FROM}")
    print(f"  Copied files     : {list(Path(CKPT_DIR).iterdir())}")
else:
    print("Starting fresh run (no RESUME_FROM)")

print(f"Run name           : {RUN_NAME}")
print(f"Checkpoint dir     : {CKPT_DIR}")

In [None]:
# Cell 3: Helper functions (run once after Cell 2)
# ─────────────────────────────────────────────────────────────────────────────
# checkpoint_status()    – table of every .pt file across all runs
# save_phase(label)      – copy latest.pt → <label>.pt (permanent snapshot)
# plot_phase_history()   – loss/accuracy curves for a pretraining history dict

import json, shutil, time
from pathlib import Path
import matplotlib.pyplot as plt


def checkpoint_status(base_dir: str = "checkpoints") -> None:
    """Print a table of all .pt checkpoints across every run."""
    rows = []
    for pt in sorted(Path(base_dir).rglob("*.pt")):
        mtime   = pt.stat().st_mtime
        size_mb = pt.stat().st_size / 1e6
        rows.append((time.strftime("%Y-%m-%d %H:%M", time.localtime(mtime)),
                     f"{size_mb:.1f} MB", str(pt)))
    if rows:
        print(f"{'Modified':<18}  {'Size':<9}  Path")
        print("-" * 72)
        for ts, sz, path in rows:
            print(f"{ts:<18}  {sz:<9}  {path}")
    else:
        print(f"No checkpoints found under {base_dir}/")


def save_phase(label: str, ckpt_dir: str = CKPT_DIR) -> None:
    """Copy latest.pt → <label>.pt for a permanent per-phase snapshot."""
    src = Path(ckpt_dir) / "latest.pt"
    if src.exists():
        dst = Path(ckpt_dir) / f"{label}.pt"
        shutil.copy(src, dst)
        sz = dst.stat().st_size / 1e6
        print(f"  Saved: {dst}  ({sz:.1f} MB)")
    else:
        print(f"  Warning: {src} not found – nothing saved.")


def plot_phase_history(history: dict, title: str = "") -> None:
    """Plot loss + accuracy curves from a pretrain_* history dict."""
    has_cursor = "cursor_losses" in history and history["cursor_losses"]
    ncols = 3 if has_cursor else 2
    fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, 3))

    axes[0].plot(history.get("tool_losses", []), color="steelblue")
    axes[0].set_title("Tool Loss"); axes[0].set_xlabel("Epoch")

    axes[1].plot(history.get("tool_accuracies", []), color="green")
    axes[1].set_title("Tool Accuracy"); axes[1].set_xlabel("Epoch")
    axes[1].set_ylim(0, 1)

    if has_cursor:
        axes[2].plot(history["cursor_losses"], color="orange")
        axes[2].set_title("Cursor Loss"); axes[2].set_xlabel("Epoch")

    if title:
        fig.suptitle(title, fontsize=12)
    plt.tight_layout()
    plt.show()


print("Helpers loaded: checkpoint_status(), save_phase(), plot_phase_history()")

In [None]:
# Cell 4: Verify environment
import torch
print(f"PyTorch  : {torch.__version__}")
print(f"CUDA     : {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU      : {torch.cuda.get_device_name(0)}")
    print(f"Memory   : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"\nRL tasks ({TaskRegistry.count()}): {TaskRegistry.list_tasks()}")

from cadfire.utils.config import num_tools, tool_list
print(f"Tools ({num_tools()}): {tool_list()[:10]}...")

# Show supervised task count (Phase 2)
from cadfire.training.pretrain_semantic import _TASK_REGISTRY
total_w = sum(w for w, _, _ in _TASK_REGISTRY)
print(f"\nPhase-2 supervised tasks: {len(_TASK_REGISTRY)}")
for w, cls, _ in _TASK_REGISTRY:
    pct = 100 * w / total_w
    print(f"  {cls.__name__:<28} weight={w:.1f} ({pct:.1f}%)")

In [None]:
# Cell 5: Checkpoint browser
# ─────────────────────────────────────────────────────────────────────────────
# Review what exists BEFORE starting training.
# Also shows what the RESUME_FROM run achieved if warm-starting.

checkpoint_status()

# If warm-starting, peek at the previous run's diagnostics
if RESUME_FROM:
    diag_path = Path(RESUME_FROM) / "diagnostics.json"
    if diag_path.exists():
        with open(diag_path) as f:
            d = json.load(f)
        print(f"\nPrevious run diagnostics ({RESUME_FROM}):")
        print(f"  Total RL steps : {d.get('total_steps', 0):,}")
        print(f"  Total episodes : {d.get('total_episodes', 0):,}")
        print(f"  Best reward    : {d.get('best_reward', float('-inf')):.4f}")
        if d.get("training_log"):
            last = d["training_log"][-1]
            print(f"  Last log entry : {last}")

In [None]:
# Cell 3a: Phase 1 — Tool Classifier Pretraining
# ─────────────────────────────────────────────────────────────────────────────
# Trains: text encoder + fusion bridge + tool head
# Frozen: vision encoder, cursor head (UNet decoder)
# Loss:   cross-entropy over (prompt → tool_id) pairs
#
# Idempotent: loads any existing checkpoint first, trains on top of it.
# Raise num_epochs to 50 for ~95%+ accuracy.

from train import run_pretrain_tool

agent, history1 = run_pretrain_tool(
    num_epochs=30,        # ~1 min on GPU; raise to 50 for higher accuracy
    lr=1e-3,
    batch_size=64,
    checkpoint_dir=CKPT_DIR,
)

In [None]:
# Cell 3a-eval: Post-Phase-1 evaluation + checkpoint snapshot
print(f"Phase 1 final tool accuracy : {history1['tool_accuracies'][-1]:.1%}")
print(f"Phase 1 final tool loss     : {history1['tool_losses'][-1]:.4f}")

plot_phase_history(history1, title="Phase 1 – Tool Classifier")
save_phase("phase1_final")

In [None]:
# Cell 3b: Phase 2 — Semantic Cursor Pretraining
# ─────────────────────────────────────────────────────────────────────────────
# Trains:  ALL parameters (vision encoder, text encoder, fusion, tool head,
#          cursor head).  Text encoder intentionally unfrozen so the model
#          links object names ("hexagon", "circle") to their visual appearance.
# Loss:    cross-entropy tool loss + focal-BCE cursor loss (Gaussian blobs)
#
# 20 supervised task types (original 11 + SCALE + MIRROR + OFFSET +
#   6 multi-turn chat tasks).
#
# Loads Phase-1 weights from checkpoint if `agent` is not in scope.

from train import run_pretrain_semantic

agent, history2 = run_pretrain_semantic(
    agent=globals().get("agent"),   # reuse Phase-1 agent if it ran above
    num_samples=20_000,   # generated samples per epoch; raise to 50_000 for better coverage
    num_epochs=20,
    lr=3e-4,
    batch_size=32,        # keep small — images are 256×256×19
    sigma=12.0,           # Gaussian blob radius in pixels
    cursor_weight=1.0,    # global scale on cursor-loss term
    num_workers=0,        # set to 4 on multi-core instances
    checkpoint_dir=CKPT_DIR,
)

In [None]:
# Cell 3b-eval: Post-Phase-2 evaluation + checkpoint snapshot
print(f"Phase 2 final tool accuracy : {history2['tool_accuracies'][-1]:.1%}")
print(f"Phase 2 final cursor loss   : {history2['cursor_losses'][-1]:.4f}")

plot_phase_history(history2, title="Phase 2 – Semantic Cursor")
save_phase("phase2_final")

# Quick sanity: compare Phase 1 vs Phase 2 tool accuracy
acc1 = history1["tool_accuracies"][-1]
acc2 = history2["tool_accuracies"][-1]
delta = acc2 - acc1
sign = "+" if delta >= 0 else ""
print(f"\nTool accuracy: Phase 1 → Phase 2  ({acc1:.1%} → {acc2:.1%}, {sign}{delta:.1%})")
if delta < -0.05:
    print("  ⚠ Accuracy dropped >5pp — consider more Phase 1 epochs or a lower Phase 2 LR.")

In [None]:
# Cell 3c: Phase 3 — Teacher-Forced Multi-Step Pretraining
# ─────────────────────────────────────────────────────────────────────────────
# Trains:  ALL parameters (same as Phase 2)
# Method:  at each trajectory step the ORACLE action is executed in the real
#          environment; agent loss computed per step — no error accumulation.
# Loss:    cross-entropy tool loss + focal-BCE cursor loss, summed over steps
#
# Trajectory mix:
#   70%  polygon-trace  (4–9 POLYLINE clicks + CONFIRM)
#   30%  2-step chains  (select→erase, select→rotate, select→copy)
#
# Loads Phase-2 weights from checkpoint if `agent` is not in scope.

from train import run_pretrain_teacher

agent, history3 = run_pretrain_teacher(
    agent=globals().get("agent"),    # reuse Phase-2 agent if it ran above
    num_trajectories=5_000,  # trajectories generated per epoch
    num_epochs=15,
    lr=1e-4,
    sigma=12.0,
    cursor_weight=1.5,       # cursor loss weighted higher — sequencing is hard
    polygon_ratio=0.7,
    checkpoint_dir=CKPT_DIR,
)

In [None]:
# Cell 3c-eval: Post-Phase-3 evaluation + checkpoint snapshot
print(f"Phase 3 final tool accuracy : {history3['tool_accuracies'][-1]:.1%}")
print(f"Phase 3 final cursor loss   : {history3['cursor_losses'][-1]:.4f}")
if "traj_lengths" in history3:
    print(f"Phase 3 avg traj length     : {history3['traj_lengths'][-1]:.2f} steps")

plot_phase_history(history3, title="Phase 3 – Teacher Forcing")
save_phase("phase3_final")

# Three-phase summary
print("\n── Three-phase tool-accuracy summary ──")
for label, hist in [("Phase 1", history1), ("Phase 2", history2), ("Phase 3", history3)]:
    acc = hist["tool_accuracies"][-1]
    print(f"  {label}: {acc:.1%}")

In [None]:
# Cell 3d: Post-Phase-3 Diagnostics — Polygon Tracing GIFs
# ─────────────────────────────────────────────────────────────────────────────
# Produces two GIF types per episode inside diagnostics/:
#   oracle_ep<N>.gif  – oracle-driven rollout with agent heatmap overlay
#   free_ep<N>.gif    – fully autonomous agent rollout (no teacher forcing)
#
# Inspect the free_ep GIFs:
#   ✓ Agent traces polygon vertices in order and ends with CONFIRM → ready for RL
#   ✗ Agent wanders / repeats a tool → run more Phase 3 epochs before continuing
# Requires: pillow (installed in Cell 1)

from train import run_diagnostics

diag_metrics = run_diagnostics(
    agent=globals().get("agent"),
    n_episodes=6,
    output_dir=f"{CKPT_DIR}/diagnostics",
    fps=1.5,
    checkpoint_dir=CKPT_DIR,
)

print("\nGenerated GIFs:")
for p in sorted(Path(f"{CKPT_DIR}/diagnostics").glob("*.gif")):
    print(f"  {p}")

if diag_metrics:
    print(f"\nFree-run mean reward  : {diag_metrics.get('mean_free_reward', 0):.3f}")
    print(f"Oracle mean reward    : {diag_metrics.get('mean_oracle_reward', 0):.3f}")

In [None]:
# Cell 3e: Phase 4 — PPO RL Training
# ─────────────────────────────────────────────────────────────────────────────
# Full agent training with curriculum learning.
# Automatically loads from the Phase-3 checkpoint; all prior supervised
# learning is preserved.
#
# Curriculum: difficulty cap starts at 2.0, increases every 5 000 steps.
# Checkpoints saved to {CKPT_DIR}/latest.pt and {CKPT_DIR}/best.pt.

from train import run_training

metrics_history = []
def collect_metrics(m):
    metrics_history.append(m)

trainer = run_training(
    num_steps=100_000,
    resume=True,           # auto-resume from {CKPT_DIR}/latest.pt
    device=None,           # auto-detect GPU
    max_difficulty=None,   # curriculum default
    checkpoint_dir=CKPT_DIR,
    callback=collect_metrics,
)

In [None]:
# Cell 6: Plot PPO training progress
import matplotlib.pyplot as plt
import json
from pathlib import Path

diag_path = Path(CKPT_DIR) / "diagnostics.json"
if diag_path.exists():
    with open(diag_path) as f:
        diag = json.load(f)
    log = diag["training_log"]

    fig, axes = plt.subplots(2, 3, figsize=(18, 8))
    steps = [e["step"] for e in log]

    axes[0, 0].plot(steps, [e.get("avg_reward",       0) for e in log])
    axes[0, 0].set_title("Average Reward");      axes[0, 0].set_xlabel("Step")

    axes[0, 1].plot(steps, [e.get("policy_loss",      0) for e in log])
    axes[0, 1].set_title("Policy Loss");         axes[0, 1].set_xlabel("Step")

    axes[0, 2].plot(steps, [e.get("difficulty",       0) for e in log])
    axes[0, 2].set_title("Curriculum Difficulty"); axes[0, 2].set_xlabel("Step")

    axes[1, 0].plot(steps, [e.get("value_loss",       0) for e in log])
    axes[1, 0].set_title("Value Loss");          axes[1, 0].set_xlabel("Step")

    axes[1, 1].plot(steps, [e.get("entropy",          0) for e in log])
    axes[1, 1].set_title("Entropy");             axes[1, 1].set_xlabel("Step")

    axes[1, 2].plot(steps, [e.get("steps_per_second", 0) for e in log])
    axes[1, 2].set_title("Steps/sec");           axes[1, 2].set_xlabel("Step")

    plt.suptitle(f"PPO — {RUN_NAME}", fontsize=13)
    plt.tight_layout()
    plt.savefig(Path(CKPT_DIR) / "training_curves.png", dpi=100)
    plt.show()

    print(f"Total steps    : {diag['total_steps']:,}")
    print(f"Total episodes : {diag['total_episodes']:,}")
    print(f"Best reward    : {diag['best_reward']:.4f}")
else:
    print(f"No diagnostics.json found in {CKPT_DIR}. Run Phase 4 (Cell 3e) first.")

In [None]:
# Cell 7: Multi-run comparison
# ─────────────────────────────────────────────────────────────────────────────
# Compare reward curves across multiple training runs on one plot.
# Edit RUN_DIRS to select the runs you want to compare.

import matplotlib.pyplot as plt
import json
from pathlib import Path

# ── Edit this list ────────────────────────────────────────────────────────────
RUN_DIRS = sorted(Path("checkpoints").glob("run_*/"))  # compare all runs
# RUN_DIRS = ["checkpoints/run_A", "checkpoints/run_B"]  # or pick specific ones
# ─────────────────────────────────────────────────────────────────────────────

fig, axes = plt.subplots(1, 2, figsize=(14, 4))
found = False

for run_dir in RUN_DIRS:
    dp = Path(run_dir) / "diagnostics.json"
    if not dp.exists():
        continue
    with open(dp) as f:
        d = json.load(f)
    log = d.get("training_log", [])
    if not log:
        continue
    label = Path(run_dir).name
    steps   = [e["step"]             for e in log]
    rewards = [e.get("avg_reward", 0) for e in log]
    entropy = [e.get("entropy",    0) for e in log]
    axes[0].plot(steps, rewards, label=label)
    axes[1].plot(steps, entropy, label=label)
    found = True

if found:
    axes[0].set_title("Average Reward");  axes[0].set_xlabel("Step"); axes[0].legend(fontsize=8)
    axes[1].set_title("Policy Entropy");  axes[1].set_xlabel("Step"); axes[1].legend(fontsize=8)
    plt.tight_layout()
    plt.show()
else:
    print("No PPO diagnostics found yet. Run Phase 4 first.")

In [None]:
# Cell 8: Visualize agent behavior on a specific task
import numpy as np
import matplotlib.pyplot as plt
import torch
from cadfire.env.cad_env import CADEnv
from cadfire.model.cad_agent import CADAgent
from cadfire.renderer.rasterizer import Renderer
from cadfire.tasks.registry import TaskRegistry
from cadfire.utils.config import load_config

config   = load_config()
env      = CADEnv(config)
renderer = Renderer(config)

# Load best checkpoint if it exists, otherwise latest
for tag in ("best", "latest", "phase3_final"):
    ckpt_path = Path(CKPT_DIR) / f"{tag}.pt"
    if ckpt_path.exists():
        agent_vis = CADAgent.load_checkpoint(str(ckpt_path), config)
        print(f"Loaded: {ckpt_path}")
        break
else:
    raise FileNotFoundError(f"No checkpoint found in {CKPT_DIR}")
agent_vis.eval()

# ── Change task_name to explore other behaviours ──────────────────────────────
TASK_NAME = "draw_circle"  # e.g. "draw_rectangle", "select_shape", "move_shape"

task = TaskRegistry.create(TASK_NAME, seed=42)
obs, info = env.reset(task=task)
print(f"Prompt: {info['prompt']}")

frames = [renderer.render_rgb_only(env.engine)]
total_reward = 0

for step in range(50):
    obs_t = {
        "image":     torch.tensor(obs["image"]).unsqueeze(0),
        "text_ids":  torch.tensor(obs["text_ids"], dtype=torch.long).unsqueeze(0),
        "state_vec": torch.tensor(obs["state_vec"]).unsqueeze(0),
    }
    action_info = agent_vis.act(obs_t, deterministic=True)
    action = {
        "tool_id": action_info["tool_id"].item(),
        "cursor":  action_info["cursor"].cpu().numpy()[0],
        "param":   action_info["param"].item(),
    }
    obs, reward, term, trunc, info = env.step(action)
    total_reward += reward
    frames.append(renderer.render_rgb_only(env.engine))
    if term or trunc:
        break

print(f"Total reward: {total_reward:.3f}, Steps: {step + 1}")

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, idx, title in zip(axes, [0, len(frames) // 2, -1], ["Start", "Middle", "End"]):
    ax.imshow(frames[idx])
    ax.set_title(title)
    ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Cell 9: Export drawing to DXF
from cadfire.export.dxf_writer import DXFWriter

out_path = Path(CKPT_DIR) / "output.dxf"
writer = DXFWriter()
writer.write(env.engine, str(out_path))
print(f"Exported to {out_path}")
print(f"Entities: {env.engine.entity_count()}")

In [None]:
# Cell 10: Pull updates + continue training
# ─────────────────────────────────────────────────────────────────────────────
# After new tasks, config changes, or bug fixes are pushed:
# 1. Pull latest code
# 2. (Optional) create a git tag at the current code+checkpoint state
# 3. Reload config + re-discover tasks
# 4. Resume Phase 4 — handles tool-list growth seamlessly

import subprocess
from pathlib import Path

# Pull
r = subprocess.run(["git", "pull", "origin", "main"], capture_output=True, text=True)
print(r.stdout.strip() or r.stderr.strip())

# (Optional) tag this checkpoint state in git for easy rollback
# Uncomment the lines below to create a lightweight git tag:
# tag_name = f"ckpt/{RUN_NAME}/step{trainer.total_steps}"
# subprocess.run(["git", "tag", tag_name], check=True)
# print(f"Git tag created: {tag_name}")

# Reload config + re-discover tasks (picks up any new task files)
from cadfire.utils.config import reload
reload()
from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"\nTasks after update ({TaskRegistry.count()}): {TaskRegistry.list_tasks()}")

# Continue Phase 4 — auto-resumes, extends tool head if new tools were added
from train import run_training
trainer = run_training(num_steps=50_000, resume=True, checkpoint_dir=CKPT_DIR)