# Train Stage 1 (BBox)

In [None]:
import sys
sys.path.append('..')

import torch
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np

import config
from models import BBoxModel
from data import RobotKeypointDataset
from utils import train_stage1, plot_training_history, visualize_bbox_predictions, denormalize_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
full_train_dataset = RobotKeypointDataset(
    data_dirs=config.TRAIN_DIRS,
    config=config
)

train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

test_dataset = RobotKeypointDataset(
    data_dirs=[config.TEST_DIR],
    config=config
)

print(f"Train: {len(train_dataset)} samples")
print(f"Val: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE * 2, # 64
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE * 2,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.BATCH_SIZE * 2,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

In [None]:
model = BBoxModel(
    backbone=config.STAGE1_BACKBONE,
    pretrained=True
)
print(model)
print(f"\nParameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
model, history = train_stage1(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=config.EPOCHS_STAGE1,
    lr=config.LR,
    save_dir='../checkpoints',
    save_every=config.SAVE_EVERY,
    early_stop_patience=config.EARLY_STOP_PATIENCE
)

In [None]:
fig = plot_training_history(history, title_prefix='Stage 1: ')
plt.savefig('../checkpoints/stage1_history.png', dpi=150)
plt.show()

print(f"Best model saved")
print(f"Best validation IoU: {max(history['val_iou']):.3f}")

In [None]:
fig = visualize_bbox_predictions(test_dataset, model, device, num_samples=6, config=config)
plt.savefig('../checkpoints/stage1_predictions.png', dpi=150)
plt.show()

In [None]:
from utils import compute_bbox_iou

model.eval()
ious = []

with torch.no_grad():
    for i in range(len(test_dataset)):
        sample = test_dataset[i]
        img = sample['img_stage1'].unsqueeze(0).to(device)
        pred = model(img)
        gt = sample['bbox'].unsqueeze(0).to(device)
        iou = compute_bbox_iou(pred, gt)
        ious.append(iou)

ious = np.array(ious)

fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(ious, bins=30, edgecolor='black', alpha=0.7)
ax.axvline(np.mean(ious), color='red', linestyle='--', label=f'Mean: {np.mean(ious):.3f}')
ax.axvline(np.median(ious), color='orange', linestyle='--', label=f'Median: {np.median(ious):.3f}')
ax.set_xlabel('IoU')
ax.set_ylabel('Count')
ax.set_title('BBox IoU Distribution')
ax.legend()
plt.tight_layout()
plt.savefig('../checkpoints/stage1_iou_dist.png', dpi=150)
plt.show()

print(f"\nIoU Statistics:")
print(f"  Mean:   {np.mean(ious):.3f}")
print(f"  Median: {np.median(ious):.3f}")
print(f"  >0.8:   {(ious > 0.8).mean()*100:.1f}%")
print(f"  >0.9:   {(ious > 0.9).mean()*100:.1f}%")