# Phase 4. Extensions

## Pipeline requested for the extension

**Object detection**

- leveraging already existing and trained YOLO model we'll be able to find BBox for RGB images

**ROI**

- for each BBox crop RGB image accordingly
- crop the same exact point on corresponding depth file

**Feat extraction**

Feature extraction happens leveraging two different CNNs:
- *RGB Branch*: feature extraction from RGB cropped image
- *Depth Branch*: feature extraction from depth cropped image. Depth is threated as a 2D image, not a cloud of dots.

**Fusion**

$f_{\text{fused}}=concat(f_{\text{rgb}},f_{\text{depth}})$

**Pose estimation**

Pose estimation is done by a regressor (MLP)

# Step 1: Object detection

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2

from config import Config
from models.yolo_detector import YOLODetector

print(f"‚úÖ Imports completati")
print(f"   Device: {Config.DEVICE}")
print(f"   PyTorch: {torch.__version__}")

In [None]:
# Load pre-trained YOLO model
yolo_ckpt = Config.CHECKPOINT_DIR / 'yolo' / 'yolo_train20' / 'weights' / 'best.pt'

if yolo_ckpt.exists():
    yolo_detector = YOLODetector(
        model_name=str(yolo_ckpt),
        num_classes=Config.NUM_CLASSES
    )
    print(f"‚úÖ YOLO loaded from: {yolo_ckpt}")
else:
    raise FileNotFoundError(f"YOLO checkpoint not found: {yolo_ckpt}")

In [None]:
# Test YOLO detection on a sample image
from dataset.linemod_pose import LineMODPoseDataset

# Load test dataset to get sample images
test_dataset = LineMODPoseDataset(
    dataset_root=Config.LINEMOD_ROOT,
    split='test',
    crop_margin=Config.POSE_CROP_MARGIN,
    output_size=Config.POSE_IMAGE_SIZE
)

# Get a sample
sample = test_dataset[0]
rgb_path = sample['rgb_path']
depth_path = sample['depth_path']
gt_bbox = sample['bbox'].numpy()  # Ground truth bbox [x, y, w, h]

print(f"üì∑ Sample image: {rgb_path}")
print(f"üìè GT BBox [x,y,w,h]: {gt_bbox}")

In [None]:
# Run YOLO detection
image_bgr = cv2.imread(rgb_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

detections = yolo_detector.detect_objects(image_bgr, conf_threshold=0.3)

print(f"üéØ Detected {len(detections)} object(s)")
for i, det in enumerate(detections):
    print(f"   [{i+1}] Class: {det['class_name']}, Conf: {det['confidence']:.2f}, BBox: {det['bbox']}")

In [None]:
# Visualize detection results
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Original image with GT bbox
axes[0].imshow(image_rgb)
x, y, w, h = gt_bbox
rect_gt = plt.Rectangle((x, y), w, h, fill=False, edgecolor='green', linewidth=2, label='GT')
axes[0].add_patch(rect_gt)
axes[0].set_title('Ground Truth BBox')
axes[0].legend()
axes[0].axis('off')

# Image with YOLO detections
axes[1].imshow(image_rgb)
for det in detections:
    x1, y1, x2, y2 = det['bbox']
    rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='red', linewidth=2)
    axes[1].add_patch(rect)
    axes[1].text(x1, y1-5, f"{det['class_name']} {det['confidence']:.2f}", 
                 color='red', fontsize=10, backgroundcolor='white')
