## 1. Setup & Imports

In [1]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import yaml
import torch
from tqdm import tqdm

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from config import Config
from models.pose_estimator import PoseEstimator
from dataset.custom_dataset import PoseDataset
from utils.transforms import (
    quaternion_to_rotation_matrix,
    crop_image_from_bbox,
    get_pose_transforms,
    project_3d_points
)
from utils.metrics import (
    load_all_models,
    load_models_info,
    compute_add
)

print(f"üìÇ Project root: {project_root}")
print(f"üñ•Ô∏è  Device: {Config.DEVICE}")

# Show device info
if Config.DEVICE == 'mps':
    print(f"   ‚úÖ Apple Silicon GPU detected (MPS)")
elif Config.DEVICE == 'cuda':
    print(f"   ‚úÖ NVIDIA GPU detected (CUDA)")
else:
    print(f"   ‚ö†Ô∏è  Using CPU (slower training)")

üìÇ Project root: /Users/nicolotermine/zMellow/GitHub-Poli/Polito/polito-aml-6D_pose_estimation
üñ•Ô∏è  Device: mps
   ‚úÖ Apple Silicon GPU detected (MPS)


## 2. Train and then Load Model & Dataset

In [2]:
# üöÄ Training Options - Choose your speed/quality tradeoff!

# Option 1: SUPER FAST TEST (2-3 min) - Freeze backbone, train only head
# ‚ö° Trains only ~3M params instead of ~26M - Perfect for quick testing!
!python ../scripts/train_pose.py \
    --epochs 2 \
    --batch_size 8 \
    --gradient_accum_steps 2 \
    --val_interval 1 \
    --save_interval 1 \
    --num_workers 2 \
    --freeze_backbone

# Option 2: MEDIUM TRAINING (10-15 min) - Train everything, few epochs
# !python ../scripts/train_pose.py \
#     --epochs 5 \
#     --batch_size 4 \
#     --gradient_accum_steps 2 \
#     --val_interval 1 \
#     --save_interval 5 \
#     --num_workers 2

# Option 3: FULL TRAINING (2-4 hours) - Best results
# !python ../scripts/train_pose.py \
#     --epochs 50 \
#     --batch_size 8 \
#     --gradient_accum_steps 4 \
#     --val_interval 5 \
#     --save_interval 10 \
#     --use_wandb

print("\n" + "="*60)
print("üí° Training Info:")
print("="*60)
print(f"   Device: {Config.DEVICE}")
print(f"   Checkpoint: checkpoints/best_model.pth")
print("\nüìä Options comparison:")
print("   1. Freeze backbone: 2-3 min, ~3M params, good for testing")
print("   2. Train all (few epochs): 10-15 min, ~26M params, better quality")
print("   3. Full training: 2-4 hours, best results")
print("="*60)

üñ•Ô∏è  Using device: mps

üì¶ Loading dataset from: /Users/nicolotermine/zMellow/GitHub-Poli/Polito/polito-aml-6D_pose_estimation/test/../data/Linemod_preprocessed
‚úÖ PoseDataset initialized: 3759 train samples
‚úÖ PoseDataset initialized: 3759 train samples
‚úÖ PoseDataset initialized: 21218 test samples

üìä Pose DataLoaders created:
   Training samples: 3759
   Training batches: 470
   Test samples: 21218
   Test batches: 2653

üìê Loading 3D models for ADD metric...
‚úÖ Loaded model 01: 5841 points
‚úÖ Loaded model 02: 38325 points
‚úÖ Loaded model 03: 40759 points
‚úÖ PoseDataset initialized: 21218 test samples

üìä Pose DataLoaders created:
   Training samples: 3759
   Training batches: 470
   Test samples: 21218
   Test batches: 2653

üìê Loading 3D models for ADD metric...
‚úÖ Loaded model 01: 5841 points
‚úÖ Loaded model 02: 38325 points
‚úÖ Loaded model 03: 40759 points
‚úÖ Loaded model 04: 18995 points
‚úÖ Loaded model 04: 18995 points
‚úÖ Loaded model 05: 22831 poin

In [2]:
# Load trained model
checkpoint_path = Config.CHECKPOINT_DIR / 'best_model.pth'

if not checkpoint_path.exists():
    print(f"‚ùå Checkpoint not found: {checkpoint_path}")
    print("   Please train the model first using scripts/train_pose.py")
