In [None]:
import os
import sys

# Clone or pull part
repo_url = "https://github.com/fraco03/6D_pose.git"
repo_dir = "/content/6D_pose"   #Modify here for kaggle
branch = "main"

# Clone if missing
if not os.path.exists(repo_dir):
    !git clone -b {branch} {repo_url}
    print(f"Cloned {repo_url} to {repo_dir}")
else:
    %cd {repo_dir}
    !git fetch origin
    !git checkout {branch}
    !git reset --hard origin/{branch}
    %cd ..
    print(f"Updated {repo_url} to {repo_dir}")

# Add repository to Python path
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)


In [None]:
%pwd

In [None]:
import sys

sys.path.append('../..')

In [None]:
from google.colab import drive
from utils.load_data import mount_drive

# Mounting part
mount_drive()

In [None]:
# dataset_root = "/content/drive/MyDrive/Linemod_preprocessed" #Modify here for kaggle
dataset_root = "../../Linemod_preprocessed_small"

print("\n‚úÖ Setup complete!")
print(f"üìÅ Dataset path: {dataset_root}")


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.pose_rgbd.loss import GeodesicLoss
from src.pose_rgbd.dataset import LineModPoseDepthDataset

train_dataset = LineModPoseDepthDataset(
    root_dir=dataset_root,
    split="train"
)

test_dataset = LineModPoseDepthDataset(
    root_dir=dataset_root,
    split="test"
)


In [None]:
sample = train_dataset[0]

print(f"Sample keys: {sample.keys()}")
print(f"Depth shape: {sample['depth'].shape}")
print(f"RGB shape: {sample['image'].shape}")

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)


In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import MSELoss
from src.pose_rgbd.model import RotationPredictionModel
from src.pose_rgbd.dataset import LineModPoseDepthDataset
from src.pose_rgbd.loss import GeodesicLoss

# Configurazione del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
model = RotationPredictionModel(pretrained=True, freeze_backbone=True)
model = model.to(device)

# Loss function and optimizer
criterion = GeodesicLoss() 
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
from tqdm import tqdm
import os
import datetime

# Ciclo di training
num_epochs = 50
best_test_loss = float('inf')
# checkpoint_dir = "checkpoints"
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_dir = f'/content/drive/MyDrive/runs/{timestamp}' # modify here for kaggle
os.makedirs(checkpoint_dir, exist_ok=True)

# Track losses for plotting
train_losses = []
test_losses = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    
    train_pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Training")
    for batch in train_pbar:
        rgb = batch['image'].to(device)  # RGB image (B, 3, H, W)
        depth = batch['depth'].unsqueeze(1).to(device)  # Depth (B, 1, H, W)
        rotations = batch['rotation'].to(device)  # GT quaternion (B, 4)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(rgb, depth)
        
        # Calcolo della loss
        loss = criterion(outputs, rotations)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        train_pbar.set_postfix({'loss': epoch_loss / (train_pbar.n + 1)})
    
    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    model.eval()
    test_loss = 0.0
    
    test_pbar = tqdm(test_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Testing")
    with torch.no_grad():
        for batch in test_pbar:
            rgb = batch['image'].to(device)
            depth = batch['depth'].unsqueeze(1).to(device)
            rotations = batch['rotation'].to(device)
            
            outputs = model(rgb, depth)
            loss = criterion(outputs, rotations)
            test_loss += loss.item()
            test_pbar.set_postfix({'loss': test_loss / (test_pbar.n + 1)})

    avg_test_loss = test_loss / len(test_loader)
    test_losses.append(avg_test_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Test Loss: {avg_test_loss:.4f}")
    
    # Save checkpoint if test loss improved
    if avg_test_loss < best_test_loss:
        best_test_loss = avg_test_loss
        checkpoint_path = os.path.join(checkpoint_dir, "best_model.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'test_loss': avg_test_loss,
        }, checkpoint_path)
        print(f"‚úÖ Checkpoint saved! Best Test Loss: {best_test_loss:.4f}")
    else:
        print(f"‚ö†Ô∏è  No improvement. Best Test Loss: {best_test_loss:.4f}")

In [None]:
import matplotlib.pyplot as plt

# Create plots directory
plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)

