# CADFire Training Notebook

Training notebook for vast.ai GPU instances.

## Four-Phase Pipeline

| Phase | Cell | Description | Typical time |
|-------|------|-------------|-------------|
| **1** | 3a | **Tool Classifier** — text → tool, no vision. Text encoder + fusion + tool head trained from scratch via cross-entropy. | ~1 min |
| **2** | 3b | **Semantic Cursor** — vision + text → tool + cursor, single-step, 11 task types. *All* parameters unfrozen (incl. text encoder) so the model links object names to visual appearances. | ~10–30 min |
| **3** | 3c | **Teacher Forcing** — vision + text → tool + cursor, 2–9-step trajectories. Oracle advances the env at each step; loss computed per step. Bridges single-step Phase 2 and sparse-reward Phase 4. | ~20–60 min |
| — | 3d | **Diagnostics** — oracle + free-run GIFs to evaluate polygon-tracing capability before RL. | ~1 min |
| **4** | 3e | **PPO RL** — full agent, curriculum learning. Resumes from Phase-3 checkpoint. | hours |

> **Tip**: Phases 1–3 only need to run once per fresh instance.  
> After a pretrained checkpoint is saved, subsequent RL runs resume from it automatically.
>
> Weights accumulate across phases: passing `agent=agent` between cells re-uses in-memory weights;  
> omitting it loads from the checkpoint file on disk.

In [None]:
# Cell 1: Setup (run once per instance)
# pillow is required for diagnostic GIF generation (Phase 3d)
!pip install -q torch numpy matplotlib pillow

import sys, os
sys.path.insert(0, os.getcwd())

In [None]:
# Cell 2: Verify environment
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"\nRegistered tasks ({TaskRegistry.count()}): {TaskRegistry.list_tasks()}")

from cadfire.utils.config import num_tools, tool_list
print(f"Tools ({num_tools()}): {tool_list()[:10]}...")

In [None]:
# Cell 3a: Phase 1 — Tool Classifier Pretraining
#
# Trains: text encoder + fusion bridge + tool head
# Frozen:  vision encoder, cursor head (UNet decoder)
# Loss:    cross-entropy over (prompt → tool_id) pairs
#
# Idempotent: loads any existing checkpoint first, trains on top of it.
# Raise num_epochs to 50 for ~95%+ accuracy.

from train import run_pretrain_tool

agent = run_pretrain_tool(
    num_epochs=30,        # ~1 min on GPU; raise to 50 for higher accuracy
    lr=1e-3,
    batch_size=64,
    checkpoint_dir="checkpoints_1",
)

In [None]:
# Cell 3b: Phase 2 — Semantic Cursor Pretraining
#
# Trains:  ALL parameters (vision encoder, text encoder, fusion, tool head, cursor head)
# Frozen:  nothing — text encoder intentionally unfrozen so the model learns to
#          associate object names ("hexagon", "circle") with visual appearances.
# Loss:    cross-entropy tool loss + MSE cursor loss (Gaussian blob targets)
#
# 11 one-step supervised task types:
#   SELECT, MULTISELECT, ERASE, PAN (4 dirs), ZOOM_IN, ZOOM_OUT,
#   HATCH, POLYLINE trace-next, COPY, MOVE, ROTATE
#
# Loads Phase-1 weights from checkpoint if `agent` is not already in scope.

from train import run_pretrain_semantic

agent = run_pretrain_semantic(
    agent=globals().get("agent"),   # reuse Phase-1 agent if it ran above
    num_samples=20_000,   # generated samples per epoch; raise to 50_000 for better coverage
    num_epochs=20,
    lr=3e-4,
    batch_size=32,        # keep small — images are 256×256×19
    sigma=12.0,           # Gaussian blob radius in pixels
    cursor_weight=1.0,    # relative weight of cursor loss vs tool loss
    num_workers=0,        # set to 4 on multi-core instances for faster dataloading
    checkpoint_dir="checkpoints_1",
)