else:
    print(f"üì¶ Loading checkpoint: {checkpoint_path}")
    
    # Create model
    device = torch.device(Config.DEVICE)
    model = PoseEstimator(pretrained=False, dropout=Config.POSE_DROPOUT)
    
    # Load weights
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    # Print checkpoint info
    epoch = checkpoint.get('epoch', 'N/A')
    metrics = checkpoint.get('metrics', {})
    print(f"‚úÖ Model loaded successfully")
    print(f"   Epoch: {epoch}")
    if 'mean_add' in metrics:
        print(f"   Val ADD: {metrics['mean_add']:.2f} mm")
    if 'accuracy' in metrics:
        print(f"   Val Accuracy: {metrics['accuracy']:.2f}%")

‚ùå Checkpoint not found: /Users/nicolotermine/zMellow/GitHub-Poli/Polito/polito-aml-6D_pose_estimation/checkpoints/best_model.pth
   Please train the model first using scripts/train_pose.py


In [None]:
# Load test dataset
print(f"\nüì¶ Loading test dataset...")
test_dataset = PoseDataset(
    dataset_root=str(Config.DATA_ROOT),
    split='test',
    crop_margin=Config.POSE_CROP_MARGIN,
    output_size=Config.POSE_IMAGE_SIZE
)

print(f"‚úÖ Test samples: {len(test_dataset)}")

# Load 3D models for visualization
print(f"\nüìê Loading 3D models...")
models_dict = load_all_models(Config.MODELS_PATH)
models_info = load_models_info(Config.MODELS_INFO_PATH)
print(f"‚úÖ Loaded {len(models_dict)} 3D models")

## 3. Helper Functions

In [None]:
def get_3d_bbox_corners(model_points):
    """Get 8 corners of 3D bounding box from model points."""
    min_xyz = model_points.min(axis=0)
    max_xyz = model_points.max(axis=0)
    
    corners = np.array([
        [min_xyz[0], min_xyz[1], min_xyz[2]],
        [max_xyz[0], min_xyz[1], min_xyz[2]],
        [max_xyz[0], max_xyz[1], min_xyz[2]],
        [min_xyz[0], max_xyz[1], min_xyz[2]],
        [min_xyz[0], min_xyz[1], max_xyz[2]],
        [max_xyz[0], min_xyz[1], max_xyz[2]],
        [max_xyz[0], max_xyz[1], max_xyz[2]],
        [min_xyz[0], max_xyz[1], max_xyz[2]]
    ])
    
    return corners


def draw_3d_bbox(ax, corners_2d, color='g', linewidth=2, label=None):
    """Draw 3D bounding box on image."""
    # Define edges of bounding box
    edges = [
        [0, 1], [1, 2], [2, 3], [3, 0],  # Bottom face
        [4, 5], [5, 6], [6, 7], [7, 4],  # Top face
        [0, 4], [1, 5], [2, 6], [3, 7]   # Vertical edges
    ]
    
    # Draw edges
    for i, (start, end) in enumerate(edges):
        if i == 0 and label:
            ax.plot([corners_2d[start, 0], corners_2d[end, 0]],
                   [corners_2d[start, 1], corners_2d[end, 1]],
                   color=color, linewidth=linewidth, label=label)
        else:
            ax.plot([corners_2d[start, 0], corners_2d[end, 0]],
                   [corners_2d[start, 1], corners_2d[end, 1]],
                   color=color, linewidth=linewidth)