# Plot 1: Training vs Test Loss
plt.figure(figsize=(10, 6))
epochs_range = range(1, num_epochs + 1)
plt.plot(epochs_range, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=6)
plt.plot(epochs_range, test_losses, 'r-s', label='Test Loss', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training vs Test Loss', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
loss_plot_path = os.path.join(plots_dir, "loss_comparison.png")
plt.savefig(loss_plot_path, dpi=300, bbox_inches='tight')
print(f"‚úÖ Plot saved: {loss_plot_path}")
plt.show()

# Plot 2: Only Training Loss
plt.figure(figsize=(10, 6))
plt.plot(epochs_range, train_losses, 'b-o', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Training Loss', fontsize=12)
plt.title('Training Loss Over Epochs', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
train_loss_path = os.path.join(plots_dir, "training_loss.png")
plt.savefig(train_loss_path, dpi=300, bbox_inches='tight')
print(f"‚úÖ Plot saved: {train_loss_path}")
plt.show()

# Plot 3: Only Test Loss
plt.figure(figsize=(10, 6))
plt.plot(epochs_range, test_losses, 'r-s', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Test Loss', fontsize=12)
plt.title('Test Loss Over Epochs', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.axhline(y=best_test_loss, color='g', linestyle='--', label=f'Best: {best_test_loss:.4f}', linewidth=2)
plt.legend(fontsize=11)
plt.tight_layout()
test_loss_path = os.path.join(plots_dir, "test_loss.png")
plt.savefig(test_loss_path, dpi=300, bbox_inches='tight')
print(f"‚úÖ Plot saved: {test_loss_path}")
plt.show()

print(f"\n‚úÖ All plots saved in '{plots_dir}' directory!")

# Visualize samples

In [None]:
import random
import cv2
from utils.projection_utils import setup_projection_utils, visualize_pose_comparison, get_image

# Setup projection utils
setup_projection_utils(dataset_root)

# Load best model
best_checkpoint_path = os.path.join(checkpoint_dir, "best_model.pth")
checkpoint = torch.load(best_checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch']} with test loss: {checkpoint['test_loss']:.4f}")

# Select a random sample from test dataset
random_idx = random.randint(0, len(test_dataset) - 1)
sample = test_dataset[random_idx]

print(f"\nüì∑ Visualizing sample {random_idx}:")
print(f"   Object ID: {sample['object_id']}")
print(f"   Image ID: {sample['img_id']}")

# Get the original image
img_path = sample['img_path']
image_rgb = cv2.imread(str(img_path))
image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)

# Prepare input for model
rgb = sample['image'].unsqueeze(0).to(device)  # Add batch dimension
depth = sample['depth'].unsqueeze(0).unsqueeze(0).to(device)  # Add batch and channel dimensions

# Get model prediction
model.eval()
with torch.no_grad():
    pred_rotation = model(rgb, depth)[0].cpu().numpy()

# Get ground truth
gt_rotation = sample['rotation'].numpy()
gt_translation = sample['translation'].numpy()

# Get camera intrinsics
cam_K = sample['cam_K'].numpy()

print(f"\nüìä Ground Truth vs Prediction:")
print(f"   GT Rotation: {gt_rotation}")
print(f"   Pred Rotation: {pred_rotation}")
print(f"   GT Translation: {gt_translation}")

# Visualize pose comparison
fig, ax = plt.subplots(1, 1, figsize=(14, 8))
img_vis = visualize_pose_comparison(
    image_rgb,
    object_id=sample['object_id'],
    cam_K=cam_K,
    gt_rotation=gt_rotation,
    gt_translation=gt_translation,
    pred_rotation=pred_rotation,
    pred_translation=gt_translation  # Using GT translation for now
)

# Convert BGR to RGB for matplotlib
img_vis_rgb = cv2.cvtColor(img_vis, cv2.COLOR_BGR2RGB)
ax.imshow(img_vis_rgb)
ax.axis('off')
ax.set_title(f"Pose Visualization - Object {sample['object_id']}, Image {sample['img_id']}", fontsize=14, fontweight='bold')
plt.tight_layout()
# plt.savefig(os.path.join(plots_dir, f"pose_visualization_sample_{random_idx}.png"), dpi=150, bbox_inches='tight')
print(f"\n‚úÖ Visualization saved!")
plt.show()