# Train Stage 2 (Keypoints)

In [None]:
import sys
sys.path.append('..')

import torch
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

import config
from models import KeypointModel
from data import RobotKeypointDataset
from utils import train_stage2, plot_training_history, visualize_sample

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
full_train_dataset = RobotKeypointDataset(
    data_dirs=config.TRAIN_DIRS,
    config=config
)

train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

test_dataset = RobotKeypointDataset(
    data_dirs=[config.TEST_DIR],
    config=config
)

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

In [None]:
sample = full_train_dataset[0]
fig = visualize_sample(sample, config=config)
plt.show()

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

In [None]:
model = KeypointModel(
    num_keypoints=config.NUM_JOINTS,
    backbone=config.STAGE2_BACKBONE,
    pretrained=True
)
print(model)
print(f"\nparameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
model, history = train_stage2(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=config.EPOCHS_STAGE2,
    lr=config.LR,
    save_dir='../checkpoints',
    save_every=config.SAVE_EVERY,
    early_stop_patience=config.EARLY_STOP_PATIENCE,
    img_size=config.STAGE2_SIZE
)

In [None]:
fig = plot_training_history(history, title_prefix='Stage 2: ')
plt.savefig('../checkpoints/stage2_history.png', dpi=150)
plt.show()

print(f"Best validation error: {min(history['val_error']):.1f}px")

In [None]:
# Visualize Predictions
from utils import denormalize_image
import numpy as np

model.eval()
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

indices = np.random.choice(len(test_dataset), 6, replace=False)

for ax, idx in zip(axes, indices):
    sample = test_dataset[idx]
    
    with torch.no_grad():
        img = sample['img_stage2'].unsqueeze(0).to(device)
        pred = model(img).cpu().numpy()[0]
    
    # Display
    img_display = denormalize_image(sample['img_stage2'])
    ax.imshow(img_display)
    
    # Ground truth
    gt = sample['keypoints'].numpy().reshape(-1, 2) * config.STAGE2_SIZE
    ax.scatter(gt[:, 0], gt[:, 1], c='lime', s=100, marker='o', label='GT')
    
    # Prediction
    pred_px = pred.reshape(-1, 2) * config.STAGE2_SIZE
    ax.scatter(pred_px[:, 0], pred_px[:, 1], c='red', s=100, marker='x', label='Pred')
    
    # Error
    error = np.linalg.norm(pred_px - gt, axis=1).mean()
    ax.set_title(f'error: {error:.1f}px')
    ax.axis('off')

axes[0].legend()
plt.tight_layout()
plt.savefig('../checkpoints/stage2_predictions.png', dpi=150)
plt.show()