# CADFire Training Notebook

Training notebook for vast.ai GPU instances.

## Workflow
1. Clone repo, install deps
2. Run training cells
3. Monitor diagnostics
4. Pull updates with `git pull` - training resumes from checkpoint automatically

In [None]:
# Cell 1: Setup (run once per instance)
!pip install -q torch numpy matplotlib

import sys, os
sys.path.insert(0, os.getcwd())

In [None]:
# Cell 2: Verify environment
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

from cadfire.tasks.registry import TaskRegistry
TaskRegistry.discover()
print(f"\nRegistered tasks ({TaskRegistry.count()}): {TaskRegistry.list_tasks()}")

from cadfire.utils.config import num_tools, tool_list
print(f"Tools ({num_tools()}): {tool_list()[:10]}...")

In [None]:
# Cell 3: Train
from train import run_training

# Training metrics collector for live plotting
metrics_history = []
def collect_metrics(m):
    metrics_history.append(m)

trainer = run_training(
    num_steps=100000,     # Adjust as needed
    resume=True,          # Auto-resume from checkpoint
    device=None,          # Auto-detect GPU
    max_difficulty=None,  # Use curriculum default
    callback=collect_metrics,
)

In [None]:
# Cell 4: Plot training progress
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Load diagnostics
diag_path = Path("checkpoints/diagnostics.json")
if diag_path.exists():
    with open(diag_path) as f:
        diag = json.load(f)
    log = diag["training_log"]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 8))
    
    steps = [e["step"] for e in log]
    
    axes[0,0].plot(steps, [e.get("avg_reward", 0) for e in log])
    axes[0,0].set_title("Average Reward"); axes[0,0].set_xlabel("Step")
    
    axes[0,1].plot(steps, [e.get("policy_loss", 0) for e in log])
    axes[0,1].set_title("Policy Loss"); axes[0,1].set_xlabel("Step")
    
    axes[1,0].plot(steps, [e.get("value_loss", 0) for e in log])
    axes[1,0].set_title("Value Loss"); axes[1,0].set_xlabel("Step")
    
    axes[1,1].plot(steps, [e.get("entropy", 0) for e in log])
    axes[1,1].set_title("Entropy"); axes[1,1].set_xlabel("Step")
    
    plt.tight_layout()
    plt.savefig("checkpoints/training_curves.png", dpi=100)
    plt.show()
    
    print(f"Total steps: {diag['total_steps']}")
    print(f"Total episodes: {diag['total_episodes']}")
    print(f"Best reward: {diag['best_reward']:.4f}")
else:
    print("No diagnostics found yet. Run training first.")

In [None]:
# Cell 5: Visualize agent behavior on a specific task
import numpy as np
import matplotlib.pyplot as plt
from cadfire.env.cad_env import CADEnv
from cadfire.model.cad_agent import CADAgent
from cadfire.renderer.rasterizer import Renderer
from cadfire.tasks.registry import TaskRegistry
from cadfire.utils.config import load_config
import torch

config = load_config()
env = CADEnv(config)
agent = CADAgent.load_checkpoint("checkpoints/best.pt", config)
agent.eval()
renderer = Renderer(config)

# Pick a task
task = TaskRegistry.create("draw_circle", seed=42)
obs, info = env.reset(task=task)
print(f"Prompt: {info['prompt']}")

frames = [renderer.render_rgb_only(env.engine)]
total_reward = 0

for step in range(50):
    obs_t = {
        "image": torch.tensor(obs["image"]).unsqueeze(0),
        "text_ids": torch.tensor(obs["text_ids"], dtype=torch.long).unsqueeze(0),
        "state_vec": torch.tensor(obs["state_vec"]).unsqueeze(0),
    }
    action_info = agent.act(obs_t, deterministic=True)
    action = {
        "tool_id": action_info["tool_id"].item(),
        "cursor": action_info["cursor"].cpu().numpy()[0],
        "param": action_info["param"].item(),
    }
    obs, reward, term, trunc, info = env.step(action)
    total_reward += reward
    frames.append(renderer.render_rgb_only(env.engine))
    if term or trunc:
        break

print(f"Total reward: {total_reward:.3f}, Steps: {step+1}")

# Show first, middle, and last frame
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for ax, idx, title in zip(axes, [0, len(frames)//2, -1], ["Start", "Middle", "End"]):
    ax.imshow(frames[idx])
    ax.set_title(title)
    ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
# Cell 6: Export a drawing to DXF
from cadfire.export.dxf_writer import DXFWriter

writer = DXFWriter()
writer.write(env.engine, "output.dxf")
print("Exported to output.dxf")
print(f"Entities: {env.engine.entity_count()}")

In [None]:
# Cell 7: Pull updates and continue training
# After pushing new tasks or config changes:
!git pull origin main

# Reload config and re-discover tasks
from cadfire.utils.config import reload
reload()
TaskRegistry.discover()
print(f"Tasks after update: {TaskRegistry.list_tasks()}")

# Continue training (auto-resumes, handles tool list growth)
trainer = run_training(num_steps=50000, resume=True)