# CADFire Training Notebook

Training notebook for vast.ai GPU instances.

## Workflow
1. Clone repo, install deps
2. **Phase 1**: Tool classifier pretraining (~1 min) — teaches text→tool mapping
3. **Phase 2**: Cursor imitation pretraining (~10-20 min) — teaches vision→cursor placement
4. **Phase 3**: RL fine-tuning — PPO from warm-started checkpoint
5. Monitor diagnostics
6. Pull updates with `git pull` - training resumes from checkpoint automatically

> **Note**: Phases 1 and 2 only need to run once. After saving `pretrained_full.pt`,
> subsequent RL runs resume from that checkpoint automatically.

In [None]:
# Cell 1: Setup (run once per instance)
!pip install -q torch numpy matplotlib

import sys, os
sys.path.insert(0, os.getcwd())

In [None]:
# Cell 2: Verify environment
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"\nRegistered tasks ({TaskRegistry.count()}): {TaskRegistry.list_tasks()}")

from cadfire.utils.config import num_tools, tool_list
print(f"Tools ({num_tools()}): {tool_list()[:10]}...")

In [None]:
# Cell 3a: Phase 1 — Tool Classifier Pretraining
# Teaches the text encoder + fusion + tool head to map prompts → tools.
# Only needs to run ONCE. Skip if pretrained_tools.pt already exists.
import os
from pathlib import Path

TOOL_CKPT = "checkpoints_1/pretrained_tools.pt"

if Path(TOOL_CKPT).exists():
    print(f"Found {TOOL_CKPT} — skipping Phase 1 (delete to re-run)")
else:
    import torch
    from cadfire.model.cad_agent import CADAgent
    from cadfire.training.pretrain_tools import pretrain_tool_classifier
    from cadfire.utils.config import load_config

    config = load_config()
    agent = CADAgent(config)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Phase 1: Tool classifier pretraining on {device}")
    history = pretrain_tool_classifier(
        agent, config,
        num_epochs=50,
        lr=1e-3,
        batch_size=64,
        device=device,
    )
    print(f"\nFinal tool accuracy: {history['accuracies'][-1]:.1%}")
    assert history['accuracies'][-1] > 0.80, "Tool accuracy too low — check tokenizer"

    os.makedirs("checkpoints_1", exist_ok=True)
    agent.save_checkpoint(TOOL_CKPT, extra_meta={
        "phase": "tool_pretrain",
        "tool_accuracy": history['accuracies'][-1],
    })
    print(f"Saved {TOOL_CKPT}")

In [None]:
# Cell 3b: Phase 2 — Cursor Imitation Pretraining
# Teaches the vision encoder + cursor head to click the right pixel.
# Only needs to run ONCE. Skip if pretrained_full.pt already exists.
from pathlib import Path

TOOL_CKPT = "checkpoints_1/pretrained_tools.pt"
FULL_CKPT = "checkpoints_1/pretrained_full.pt"

if Path(FULL_CKPT).exists():
    print(f"Found {FULL_CKPT} — skipping Phase 2 (delete to re-run)")
else:
    import torch
    from cadfire.model.cad_agent import CADAgent
    from cadfire.training.pretrain_cursor import pretrain_cursor_imitation
    from cadfire.utils.config import load_config

    config = load_config()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load Phase 1 weights if available, otherwise start fresh
    if Path(TOOL_CKPT).exists():
        print(f"Loading Phase 1 weights from {TOOL_CKPT}")
        agent = CADAgent.load_checkpoint(TOOL_CKPT, config)
    else:
        print("No Phase 1 checkpoint found — starting from scratch")
        agent = CADAgent(config)

    print(f"Phase 2: Cursor imitation pretraining on {device}")
    history = pretrain_cursor_imitation(
        agent, config,
        num_samples=20000,   # Increase to 50000 for better coverage
        num_epochs=20,
        lr=3e-4,
        batch_size=32,       # Keep small — images are 256x256x19
        device=device,
    )
    print(f"\nFinal tool accuracy:  {history['tool_accuracies'][-1]:.1%}")
    print(f"Final cursor loss:    {history['cursor_losses'][-1]:.4f}")

    agent.save_checkpoint(FULL_CKPT, extra_meta={
        "phase": "cursor_pretrain",
        "cursor_tool_accuracy": history['tool_accuracies'][-1],
        "cursor_final_loss":    history['cursor_losses'][-1],
    })
    print(f"Saved {FULL_CKPT} — use this as the RL starting point")

In [None]:
# Cell 3c: Phase 3 — RL Training
# Resumes from pretrained_full.pt if it exists, otherwise from latest RL checkpoint.
import os
from pathlib import Path
from train import run_training

FULL_CKPT = "checkpoints_1/pretrained_full.pt"
RL_CKPT_DIR = "checkpoints_1"