def visualize_pose_prediction(rgb_crop, rgb_full, K, pred_R, pred_t, gt_R, gt_t, 
                              model_points, obj_name, add_error):
    """Visualize pose prediction with 3D bounding box overlay."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot 1: Cropped RGB input
    axes[0].imshow(rgb_crop)
    axes[0].set_title(f"Input Crop\n{obj_name}")
    axes[0].axis('off')
    
    # Get 3D bbox corners
    corners_3d = get_3d_bbox_corners(model_points)
    
    # Plot 2: Predicted pose
    axes[1].imshow(rgb_full)
    pred_corners_2d = project_3d_points(corners_3d, pred_R, pred_t, K)
    draw_3d_bbox(axes[1], pred_corners_2d, color='lime', linewidth=2, label='Predicted')
    axes[1].set_title(f"Predicted Pose\nADD: {add_error:.2f} mm")
    axes[1].axis('off')
    axes[1].legend(loc='upper right')
    
    # Plot 3: Ground truth pose
    axes[2].imshow(rgb_full)
    gt_corners_2d = project_3d_points(corners_3d, gt_R, gt_t, K)
    draw_3d_bbox(axes[2], gt_corners_2d, color='cyan', linewidth=2, label='Ground Truth')
    axes[2].set_title("Ground Truth Pose")
    axes[2].axis('off')
    axes[2].legend(loc='upper right')
    
    plt.tight_layout()
    plt.show()

## 4. Test on Individual Samples

In [None]:
# Select random test samples
import random

num_samples = 5
sample_indices = random.sample(range(len(test_dataset)), num_samples)

print(f"Testing on {num_samples} random samples...\n")

for idx in sample_indices:
    # Get sample
    sample = test_dataset[idx]
    
    # Prepare input
    rgb_crop_tensor = sample['rgb_crop'].unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        pred_quat, pred_trans = model(rgb_crop_tensor)
    
    # Convert to numpy
    pred_quat = pred_quat.squeeze(0).cpu().numpy()
    pred_trans = pred_trans.squeeze(0).cpu().numpy()
    gt_quat = sample['quaternion'].numpy()
    gt_trans = sample['translation'].numpy()
    
    # Convert quaternion to rotation matrix
    pred_R = quaternion_to_rotation_matrix(torch.tensor(pred_quat)).numpy()
    gt_R = quaternion_to_rotation_matrix(torch.tensor(gt_quat)).numpy()
    
    # Get 3D model
    obj_id = sample['obj_id']
    model_points = models_dict[obj_id]
    obj_name = Config.OBJ_ID_TO_NAME.get(obj_id, f"Object {obj_id}")
    
    # Compute ADD error
    is_symmetric = obj_id in Config.SYMMETRIC_OBJECTS
    add_error = compute_add(
        pred_R, pred_trans,
        gt_R, gt_trans,
        model_points,
        symmetric=is_symmetric
    )
    
    # Get full RGB and camera intrinsics
    rgb_full = sample['rgb_full']
    K = sample['camera_K'].numpy()
    
    # Denormalize crop for visualization
    rgb_crop_vis = rgb_crop_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    rgb_crop_vis = (rgb_crop_vis * np.array([0.229, 0.224, 0.225]) + 
                   np.array([0.485, 0.456, 0.406]))
    rgb_crop_vis = (rgb_crop_vis * 255).clip(0, 255).astype(np.uint8)
    
    # Visualize
    print(f"Sample {idx}: {obj_name}")
    print(f"  ADD Error: {add_error:.2f} mm")
    print(f"  Symmetric: {is_symmetric}\n")
    
    visualize_pose_prediction(
        rgb_crop_vis, rgb_full, K,
        pred_R, pred_trans,
        gt_R, gt_trans,
        model_points, obj_name, add_error
    )

## 5. Evaluate on Full Test Set

In [None]:
print(f"Evaluating on {len(test_dataset)} test samples...\n")

# Storage for results
add_errors_per_object = {obj_id: [] for obj_id in Config.OBJ_ID_TO_NAME.keys()}
all_add_errors = []

# Evaluate
with torch.no_grad():
    for idx in tqdm(range(len(test_dataset))):
        sample = test_dataset[idx]
        
        # Predict
        rgb_crop = sample['rgb_crop'].unsqueeze(0).to(device)
        pred_quat, pred_trans = model(rgb_crop)
        
        # Convert to numpy
        pred_quat = pred_quat.squeeze(0).cpu().numpy()
        pred_trans = pred_trans.squeeze(0).cpu().numpy()
        gt_quat = sample['quaternion'].numpy()
        gt_trans = sample['translation'].numpy()
        
        # Convert to rotation matrix
        pred_R = quaternion_to_rotation_matrix(torch.tensor(pred_quat)).numpy()
        gt_R = quaternion_to_rotation_matrix(torch.tensor(gt_quat)).numpy()
        
        # Compute ADD
        obj_id = sample['obj_id']
        model_points = models_dict[obj_id]
        is_symmetric = obj_id in Config.SYMMETRIC_OBJECTS
        
        add_error = compute_add(
            pred_R, pred_trans,
            gt_R, gt_trans,
            model_points,
            symmetric=is_symmetric
        )
        
        # Store results
        add_errors_per_object[obj_id].append(add_error)
        all_add_errors.append(add_error)

print("\n‚úÖ Evaluation complete!")

## 6. Results Analysis

In [None]:
# Overall statistics
mean_add = np.mean(all_add_errors)
median_add = np.median(all_add_errors)
std_add = np.std(all_add_errors)

# Accuracy at threshold
threshold_mm = Config.ADD_THRESHOLD * 100  # Convert to mm (assuming diameter ~100mm)
accuracy = np.mean([e < threshold_mm for e in all_add_errors]) * 100

print(f"üìä Overall Results:")
print(f"   Mean ADD: {mean_add:.2f} mm")
print(f"   Median ADD: {median_add:.2f} mm")
print(f"   Std ADD: {std_add:.2f} mm")
print(f"   Accuracy @ {threshold_mm:.1f}mm: {accuracy:.2f}%")

# Per-object statistics
print(f"\nüì¶ Per-Object Results:")
obj_results = []

for obj_id, errors in add_errors_per_object.items():
    if len(errors) > 0:
        obj_name = Config.OBJ_ID_TO_NAME[obj_id]
        mean_err = np.mean(errors)
        median_err = np.median(errors)
        
        # Get object diameter for threshold
        diameter = models_info[obj_id]['diameter']
        obj_threshold = Config.ADD_THRESHOLD * diameter
        obj_acc = np.mean([e < obj_threshold for e in errors]) * 100
        
        obj_results.append({
            'id': obj_id,
            'name': obj_name,
            'mean': mean_err,
            'median': median_err,
            'accuracy': obj_acc,
            'count': len(errors)
        })
        
        print(f"   {obj_name:15s} - Mean: {mean_err:6.2f} mm, Accuracy: {obj_acc:5.2f}% ({len(errors)} samples)")

# Sort by mean ADD
obj_results.sort(key=lambda x: x['mean'])

In [None]:
# Visualize ADD distribution
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Overall ADD histogram
axes[0, 0].hist(all_add_errors, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
axes[0, 0].axvline(mean_add, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_add:.2f} mm')
axes[0, 0].axvline(median_add, color='green', linestyle='--', linewidth=2, label=f'Median: {median_add:.2f} mm')
axes[0, 0].set_xlabel('ADD Error (mm)')
axes[0, 0].set_ylabel('Count')
axes[0, 0].set_title('Overall ADD Error Distribution')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# Plot 2: Per-object mean ADD
obj_names = [r['name'] for r in obj_results]
obj_means = [r['mean'] for r in obj_results]
axes[0, 1].barh(obj_names, obj_means, color='coral', edgecolor='black')
axes[0, 1].set_xlabel('Mean ADD Error (mm)')
axes[0, 1].set_title('Mean ADD by Object')
axes[0, 1].grid(axis='x', alpha=0.3)

# Plot 3: Per-object accuracy
obj_accs = [r['accuracy'] for r in obj_results]
axes[1, 0].barh(obj_names, obj_accs, color='lightgreen', edgecolor='black')
axes[1, 0].set_xlabel('Accuracy (%)')
axes[1, 0].set_title(f'Accuracy by Object (threshold: {Config.ADD_THRESHOLD*100}% diameter)')
axes[1, 0].grid(axis='x', alpha=0.3)
axes[1, 0].set_xlim([0, 100])

# Plot 4: ADD boxplot per object
add_data = [add_errors_per_object[r['id']] for r in obj_results]
bp = axes[1, 1].boxplot(add_data, labels=obj_names, vert=False, patch_artist=True)
for patch in bp['boxes']:
    patch.set_facecolor('skyblue')
axes[1, 1].set_xlabel('ADD Error (mm)')
axes[1, 1].set_title('ADD Distribution by Object')
axes[1, 1].grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Best and Worst Predictions

In [None]:
# Find best and worst predictions
sorted_indices = np.argsort(all_add_errors)

print("üèÜ Best 3 Predictions:")
for i in range(3):
    idx = sorted_indices[i]
    sample = test_dataset[idx]
    obj_name = Config.OBJ_ID_TO_NAME[sample['obj_id']]
    print(f"   {i+1}. {obj_name}: {all_add_errors[idx]:.2f} mm")

print("\n‚ùå Worst 3 Predictions:")
for i in range(3):
    idx = sorted_indices[-(i+1)]
    sample = test_dataset[idx]
    obj_name = Config.OBJ_ID_TO_NAME[sample['obj_id']]
    print(f"   {i+1}. {obj_name}: {all_add_errors[idx]:.2f} mm")

## 8. Summary

This notebook provides comprehensive evaluation of the trained 6D pose estimation model:

1. **Individual sample visualization** - Inspect predicted vs GT 3D bounding boxes
2. **Full test set evaluation** - Compute ADD metric on all test samples
3. **Per-object analysis** - Identify which objects are easiest/hardest to estimate
4. **Error distribution** - Understand model performance characteristics

**Key Metrics:**
- **ADD (Average Distance of Model Points)**: Average distance between transformed model points
- **ADD-S**: Symmetric variant using closest point matching for symmetric objects (eggbox, glue)
- **Accuracy**: Percentage of predictions with ADD < 10% of object diameter

**Next Steps:**
- Fine-tune hyperparameters (learning rate, batch size, loss weights)
- Try data augmentation (random crops, color jitter, rotation)
- Experiment with different backbones (ResNet-101, EfficientNet)
- Add depth information to input (RGB-D)
- Implement iterative refinement