In [None]:
# Cell 3c: Phase 3 — Teacher-Forced Multi-Step Pretraining
#
# Trains:  ALL parameters (same as Phase 2)
# Method:  at each trajectory step the ORACLE action is executed in the real
#          environment (teacher forcing); agent loss computed per-step.
# Loss:    cross-entropy tool loss + MSE cursor loss, summed over trajectory steps
#
# Trajectory mix (default):
#   70% polygon-trace  (4–9 POLYLINE steps + CONFIRM)
#   30% 2-step chains  (select→erase, select→rotate, select→copy)
#
# Loads Phase-2 weights from checkpoint if `agent` is not already in scope.

from train import run_pretrain_teacher

agent = run_pretrain_teacher(
    agent=globals().get("agent"),    # reuse Phase-2 agent if it ran above
    num_trajectories=5_000,  # trajectories generated per epoch
    num_epochs=15,
    lr=1e-4,
    sigma=12.0,              # Gaussian blob radius in pixels
    cursor_weight=1.5,       # cursor loss weighted higher — sequencing is hard
    polygon_ratio=0.7,       # fraction of polygon-trace trajectories
    checkpoint_dir="checkpoints_1",
)

In [None]:
# Cell 3d: Post-Phase-3 Diagnostics — Polygon Tracing GIFs
#
# Produces two GIF types per episode inside diagnostics/:
#   oracle_ep<N>.gif  — oracle-driven rollout with agent heatmap overlay
#   free_ep<N>.gif    — fully autonomous agent rollout (no teacher forcing)
#
# Review the free_ep GIFs: if the agent traces polygons end-to-end without
# oracle guidance it is ready for Phase 4 RL.
# Requires: pillow  (installed in Cell 1)

from train import run_diagnostics
from pathlib import Path

metrics = run_diagnostics(
    agent=globals().get("agent"),   # reuse Phase-3 agent if it ran above
    n_episodes=6,
    output_dir="diagnostics",
    fps=1.5,
    checkpoint_dir="checkpoints_1",
)

print("\nGenerated GIFs:")
for p in sorted(Path("diagnostics").glob("*.gif")):
    print(f"  {p}")

if metrics:
    print(f"\nFree-run mean reward  : {metrics.get('mean_free_reward', 0):.3f}")
    print(f"Oracle mean reward    : {metrics.get('mean_oracle_reward', 0):.3f}")

In [None]:
# Cell 3e: Phase 4 — PPO RL Training
#
# Full agent training with curriculum learning.
# Automatically loads from the Phase-3 checkpoint so all prior
# supervised learning is preserved.
#
# Curriculum: difficulty cap starts at 2.0, grows every 5 000 steps.
# Checkpoints saved to checkpoints_1/latest.pt and checkpoints_1/best.pt.

from train import run_training

metrics_history = []
def collect_metrics(m):
    metrics_history.append(m)

trainer = run_training(
    num_steps=100_000,     # total PPO environment steps
    resume=True,           # auto-resume from checkpoints_1/latest.pt
    device=None,           # auto-detect GPU
    max_difficulty=None,   # curriculum default: starts at 2.0, grows every 5 000 steps
    checkpoint_dir="checkpoints_1",
    callback=collect_metrics,
)

In [None]:
# Cell 4: Plot training progress
import matplotlib.pyplot as plt
import json
from pathlib import Path