axes[1].set_title(f'YOLO Detections ({len(detections)} objects)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"\n‚úÖ Step 1 completed: Object detection with pre-trained YOLO")

# Step 2: ROI - Crop RGB and Depth

For each detected bounding box:
1. Crop the RGB image with a margin
2. Crop the corresponding depth map at the **same exact coordinates**
3. Resize both to a fixed size (224x224) for the CNN

In [None]:
def crop_roi(image: np.ndarray, bbox_xyxy: np.ndarray, margin: float = 0.15, output_size: int = 224):
    """
    Crop a region of interest from an image given a bounding box.
    
    Args:
        image: Input image (H, W, C) for RGB or (H, W) for depth
        bbox_xyxy: Bounding box [x1, y1, x2, y2]
        margin: Margin to add around the bbox (as fraction of bbox size)
        output_size: Output size for the crop (square)
    
    Returns:
        Cropped and resized image
    """
    x1, y1, x2, y2 = bbox_xyxy
    w, h = x2 - x1, y2 - y1
    cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
    
    # Add margin and make it square
    size = max(w, h) * (1 + margin)
    half = size / 2
    
    # Compute crop coordinates (clipped to image bounds)
    img_h, img_w = image.shape[:2]
    x1_crop = int(max(0, cx - half))
    y1_crop = int(max(0, cy - half))
    x2_crop = int(min(img_w, cx + half))
    y2_crop = int(min(img_h, cy + half))
    
    # Crop
    if image.ndim == 3:
        crop = image[y1_crop:y2_crop, x1_crop:x2_crop, :]
    else:
        crop = image[y1_crop:y2_crop, x1_crop:x2_crop]
    
    # Resize to output size
    crop_resized = cv2.resize(crop, (output_size, output_size), interpolation=cv2.INTER_LINEAR)
    
    return crop_resized


# Test on the first detection
if len(detections) > 0:
    det = detections[0]
    bbox = det['bbox']
    
    print(f"üéØ Using detection: {det['class_name']} (conf: {det['confidence']:.2f})")
    print(f"   BBox [x1,y1,x2,y2]: {bbox}")
else:
    # Fallback to GT bbox if no detection
    x, y, w, h = gt_bbox
    bbox = np.array([x, y, x+w, y+h])
    print(f"‚ö†Ô∏è No YOLO detection, using GT bbox: {bbox}")

In [None]:
# Crop RGB image
rgb_crop = crop_roi(image_rgb, bbox, margin=0.15, output_size=224)

# Load and crop depth image
depth_raw = np.array(Image.open(depth_path))  # uint16, values in mm
depth_crop = crop_roi(depth_raw, bbox, margin=0.15, output_size=224)

# Normalize depth to [0, 1] for visualization and network input
DEPTH_MAX = 2000.0  # mm (typical max depth in LineMOD)
depth_crop_normalized = np.clip(depth_crop / DEPTH_MAX, 0, 1)

print(f"‚úÖ ROI crops created:")
print(f"   RGB crop shape: {rgb_crop.shape}")
print(f"   Depth crop shape: {depth_crop.shape}")
print(f"   Depth range: [{depth_crop.min():.0f}, {depth_crop.max():.0f}] mm")

In [None]:
# Visualize the crops
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# RGB crop
axes[0].imshow(rgb_crop)
axes[0].set_title(f'RGB Crop ({rgb_crop.shape[0]}x{rgb_crop.shape[1]})')
axes[0].axis('off')

# Depth crop (normalized for visualization)
im = axes[1].imshow(depth_crop_normalized, cmap='viridis')
axes[1].set_title(f'Depth Crop (normalized)')
axes[1].axis('off')
plt.colorbar(im, ax=axes[1], fraction=0.046, label='Depth (normalized)')

# RGB + Depth overlay
axes[2].imshow(rgb_crop)
axes[2].imshow(depth_crop_normalized, cmap='viridis', alpha=0.5)
axes[2].set_title('RGB + Depth Overlay')
axes[2].axis('off')

plt.tight_layout()
plt.show()

print(f"\n‚úÖ Step 2 completed: ROI crops for RGB and Depth")

# Step 3: Feature Extraction

- **RGB Branch**: ResNet-50 backbone ‚Üí 2048-dim features
- **Depth Branch**: DepthEncoder CNN ‚Üí 256-dim features (see `models/depth_encoder.py`)

In [None]:
# Feature extraction models
import torchvision.models as models
from torchvision.models import ResNet50_Weights
import torch.nn as nn
from models.depth_encoder import DepthEncoder

# RGB Branch: ResNet-50 (pretrained, without final FC)
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
rgb_encoder = nn.Sequential(*list(resnet.children())[:-1]).to(Config.DEVICE)
rgb_encoder.eval()

# Depth Branch: DepthEncoder
depth_encoder = DepthEncoder(output_dim=256).to(Config.DEVICE)
depth_encoder.eval()

print(f"‚úÖ Feature extractors loaded:")
print(f"   RGB: ResNet-50 ‚Üí 2048-dim")
print(f"   Depth: DepthEncoder ‚Üí 256-dim")

In [None]:
# Prepare tensors
from torchvision import transforms
import torch.nn as nn

# ImageNet normalization for RGB
imagenet_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# RGB tensor
rgb_tensor = torch.from_numpy(rgb_crop).permute(2, 0, 1).float() / 255.0
rgb_tensor = imagenet_normalize(rgb_tensor).unsqueeze(0).to(Config.DEVICE)

# Depth tensor
depth_tensor = torch.from_numpy(depth_crop_normalized).float().unsqueeze(0).unsqueeze(0).to(Config.DEVICE)

print(f"‚úÖ Input tensors: RGB {rgb_tensor.shape}, Depth {depth_tensor.shape}")

In [None]:
# Extract features
with torch.no_grad():
    f_rgb = rgb_encoder(rgb_tensor).squeeze()      # (2048,)
    f_depth = depth_encoder(depth_tensor).squeeze() # (256,)

print(f"‚úÖ Step 3 completed:")
print(f"   f_rgb: {f_rgb.shape}")
print(f"   f_depth: {f_depth.shape}")

# Step 4: Fusion

Late fusion via concatenation:

$$f_{\text{fused}} = \text{concat}(f_{\text{rgb}}, f_{\text{depth}}) \in \mathbb{R}^{2304}$$

In [None]:
# Late fusion: concatenate RGB and Depth features
f_fused = torch.cat([f_rgb, f_depth], dim=0)

print(f"‚úÖ Step 4 completed: Feature Fusion")
print(f"   f_rgb:   {f_rgb.shape[0]} dims")
print(f"   f_depth: {f_depth.shape[0]} dims")
print(f"   f_fused: {f_fused.shape[0]} dims (concatenated)")

In [None]:
# Visualize fused feature vector
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Fused features as heatmap (reshaped for visualization)
f_np = f_fused.cpu().numpy()
# Reshape to 2D for better visualization (48x48 = 2304)
f_2d = f_np.reshape(48, 48)
im = axes[0].imshow(f_2d, cmap='coolwarm', aspect='auto')
axes[0].set_title(f'Fused Features (2304 dims reshaped to 48√ó48)')
axes[0].set_xlabel('Feature index')
axes[0].set_ylabel('Feature index')
plt.colorbar(im, ax=axes[0], fraction=0.046)

# Show contribution of each branch
axes[1].barh(['Depth\n(256 dims)', 'RGB\n(2048 dims)'], [256, 2048], color=['orange', 'steelblue'])
axes[1].set_xlabel('Feature dimensions')
axes[1].set_title('Feature Contribution per Branch')
for i, v in enumerate([256, 2048]):
    axes[1].text(v + 50, i, f'{v} ({v/2304*100:.1f}%)', va='center')

plt.tight_layout()
plt.show()

# Step 5: Pose Estimation

MLP regressor predicts 6D pose from fused features:
- **Quaternion** (4D): $[q_w, q_x, q_y, q_z]$ normalized to unit norm
- **Translation** (3D): $[t_x, t_y, t_z]$ in meters

See `models/pose_regressor.py`

In [None]:
# Pose regressor
from models.pose_regressor import PoseRegressor

pose_regressor = PoseRegressor(input_dim=2304, dropout=0.3).to(Config.DEVICE)
pose_regressor.eval()

# Predict pose (model has random weights - just testing the pipeline)
with torch.no_grad():
    f_fused_batch = f_fused.unsqueeze(0)  # (1, 2304)
    pose_pred = pose_regressor(f_fused_batch).squeeze()  # (7,)

quat_pred = pose_pred[:4].cpu().numpy()
trans_pred = pose_pred[4:].cpu().numpy()

print(f"‚úÖ Step 5 completed: Pose Estimation")
print(f"   Quaternion [qw,qx,qy,qz]: {quat_pred}")
print(f"   Quaternion norm: {np.linalg.norm(quat_pred):.6f}")
print(f"   Translation [tx,ty,tz]: {trans_pred}")

In [None]:
# Compare with ground truth
gt_quat = sample['quaternion'].numpy()
gt_trans = sample['translation'].numpy()

print(f"üìä Comparison (random weights vs GT):")
print(f"   Predicted quat: {quat_pred}")
print(f"   GT quat:        {gt_quat}")
print(f"   Predicted trans: {trans_pred}")
print(f"   GT trans:        {gt_trans}")

print(f"\nüéâ Pipeline complete! Next: train the model on LineMOD dataset.")

# Step 6: Training

Train the complete RGB-D fusion model end-to-end:
- **RGB Encoder**: ResNet-50 (pretrained, fine-tuned)
- **Depth Encoder**: DepthEncoder (trained from scratch)
- **Pose Regressor**: MLP (trained from scratch)

Loss: Geodesic (rotation) + Smooth L1 (translation)

In [None]:
# Training configuration
TRAIN_CONFIG = {
    'epochs': 50,
    'batch_size': 16,
    'lr': 1e-4,
    'weight_decay': 1e-5,
    'lambda_rot': 1,
    'lambda_trans': 1,
}

print(f"üìã Training configuration:")
for k, v in TRAIN_CONFIG.items():
    print(f"   {k}: {v}")

In [None]:
# Custom Dataset for RGB-D fusion training
from torch.utils.data import Dataset, DataLoader, random_split

class LineMODFusionDataset(Dataset):
    """Dataset that returns RGB crop, Depth crop, and pose labels."""
    
    def __init__(self, base_dataset, crop_margin=0.15, output_size=224, depth_max=2000.0):
        self.base_dataset = base_dataset
        self.crop_margin = crop_margin
        self.output_size = output_size
        self.depth_max = depth_max
        self.imagenet_normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        sample = self.base_dataset[idx]
        
        # Load full images
        rgb_full = cv2.cvtColor(cv2.imread(sample['rgb_path']), cv2.COLOR_BGR2RGB)
        depth_full = np.array(Image.open(sample['depth_path']))
        
        # Get bbox (convert from [x,y,w,h] to [x1,y1,x2,y2])
        bbox = sample['bbox'].numpy()
        x, y, w, h = bbox
        bbox_xyxy = np.array([x, y, x+w, y+h])
        
        # Crop both RGB and Depth at same coordinates
        rgb_crop = crop_roi(rgb_full, bbox_xyxy, self.crop_margin, self.output_size)
        depth_crop = crop_roi(depth_full, bbox_xyxy, self.crop_margin, self.output_size)
        
        # Normalize depth to [0, 1]
        depth_crop = np.clip(depth_crop / self.depth_max, 0, 1).astype(np.float32)
        
        # Convert RGB to tensor and normalize
        rgb_tensor = torch.from_numpy(rgb_crop).permute(2, 0, 1).float() / 255.0
        rgb_tensor = self.imagenet_normalize(rgb_tensor)
        
        # Convert depth to tensor
        depth_tensor = torch.from_numpy(depth_crop).unsqueeze(0).float()
        
        return {
            'rgb': rgb_tensor,
            'depth': depth_tensor,
            'quaternion': sample['quaternion'],
            'translation': sample['translation']
        }

# Create dataloaders
full_train = LineMODPoseDataset(Config.LINEMOD_ROOT, split='train')
full_test = LineMODPoseDataset(Config.LINEMOD_ROOT, split='test')

# Train/val split
train_len = int(len(full_train) * 0.85)
val_len = len(full_train) - train_len
train_base, val_base = random_split(full_train, [train_len, val_len], 
                                     generator=torch.Generator().manual_seed(42))

train_dataset = LineMODFusionDataset(train_base)
val_dataset = LineMODFusionDataset(val_base)
test_dataset_fusion = LineMODFusionDataset(full_test)

train_loader = DataLoader(train_dataset, batch_size=TRAIN_CONFIG['batch_size'], shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=TRAIN_CONFIG['batch_size'], shuffle=False, num_workers=0)

print(f"‚úÖ Datasets created:")
print(f"   Train: {len(train_dataset)} samples")
print(f"   Val: {len(val_dataset)} samples")
print(f"   Test: {len(test_dataset_fusion)} samples")

In [None]:
# Initialize fresh models for training
#rgb_encoder_train = nn.Sequential(*list(models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).children())[:-1])
rgb_encoder_train = nn.Sequential(*list(models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).children())[:-1])
depth_encoder_train = DepthEncoder(output_dim=256)
pose_regressor_train = PoseRegressor(input_dim=2304, dropout=0.3)