# If no RL checkpoint exists yet but pretrained_full.pt does, copy it as the
# starting checkpoint so the RL loop picks it up on resume.
rl_latest = Path(RL_CKPT_DIR) / "latest.pt"
if not rl_latest.exists() and Path(FULL_CKPT).exists():
    import shutil
    shutil.copy(FULL_CKPT, rl_latest)
    print(f"Seeded RL from pretrained checkpoint: {FULL_CKPT}")

# Training metrics collector for live plotting
metrics_history = []
def collect_metrics(m):
    metrics_history.append(m)

trainer = run_training(
    num_steps=100000,     # Adjust as needed
    resume=True,          # Auto-resume from checkpoint
    device=None,          # Auto-detect GPU
    max_difficulty=None,  # Use curriculum default (starts at 2.0)
    callback=collect_metrics,
)

In [None]:
# Cell 4: Plot training progress
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Load diagnostics
diag_path = Path("checkpoints_1/diagnostics.json")
if diag_path.exists():
    with open(diag_path) as f:
        diag = json.load(f)
    log = diag["training_log"]
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 8))
    
    steps = [e["step"] for e in log]
    
    axes[0,0].plot(steps, [e.get("avg_reward", 0) for e in log])
    axes[0,0].set_title("Average Reward"); axes[0,0].set_xlabel("Step")
    
    axes[0,1].plot(steps, [e.get("policy_loss", 0) for e in log])
    axes[0,1].set_title("Policy Loss"); axes[0,1].set_xlabel("Step")
    
    axes[0,2].plot(steps, [e.get("difficulty", 0) for e in log])
    axes[0,2].set_title("Curriculum Difficulty"); axes[0,2].set_xlabel("Step")
    
    axes[1,0].plot(steps, [e.get("value_loss", 0) for e in log])
    axes[1,0].set_title("Value Loss"); axes[1,0].set_xlabel("Step")
    
    axes[1,1].plot(steps, [e.get("entropy", 0) for e in log])
    axes[1,1].set_title("Entropy"); axes[1,1].set_xlabel("Step")
    
    axes[1,2].plot(steps, [e.get("steps_per_second", 0) for e in log])
    axes[1,2].set_title("Steps/sec"); axes[1,2].set_xlabel("Step")
    
    plt.tight_layout()
    plt.savefig("checkpoints_1/training_curves.png", dpi=100)
    plt.show()
    
    print(f"Total steps: {diag['total_steps']}")
    print(f"Total episodes: {diag['total_episodes']}")
    print(f"Best reward: {diag['best_reward']:.4f}")
else:
    print("No diagnostics found yet. Run training first.")

In [None]:
# Cell 5: Visualize agent behavior on a specific task
import numpy as np
import matplotlib.pyplot as plt
from cadfire.env.cad_env import CADEnv
from cadfire.model.cad_agent import CADAgent
from cadfire.renderer.rasterizer import Renderer
from cadfire.tasks.registry import TaskRegistry
from cadfire.utils.config import load_config
import torch

config = load_config()
env = CADEnv(config)
agent = CADAgent.load_checkpoint("checkpoints_1/best.pt", config)
agent.eval()
renderer = Renderer(config)

# Pick a task
task = TaskRegistry.create("draw_circle", seed=42)
obs, info = env.reset(task=task)
print(f"Prompt: {info['prompt']}")

frames = [renderer.render_rgb_only(env.engine)]
total_reward = 0

for step in range(50):
    obs_t = {
        "image": torch.tensor(obs["image"]).unsqueeze(0),
        "text_ids": torch.tensor(obs["text_ids"], dtype=torch.long).unsqueeze(0),
        "state_vec": torch.tensor(obs["state_vec"]).unsqueeze(0),
    }
    action_info = agent.act(obs_t, deterministic=True)
    action = {
        "tool_id": action_info["tool_id"].item(),
        "cursor": action_info["cursor"].cpu().numpy()[0],
        "param": action_info["param"].item(),
    }
    obs, reward, term, trunc, info = env.step(action)
    total_reward += reward
    frames.append(renderer.render_rgb_only(env.engine))
    if term or trunc:
        break

print(f"Total reward: {total_reward:.3f}, Steps: {step+1}")

# Show first, middle, and last frame
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, idx, title in zip(axes, [0, len(frames)//2, -1], ["Start", "Middle", "End"]):
    ax.imshow(frames[idx])
    ax.set_title(title)
    ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Cell 6: Export a drawing to DXF
from cadfire.export.dxf_writer import DXFWriter

writer = DXFWriter()
writer.write(env.engine, "output.dxf")
print("Exported to output.dxf")
print(f"Entities: {env.engine.entity_count()}")

In [None]:
# Cell 7: Pull updates and continue training
# After pushing new tasks or config changes:
!git pull origin main

# Reload config and re-discover tasks
from cadfire.utils.config import reload
reload()
TaskRegistry.discover()
print(f"Tasks after update: {TaskRegistry.list_tasks()}")

# Continue training (auto-resumes, handles tool list growth)
trainer = run_training(num_steps=50000, resume=True)