diag_path = Path("checkpoints_1/diagnostics.json")
if diag_path.exists():
    with open(diag_path) as f:
        diag = json.load(f)
    log = diag["training_log"]

    fig, axes = plt.subplots(2, 3, figsize=(18, 8))

    steps = [e["step"] for e in log]

    axes[0, 0].plot(steps, [e.get("avg_reward", 0) for e in log])
    axes[0, 0].set_title("Average Reward"); axes[0, 0].set_xlabel("Step")

    axes[0, 1].plot(steps, [e.get("policy_loss", 0) for e in log])
    axes[0, 1].set_title("Policy Loss"); axes[0, 1].set_xlabel("Step")

    axes[0, 2].plot(steps, [e.get("difficulty", 0) for e in log])
    axes[0, 2].set_title("Curriculum Difficulty"); axes[0, 2].set_xlabel("Step")

    axes[1, 0].plot(steps, [e.get("value_loss", 0) for e in log])
    axes[1, 0].set_title("Value Loss"); axes[1, 0].set_xlabel("Step")

    axes[1, 1].plot(steps, [e.get("entropy", 0) for e in log])
    axes[1, 1].set_title("Entropy"); axes[1, 1].set_xlabel("Step")

    axes[1, 2].plot(steps, [e.get("steps_per_second", 0) for e in log])
    axes[1, 2].set_title("Steps/sec"); axes[1, 2].set_xlabel("Step")

    plt.tight_layout()
    plt.savefig("checkpoints_1/training_curves.png", dpi=100)
    plt.show()

    print(f"Total steps    : {diag['total_steps']}")
    print(f"Total episodes : {diag['total_episodes']}")
    print(f"Best reward    : {diag['best_reward']:.4f}")
else:
    print("No diagnostics found yet. Run Phase 4 (Cell 3e) first.")

In [None]:
# Cell 5: Visualize agent behavior on a specific task
import numpy as np
import matplotlib.pyplot as plt
import torch
from cadfire.env.cad_env import CADEnv
from cadfire.model.cad_agent import CADAgent
from cadfire.renderer.rasterizer import Renderer
from cadfire.tasks.registry import TaskRegistry
from cadfire.utils.config import load_config

config = load_config()
env = CADEnv(config)
agent_vis = CADAgent.load_checkpoint("checkpoints_1/best.pt", config)
agent_vis.eval()
renderer = Renderer(config)

# Pick a task — change task_name to explore other behaviours
task = TaskRegistry.create("draw_circle", seed=42)
obs, info = env.reset(task=task)
print(f"Prompt: {info['prompt']}")

frames = [renderer.render_rgb_only(env.engine)]
total_reward = 0

for step in range(50):
    obs_t = {
        "image":    torch.tensor(obs["image"]).unsqueeze(0),
        "text_ids": torch.tensor(obs["text_ids"], dtype=torch.long).unsqueeze(0),
        "state_vec":torch.tensor(obs["state_vec"]).unsqueeze(0),
    }
    action_info = agent_vis.act(obs_t, deterministic=True)
    action = {
        "tool_id": action_info["tool_id"].item(),
        "cursor":  action_info["cursor"].cpu().numpy()[0],
        "param":   action_info["param"].item(),
    }
    obs, reward, term, trunc, info = env.step(action)
    total_reward += reward
    frames.append(renderer.render_rgb_only(env.engine))
    if term or trunc:
        break

print(f"Total reward: {total_reward:.3f}, Steps: {step + 1}")

# Show first, middle, and last frame
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, idx, title in zip(axes, [0, len(frames) // 2, -1], ["Start", "Middle", "End"]):
    ax.imshow(frames[idx])
    ax.set_title(title)
    ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Cell 6: Export a drawing to DXF
from cadfire.export.dxf_writer import DXFWriter

writer = DXFWriter()
writer.write(env.engine, "output.dxf")
print("Exported to output.dxf")
print(f"Entities: {env.engine.entity_count()}")

In [None]:
# Cell 7: Pull updates and continue training
# After new tasks or config changes are pushed upstream:
!git pull origin main

# Reload config and re-discover tasks (handles new task files automatically)
from cadfire.utils.config import reload
reload()
from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"Tasks after update: {TaskRegistry.list_tasks()}")

# Continue Phase 4 — auto-resumes, handles tool-list growth seamlessly
from train import run_training
trainer = run_training(num_steps=50_000, resume=True, checkpoint_dir="checkpoints_1")