# Move to device
rgb_encoder_train = rgb_encoder_train.to(Config.DEVICE)
depth_encoder_train = depth_encoder_train.to(Config.DEVICE)
pose_regressor_train = pose_regressor_train.to(Config.DEVICE)

# Optimizer (all parameters)
all_params = list(rgb_encoder_train.parameters()) + \
             list(depth_encoder_train.parameters()) + \
             list(pose_regressor_train.parameters())

optimizer = torch.optim.AdamW(all_params, lr=TRAIN_CONFIG['lr'], weight_decay=TRAIN_CONFIG['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TRAIN_CONFIG['epochs'])

# Loss functions
from utils.losses import PoseLoss
criterion = PoseLoss(lambda_trans=TRAIN_CONFIG['lambda_trans'], lambda_rot=TRAIN_CONFIG['lambda_rot'])

print(f"‚úÖ Models and optimizer initialized")
print(f"   Total parameters: {sum(p.numel() for p in all_params):,}")

In [None]:
# Training loop
from tqdm.auto import tqdm

def train_epoch(rgb_enc, depth_enc, pose_reg, loader, optimizer, criterion, device):
    rgb_enc.train()
    depth_enc.train()
    pose_reg.train()
    
    total_loss = 0
    for batch in tqdm(loader, desc="Training", leave=False):
        rgb = batch['rgb'].to(device)
        depth = batch['depth'].to(device)
        gt_quat = batch['quaternion'].to(device)
        gt_trans = batch['translation'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        f_rgb = rgb_enc(rgb).squeeze(-1).squeeze(-1)
        f_depth = depth_enc(depth)
        f_fused = torch.cat([f_rgb, f_depth], dim=1)
        pose = pose_reg(f_fused)
        
        pred_quat = pose[:, :4]
        pred_trans = pose[:, 4:]
        
        # Loss (PoseLoss returns a dict with 'total', 'rot', 'trans')
        loss_dict = criterion(pred_quat, pred_trans, gt_quat, gt_trans)
        loss = loss_dict['total']
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def validate_epoch(rgb_enc, depth_enc, pose_reg, loader, criterion, device):
    rgb_enc.eval()
    depth_enc.eval()
    pose_reg.eval()
    
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation", leave=False):
            rgb = batch['rgb'].to(device)
            depth = batch['depth'].to(device)
            gt_quat = batch['quaternion'].to(device)
            gt_trans = batch['translation'].to(device)
            
            f_rgb = rgb_enc(rgb).squeeze(-1).squeeze(-1)
            f_depth = depth_enc(depth)
            f_fused = torch.cat([f_rgb, f_depth], dim=1)
            pose = pose_reg(f_fused)
            
            pred_quat = pose[:, :4]
            pred_trans = pose[:, 4:]
            
            loss_dict = criterion(pred_quat, pred_trans, gt_quat, gt_trans)
            loss = loss_dict['total']
            total_loss += loss.item()
    
    return total_loss / len(loader)

print("‚úÖ Training functions defined")

In [None]:
# Run training
train_losses = []
val_losses = []
best_val_loss = float('inf')

checkpoint_dir = Config.CHECKPOINT_DIR / 'pose' / 'fusion_rgbd'
checkpoint_dir.mkdir(parents=True, exist_ok=True)

print(f"üöÄ Starting training for {TRAIN_CONFIG['epochs']} epochs...")
print(f"   Checkpoint dir: {checkpoint_dir}")

for epoch in range(TRAIN_CONFIG['epochs']):
    train_loss = train_epoch(
        rgb_encoder_train, depth_encoder_train, pose_regressor_train,
        train_loader, optimizer, criterion, Config.DEVICE
    )
    val_loss = validate_epoch(
        rgb_encoder_train, depth_encoder_train, pose_regressor_train,
        val_loader, criterion, Config.DEVICE
    )
    
    scheduler.step()
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'rgb_encoder': rgb_encoder_train.state_dict(),
            'depth_encoder': depth_encoder_train.state_dict(),
            'pose_regressor': pose_regressor_train.state_dict(),
            'optimizer': optimizer.state_dict(),
            'val_loss': val_loss
        }, checkpoint_dir / 'best.pt')
        print(f"Epoch {epoch+1:3d} | Train: {train_loss:.4f} | Val: {val_loss:.4f} ‚≠ê (best)")
    else:
        print(f"Epoch {epoch+1:3d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")

print(f"\n‚úÖ Training complete! Best val loss: {best_val_loss:.4f}")

In [None]:
# Plot training curves
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(train_losses, label='Train Loss', color='steelblue')
ax.plot(val_losses, label='Val Loss', color='orange')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('RGB-D Fusion Training')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"üìä Final losses:")
print(f"   Train: {train_losses[-1]:.4f}")
print(f"   Val: {val_losses[-1]:.4f}")
print(f"   Best Val: {best_val_loss:.4f}")