# 04 — Model Training & Analysis

Train the hierarchical VLA models and analyze results.

1. Train Level 2 (Skill Selector) via behavioral cloning
2. Train Level 3 (Motor Policy) via diffusion
3. Visualize training curves
4. Test inference

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import torch
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
print(f'PyTorch: {torch.__version__}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

%matplotlib inline

## 4.1 — Model Architecture Overview

In [None]:
from safedisassemble.models.skill_selector.selector import SkillSelector, SKILL_VOCAB
from safedisassemble.models.motor_policy.diffusion_policy import DiffusionMotorPolicy

# Level 2: Skill Selector
selector = SkillSelector(image_size=84, embed_dim=256)
selector_params = sum(p.numel() for p in selector.parameters())
print(f"Skill Selector: {selector_params:,} parameters")
print(f"  Skills: {SKILL_VOCAB}")

# Level 3: Motor Policy
policy = DiffusionMotorPolicy(
    action_dim=7, action_horizon=16, image_size=84,
    cond_dim=256, num_diffusion_steps=100, num_inference_steps=10,
)
policy_params = sum(p.numel() for p in policy.parameters())
print(f"\nMotor Policy: {policy_params:,} parameters")
print(f"  Action horizon: 16 steps")
print(f"  Diffusion steps (train): 100")
print(f"  Diffusion steps (inference): 10")

print(f"\nTotal parameters: {selector_params + policy_params:,}")

## 4.2 — Forward Pass Test

In [None]:
selector = selector.to(device)
policy = policy.to(device)

# Test Skill Selector
batch = {
    'images': torch.randn(4, 3, 84, 84).to(device),
    'tokens': torch.randint(0, 1000, (4, 32)).to(device),
    'proprio': torch.randn(4, 19).to(device),
}

with torch.no_grad():
    out = selector(batch['images'], batch['tokens'], batch['proprio'])
    print('Skill Selector output:')
    print(f'  skill_logits: {out["skill_logits"].shape}')
    print(f'  skill_params: {out["skill_params"].shape}')
    print(f'  confidence:   {out["confidence"].shape}')
    print(f'  Predicted skills: {[SKILL_VOCAB[i] for i in out["skill_logits"].argmax(dim=1)]}')

# Test Motor Policy
skill_ids = torch.randint(0, 8, (4,)).to(device)
skill_params = torch.randn(4, 10).to(device)

loss = policy.compute_loss(
    batch['images'], batch['proprio'], skill_ids, skill_params,
    torch.randn(4, 16, 7).to(device),
)
print(f'\nMotor Policy diffusion loss: {loss.item():.4f}')

actions = policy.predict_action(
    batch['images'], batch['proprio'], skill_ids, skill_params,
)
print(f'Predicted actions: {actions.shape}  (batch, horizon, action_dim)')
print(f'Action range: [{actions.min():.3f}, {actions.max():.3f}]')

## 4.3 — Training (Small-Scale Demo)

For full training, use `scripts/train.py`. This cell runs a quick overfit test.

In [None]:
# Quick overfit test on a small batch
selector = SkillSelector(image_size=84, embed_dim=128).to(device)
optimizer = torch.optim.Adam(selector.parameters(), lr=1e-3)

# Fixed batch to overfit on
fixed_batch = {
    'images': torch.randn(8, 3, 84, 84).to(device),
    'tokens': torch.randint(0, 1000, (8, 32)).to(device),
    'proprio': torch.randn(8, 19).to(device),
    'skill_target': torch.randint(0, len(SKILL_VOCAB), (8,)).to(device),
    'param_target': torch.randn(8, 10).to(device),
}

losses = []
for step in range(200):
    out = selector(fixed_batch['images'], fixed_batch['tokens'], fixed_batch['proprio'])
    loss_dict = selector.compute_loss(out, fixed_batch['skill_target'], fixed_batch['param_target'])
    
    optimizer.zero_grad()
    loss_dict['total_loss'].backward()
    optimizer.step()
    losses.append(loss_dict['total_loss'].item())

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Skill Selector — Overfit Test (should go to ~0)')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

print(f'Final loss: {losses[-1]:.6f}')
print(f'Loss decreased: {losses[0]:.4f} -> {losses[-1]:.6f} ({losses[0]/losses[-1]:.0f}x reduction)')

## 4.4 — Visualize Diffusion Denoising Process

In [None]:
from safedisassemble.models.motor_policy.diffusion_policy import DiffusionScheduler

scheduler = DiffusionScheduler(num_train_steps=100)

# Show the noise schedule
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(scheduler.betas.numpy())
axes[0].set_title('Beta Schedule')
axes[0].set_xlabel('Timestep')

axes[1].plot(scheduler.alphas_cumprod.numpy())
axes[1].set_title('Cumulative Alpha (signal remaining)')
axes[1].set_xlabel('Timestep')

axes[2].plot(scheduler.sqrt_one_minus_alphas_cumprod.numpy())
axes[2].set_title('Noise Level')
axes[2].set_xlabel('Timestep')

plt.suptitle('DDPM Noise Schedule (cosine)', fontsize=13)
plt.tight_layout()
plt.show()

# Visualize progressive noising of an action trajectory
original_actions = torch.sin(torch.linspace(0, 4*np.pi, 16)).unsqueeze(0).unsqueeze(0).repeat(1, 7, 1)

fig, axes = plt.subplots(2, 5, figsize=(18, 6))
timesteps_to_show = [0, 10, 25, 50, 99]

for i, t in enumerate(timesteps_to_show):
    noise = torch.randn_like(original_actions)
    noisy = scheduler.add_noise(original_actions, noise, torch.tensor([t]))
    
    axes[0, i].plot(original_actions[0, 0].numpy(), 'b-', alpha=0.3, label='original')
    axes[0, i].plot(noisy[0, 0].numpy(), 'r-', label='noisy')
    axes[0, i].set_title(f't = {t}')
    axes[0, i].set_ylim(-3, 3)
    if i == 0:
        axes[0, i].legend(fontsize=8)
    
    axes[1, i].plot(noise[0, 0].numpy(), 'g-', alpha=0.5, label='target noise')
    axes[1, i].set_ylim(-3, 3)
    if i == 0:
        axes[1, i].set_ylabel('Noise to predict')

axes[0, 0].set_ylabel('Action trajectory')
plt.suptitle('Forward Diffusion: Progressive Noising of Action Trajectories', fontsize=13)
plt.tight_layout()
plt.show()