# CADFire Training Notebook

Training notebook for vast.ai GPU instances.

## Workflow
1. Clone repo, install deps
2. **Phase 1**: Tool classifier pretraining (~1 min) — teaches text→tool mapping
3. **Phase 2**: Semantic cursor pretraining (~10-30 min) — teaches SELECT/MULTISELECT with precise cursor placement
4. **Phase 3**: RL fine-tuning — PPO from warm-started checkpoint
5. Monitor diagnostics
6. Pull updates with `git pull` - training resumes from checkpoint automatically

> **Note**: Phases 1 and 2 only need to run once. After a pretrained checkpoint is saved,
> subsequent RL runs resume from it automatically.

In [None]:
# Cell 1: Setup (run once per instance)
!pip install -q torch numpy matplotlib

import sys, os
sys.path.insert(0, os.getcwd())

In [None]:
# Cell 2: Verify environment
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"\nRegistered tasks ({TaskRegistry.count()}): {TaskRegistry.list_tasks()}")

from cadfire.utils.config import num_tools, tool_list
print(f"Tools ({num_tools()}): {tool_list()[:10]}...")

In [None]:
# Cell 3a: Phase 1 — Tool Classifier Pretraining
# Teaches the text encoder + fusion + tool head to map prompts → tools.
# Idempotent: loads any existing checkpoint first, then trains on top of it.
# Increase num_epochs to 50 for higher accuracy (~95%+).

import sys, os
sys.path.insert(0, os.getcwd())

from train import run_pretrain_tool

agent = run_pretrain_tool(
    num_epochs=30,        # ~1 min on GPU; raise to 50 for higher accuracy
    lr=1e-3,
    batch_size=64,
    checkpoint_dir="checkpoints_1",
)

In [None]:
# Cell 3b: Phase 2 — Semantic Cursor Pretraining
# Teaches the vision encoder + cursor head to locate shapes by name.
#   SELECT      → "Select the <shape>"   — single Gaussian blob at target centroid
#   MULTISELECT → "Select all <shape>s"  — one blob per target centroid
# Text encoder is FROZEN to preserve Phase-1 weights.
# Loads Phase-1 weights from checkpoint if `agent` is not already in scope.

from train import run_pretrain_semantic

agent = run_pretrain_semantic(
    agent=globals().get("agent"),   # reuse Phase-1 agent if it ran above
    num_samples=20_000,   # generated samples per epoch; raise to 50_000 for better coverage
    num_epochs=20,
    lr=3e-4,
    batch_size=32,        # keep small — images are 256×256×19
    sigma=12.0,           # Gaussian blob radius in pixels
    multi_ratio=0.4,      # 40% MULTISELECT samples, 60% SELECT
    cursor_weight=1.0,    # relative weight of cursor loss vs tool loss
    num_workers=0,        # set to 4 on multi-core instances for faster dataloading
    checkpoint_dir="checkpoints_1",
)

In [None]:
# Cell 3c: Phase 3 — RL Training (PPO)
# Resumes from the latest pretraining checkpoint automatically.
# Adjust num_steps and max_difficulty as needed.

from train import run_training

metrics_history = []
def collect_metrics(m):
    metrics_history.append(m)

trainer = run_training(
    num_steps=100_000,     # total PPO environment steps
    resume=True,           # auto-resume from checkpoints_1/latest.pt
    device=None,           # auto-detect GPU
    max_difficulty=None,   # curriculum default: starts at 2.0, grows every 5 000 steps
    checkpoint_dir="checkpoints_1",
    callback=collect_metrics,
)

In [None]:
# Cell 4: Plot training progress
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Load diagnostics
diag_path = Path("checkpoints_1/diagnostics.json")
if diag_path.exists():
    with open(diag_path) as f:
        diag = json.load(f)
    log = diag["training_log"]
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 8))
    
    steps = [e["step"] for e in log]
    
    axes[0,0].plot(steps, [e.get("avg_reward", 0) for e in log])
    axes[0,0].set_title("Average Reward"); axes[0,0].set_xlabel("Step")
    
    axes[0,1].plot(steps, [e.get("policy_loss", 0) for e in log])
    axes[0,1].set_title("Policy Loss"); axes[0,1].set_xlabel("Step")
    
    axes[0,2].plot(steps, [e.get("difficulty", 0) for e in log])
    axes[0,2].set_title("Curriculum Difficulty"); axes[0,2].set_xlabel("Step")
    
    axes[1,0].plot(steps, [e.get("value_loss", 0) for e in log])
    axes[1,0].set_title("Value Loss"); axes[1,0].set_xlabel("Step")
    
    axes[1,1].plot(steps, [e.get("entropy", 0) for e in log])
    axes[1,1].set_title("Entropy"); axes[1,1].set_xlabel("Step")
    
    axes[1,2].plot(steps, [e.get("steps_per_second", 0) for e in log])
    axes[1,2].set_title("Steps/sec"); axes[1,2].set_xlabel("Step")
    
    plt.tight_layout()
    plt.savefig("checkpoints_1/training_curves.png", dpi=100)
    plt.show()
    
    print(f"Total steps: {diag['total_steps']}")
    print(f"Total episodes: {diag['total_episodes']}")
    print(f"Best reward: {diag['best_reward']:.4f}")
else:
    print("No diagnostics found yet. Run training first.")

In [None]:
# Cell 5: Visualize agent behavior on a specific task
import numpy as np
import matplotlib.pyplot as plt
from cadfire.env.cad_env import CADEnv
from cadfire.model.cad_agent import CADAgent
from cadfire.renderer.rasterizer import Renderer
from cadfire.tasks.registry import TaskRegistry
from cadfire.utils.config import load_config
import torch

config = load_config()
env = CADEnv(config)
agent = CADAgent.load_checkpoint("checkpoints_1/best.pt", config)
agent.eval()
renderer = Renderer(config)

# Pick a task
task = TaskRegistry.create("draw_circle", seed=42)
obs, info = env.reset(task=task)
print(f"Prompt: {info['prompt']}")

frames = [renderer.render_rgb_only(env.engine)]
total_reward = 0

for step in range(50):
    obs_t = {
        "image": torch.tensor(obs["image"]).unsqueeze(0),
        "text_ids": torch.tensor(obs["text_ids"], dtype=torch.long).unsqueeze(0),
        "state_vec": torch.tensor(obs["state_vec"]).unsqueeze(0),
    }
    action_info = agent.act(obs_t, deterministic=True)
    action = {
        "tool_id": action_info["tool_id"].item(),
        "cursor": action_info["cursor"].cpu().numpy()[0],
        "param": action_info["param"].item(),
    }
    obs, reward, term, trunc, info = env.step(action)
    total_reward += reward
    frames.append(renderer.render_rgb_only(env.engine))
    if term or trunc:
        break

print(f"Total reward: {total_reward:.3f}, Steps: {step+1}")

# Show first, middle, and last frame
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, idx, title in zip(axes, [0, len(frames)//2, -1], ["Start", "Middle", "End"]):
    ax.imshow(frames[idx])
    ax.set_title(title)
    ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Cell 6: Export a drawing to DXF
from cadfire.export.dxf_writer import DXFWriter

writer = DXFWriter()
writer.write(env.engine, "output.dxf")
print("Exported to output.dxf")
print(f"Entities: {env.engine.entity_count()}")

In [None]:
# Cell 7: Pull updates and continue training
# After new tasks or config changes are pushed upstream:
!git pull origin master

# Reload config and re-discover tasks
from cadfire.utils.config import reload
reload()
from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"Tasks after update: {TaskRegistry.list_tasks()}")

# Continue training — auto-resumes, handles tool-list growth seamlessly
from train import run_training
trainer = run_training(num_steps=50_000, resume=True, checkpoint_dir="checkpoints_1")