In [None]:
!pip install gdown -q
print("Downloading folder from Drive...")
# Downloads the folder structure containing the Linemod dataset
!gdown "https://drive.google.com/file/d/1Zwh-gAk_-CBgpOcNLPLdFNxggi3NTh-S/view?usp=drive_link" --fuzzy
import glob
zip_files = glob.glob("**/Linemod_preprocessed.zip", recursive=True)

if zip_files:
    zip_path = zip_files[0]
    print(f"Unzipping {zip_path}...")
    !unzip -q -o "{zip_path}"
    print("Extraction complete!")
else:
    print("Error: Linemod_preprocessed.zip not found. Check the download.")

Cloning into '6D_pose'...
remote: Enumerating objects: 174, done.[K
remote: Counting objects: 100% (174/174), done.[K
remote: Compressing objects: 100% (120/120), done.[K
remote: Total 174 (delta 87), reused 113 (delta 45), pack-reused 0 (from 0)[K
Receiving objects: 100% (174/174), 2.23 MiB | 6.95 MiB/s, done.
Resolving deltas: 100% (87/87), done.
Cloned https://github.com/fraco03/6D_pose.git to /content/6D_pose


In [None]:
import os
import sys

# Clone or pull part
repo_url = "https://github.com/fraco03/6D_pose.git"
repo_dir = "/kaggle/working/6D_pose"   #Modify here for kaggle
branch = "main"

# Clone if missing
if not os.path.exists(repo_dir):
    !git clone -b {branch} {repo_url}
    print(f"Cloned {repo_url} to {repo_dir}")
else:
    %cd {repo_dir}
    !git fetch origin
    !git checkout {branch}
    !git reset --hard origin/{branch}
    %cd ..
    print(f"Updated {repo_url} to {repo_dir}")

# Add repository to Python path
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)

In [None]:
# Cancella tutte le cartelle __pycache__ ricorsivamente nella directory di lavoro
!find . -name "__pycache__" -type d -exec rm -rf {} +
print("üóëÔ∏è Cache pulita dal disco.")

In [None]:
!pip install plyfile
from src.pose_rgb.dataset import LineModPoseDataset
from src.pose_rgb.model import ResNetRotation, TranslationNet
from src.pose_rgb.pose_utils import quaternion_to_rotation_matrix, convert_rotation_to_quaternion, inverse_pinhole_projection
from src.pose_rgb.test_dataset import *
from src.pose_rgb.loss import CombinedPoseLoss, MultiObjectPointMatchingLoss, TranslationLoss
from torch.utils.data import Dataset, DataLoader
import pathlib
import torch.optim as optim
from tqdm import tqdm
from utils.projection_utils import *
from utils.linemod_config import *
from metrics import compute_ADD_metric_quaternion


In [None]:
root_dir = '/kaggle/input/line-mode/Linemod_preprocessed' #Modify here for kaggle

train_dataset = LineModPoseDataset(split='train', root_dir=root_dir)
test_dataset = LineModPoseDataset(split='test', root_dir=root_dir)

#Dataloder
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

 Loaded LineModPoseDataset
   Split: train
   Dir : [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
   Total samples: 3631
 Loaded LineModPoseDataset
   Split: test
   Dir : [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
   Total samples: 20528




In [None]:
!pip install trimesh
import torch
import numpy as np
import trimesh
import os

def load_all_object_points(models_dir, valid_obj_ids, num_points=1000):
    """
    Loads .ply files for ALL objects and stacks them into a single Tensor.
    
    Args:
        models_dir (str): Folder containing .ply files (e.g., 'obj_01.ply').
        valid_obj_ids (list): List of integers IDs (e.g., [1, 5, 6...]).
        num_points (int): Number of points to sample per object.
        
    Returns:
        torch.Tensor: Shape (Num_Classes, num_points, 3).
                      The index in dimension 0 corresponds to the index in valid_obj_ids.
    """
    all_points_list = []
    
    print(f"üì¶ Loading {len(valid_obj_ids)} 3D models from {models_dir}...")
    
    for i, obj_id in enumerate(valid_obj_ids):
        # Construct filename assuming LineMod format (e.g., 'obj_01.ply')
        ply_name = f"obj_{obj_id:02d}.ply" 
        ply_path = os.path.join(models_dir, ply_name)
        
        if not os.path.exists(ply_path):
            raise FileNotFoundError(f"Model not found: {ply_path}")

        # Load mesh
        mesh = trimesh.load(ply_path)
        vertices = np.array(mesh.vertices)
        
        # Sample points
        if len(vertices) > num_points:
            idx = np.random.choice(len(vertices), num_points, replace=False)
            selected = vertices[idx]
        else:
            # Padding via repetition if not enough points (rare in LineMod)
            choice = np.random.choice(len(vertices), num_points, replace=True)
            selected = vertices[choice]
            
        # Add to list
        all_points_list.append(selected)

    # Stack into a single tensor
    # Shape: (Num_Classes, Num_Points, 3)
    # Example: (13, 1000, 3)
    bank_tensor = torch.from_numpy(np.array(all_points_list)).float()
    
    # Unit conversion (mm to meters) if needed
    # bank_tensor = bank_tensor / 1000.0 
    
    return bank_tensor / 1000

In [None]:
LINEMOD_NAMES = [
            'ape',         # Index 0 (ID 1)
            'benchvise',   # Index 1 (ID 2)
            'camera',      # Index 2 (ID 4)
            'can',         # Index 3 (ID 5)
            'cat',         # Index 4 (ID 6)
            'driller',     # Index 5 (ID 8)
            'duck',        # Index 6 (ID 9)
            'eggbox',      # Index 7 (ID 10)
            'glue',        # Index 8 (ID 11)
            'holepuncher', # Index 9 (ID 12)
            'iron',        # Index 10 (ID 13)
            'lamp',        # Index 11 (ID 14)
            'phone'        # Index 12 (ID 15)
        ]
name_to_idx = {name: i for i, name in enumerate(LINEMOD_NAMES)}

In [None]:
# ONLY ROTATION TRAINING SCRIPT
import os
import torch
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime
from itertools import islice
import numpy as np

# ==========================================
# 1. SETUP & HYPERPARAMETERS
# ==========================================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 0.0001
NUM_EPOCHS = 50

# --- PATHS ---
# Define where your .ply models are located
MODELS_DIR = '/kaggle/input/line-mode/Linemod_preprocessed/models' 
# List of valid object IDs in your dataset (must match your dataset logic)
VALID_OBJ_IDS = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15] 

# --- LOGGING SETUP ---
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# Directory to save checkpoints and logs
CHECKPOINT_DIR = f'/kaggle/working/run_rotation' 
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
run_dir = CHECKPOINT_DIR

print(f"\nüî• STARTING ROTATION-ONLY TRAINING on {DEVICE}...")
print(f"üìÅ Saving outputs to: {run_dir}")

# ==========================================
# 2. INITIALIZE LOSS & MODELS
# ==========================================

# A. LOAD 3D POINTS FOR LOSS
# We need to load the point clouds for all objects to use PointMatchingLoss.
print("üì¶ Loading 3D Point Clouds for Loss Function...")
# Use the helper function we defined earlier to load all ply files
point_bank = load_all_object_points(MODELS_DIR, VALID_OBJ_IDS, num_points=1000)
point_bank = point_bank.to(DEVICE) # Move entire bank to GPU


# B. DEFINE LOSS FUNCTION

criterion = MultiObjectPointMatchingLoss(point_bank).to(DEVICE)

# C. INITIALIZE MODEL
# We only use the Rotation Network
model_rot = ResNetRotation(freeze_backbone=False).to(DEVICE)

# D. OPTIMIZER
# We only optimize the rotation model parameters
optimizer = optim.Adam(
    model_rot.parameters(),
    lr=LEARNING_RATE
)

# E. METRICS STORAGE
train_losses = []
val_losses = []
best_val_loss = float('inf')

# ==========================================
# 3. TRAINING LOOP
# ==========================================
for epoch in range(NUM_EPOCHS):

    # --- A. TRAIN PHASE ---
    model_rot.train()
    running_train_loss = 0.0

    # Progress Bar
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")

    for batch in pbar:
        # 1. Move data to GPU
        imgs = batch['image'].to(DEVICE)
        gt_rot = batch['rotation'].to(DEVICE)
        
        # We need class indices for the PointMatchingLoss (Index 0 to 12)
        # Ensure your Dataset returns 'class_id' as a mapped index (0..N), NOT the raw Linemod ID (1,5,8..)
        raw_names_list = batch['class_idx'] # es. ['can', 'ape', 'driller']
        
        
        try:
            indices = [name_to_idx[name] for name in raw_names_list]
        except KeyError as e:
            print(f"‚ùå ERRORE CRITICO: Trovato nome '{e}' non presente nella lista LINEMOD_NAMES!")
            raise e

        
        class_ids = torch.tensor(indices, dtype=torch.long).to(DEVICE)

        # 2. Forward Pass
        pred_rot = model_rot(imgs)

        # 3. Calculate Loss
        # Pass class_ids so the loss knows which 3D model to use for each image in the batch
        if point_bank is not None:
            loss = criterion(pred_rot, gt_rot, class_ids)
        else:
            loss = criterion(pred_rot, gt_rot) # Fallback doesn't use IDs

        # 4. Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 5. Logging
        running_train_loss += loss.item()
        pbar.set_postfix({'ADD Loss': f"{loss.item():.4f}"})

    avg_train_loss = running_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # --- B. EVALUATION PHASE ---
    model_rot.eval()
    running_val_loss = 0.0
    val_batches_limit = 50  # Validate on a subset to save time per epoch
    count_batches = 0

    with torch.no_grad():
        val_iterator = islice(test_loader, val_batches_limit)
        val_pbar = tqdm(val_iterator, total=val_batches_limit, desc="Validating")

        for batch in val_pbar:
            imgs = batch['image'].to(DEVICE)
            gt_rot = batch['rotation'].to(DEVICE)
            raw_names_list = batch['class_idx'] # es. ['can', 'ape', 'driller']
            
            
            try:
                indices = [name_to_idx[name] for name in raw_names_list]
            except KeyError as e:
                print(f"‚ùå ERRORE CRITICO: Trovato nome '{e}' non presente nella lista LINEMOD_NAMES!")
                raise e
    
            
            class_ids = torch.tensor(indices, dtype=torch.long).to(DEVICE)

            # Forward
            pred_rot = model_rot(imgs)

            # Loss
            if point_bank is not None:
                loss = criterion(pred_rot, gt_rot, class_ids)
            else:
                loss = criterion(pred_rot, gt_rot)
                
            running_val_loss += loss.item()
            count_batches += 1

    avg_val_loss = running_val_loss / count_batches if count_batches > 0 else 0
    val_losses.append(avg_val_loss)

    # --- C. REPORT & SAVE ---
    print(f"üìä Epoch {epoch+1} Summary: Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # Save Best Model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        save_path = os.path.join(CHECKPOINT_DIR, "best_model_rot.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_rot.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': best_val_loss
        }, save_path)
        print(f"üèÜ New Best Rotation Model Saved! (Loss: {best_val_loss:.4f})")

    # Save Last Checkpoint (for resuming if needed)
    if (epoch + 1) == NUM_EPOCHS:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_rot.state_dict(),
            'val_loss': avg_val_loss
        }, os.path.join(CHECKPOINT_DIR, f"checkpoint_last.pth"))

# --- D. PLOTTING ---
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss (ADD Metric)')
plt.plot(val_losses, label='Val Loss (ADD Metric)')
plt.title('Rotation Training Convergence')
plt.xlabel('Epochs')
plt.ylabel('Average Distance (m)')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(CHECKPOINT_DIR, 'rotation_training_curve.png'))
print("üéâ TRAINING COMPLETE! Training curve saved.")

In [None]:
import torch
import numpy as np
import os
import trimesh
import pandas as pd
from tqdm.auto import tqdm

# Ensure compute_ADD_metric_quaternion is imported or defined in your notebook

# ==========================================
# 1. LOAD 3D MODELS AND DIAMETERS
# ==========================================
def load_models_info(models_dir, obj_ids, num_points=1000):
    """
    Loads 3D meshes and calculates the DIAMETER for each object.
    Returns:
        point_cache: {id: points (N, 3)}
        diameters:   {id: diameter (float)}
    """
    point_cache = {}
    diameters = {}
    
    unique_ids = sorted(list(set(obj_ids)))
    print(f"‚è≥ Loading info for {len(unique_ids)} 3D models...")
    
    for oid in tqdm(unique_ids, desc="Mesh Analysis"):
        filename = f"obj_{int(oid):02d}.ply"
        path = os.path.join(models_dir, filename)
        
        if os.path.exists(path):
            try:
                mesh = trimesh.load(path)
                
                # 1. Sample Points (for ADD calculation)
                points, _ = trimesh.sample.sample_surface(mesh, num_points)
                point_cache[oid] = points / 1000.0 # Convert mm -> Meters
                
                # 2. Calculate Diameter (for Accuracy threshold)
                # Standard LineMod method: Diagonal of the Bounding Box
                extents = mesh.extents / 1000.0 # Meters
                diameter = np.linalg.norm(extents)
                diameters[oid] = diameter
            except Exception as e:
                print(f"‚ùå Error loading {filename}: {e}")
        else:
            print(f"‚ö†Ô∏è Missing model file: {path}")
            
    return point_cache, diameters

# ==========================================
# 2. PANDAS EVALUATION FUNCTION
# ==========================================
def evaluate_with_pandas(model_rot, dataloader, device, models_dir, model_trans=None):
    model_rot.eval()
    if model_trans: model_trans.eval()
    
    # 1. Get unique IDs from the dataset to load specific meshes
    try:
        # Try to extract IDs from dataset if iterable
        all_obj_ids = [s['object_id'] for s in dataloader.dataset]
    except:
        # Fallback if dataset is complex
        all_obj_ids = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]

    points_dict, diameters_dict = load_models_info(models_dir, all_obj_ids)
    
    # List to accumulate raw results
    raw_results = []
    
    print("\nüöÄ Starting Benchmark (ADD Error + Accuracy)...")
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            # Move data to GPU
            imgs = batch['image'].to(device)
            gt_quats = batch['rotation'].to(device)
            gt_trans = batch['translation'].to(device)
            obj_ids = batch['object_id'] # CPU tensor
            
            # Predict Rotation
            pred_quats = model_rot(imgs)
            pred_trans_batch = gt_trans


            # Convert to Numpy for metric calculation
            pred_quats_np = pred_quats.cpu().numpy()
            pred_trans_np = pred_trans_batch.cpu().numpy()
            gt_quats_np = gt_quats.cpu().numpy()
            gt_trans_np = gt_trans.cpu().numpy()
            
            # Loop through batch samples
            batch_size = imgs.shape[0]
            for i in range(batch_size):
                curr_id = int(obj_ids[i])
                
                # Skip if we don't have 3D info for this object
                if curr_id not in points_dict: 
                    continue
                
                # --- CALCULATE ADD METRIC (in Meters) ---
                add_error = compute_ADD_metric_quaternion(
                    model_points=points_dict[curr_id],
                    gt_quat=gt_quats_np[i],
                    gt_translation=gt_trans_np[i],
                    pred_quat=pred_quats_np[i],
                    pred_translation=pred_trans_np[i]
                )
                
                # --- CALCULATE THRESHOLD & ACCURACY ---
                diam = diameters_dict[curr_id]
                threshold = diam * 0.1 # 10% of diameter
                is_correct = add_error < threshold
                
                # Save raw result
                raw_results.append({
                    'obj_id': curr_id,
                    'diameter_cm': diam * 100,
                    'add_error_m': add_error,
                    'add_error_cm': add_error * 100,
                    'threshold_cm': threshold * 100,
                    'is_correct': is_correct
                })

    # ==========================================
    # 3. GENERATE PANDAS REPORT
    # ==========================================
    if not raw_results:
        print("‚ùå No results collected. Check your dataloader or model paths.")
        return None, None

    # Create DataFrame
    df = pd.DataFrame(raw_results)
    
    # Group by Object ID and calculate stats
    report = df.groupby('obj_id').agg(
        Samples=('obj_id', 'count'),
        Diameter_cm=('diameter_cm', 'first'), 
        Mean_Error_cm=('add_error_cm', 'mean'),
        Accuracy_pct=('is_correct', 'mean') # Mean of booleans is percentage
    )
    
    # Format Accuracy column (0.69 -> 69.0)
    report['Accuracy_pct'] = report['Accuracy_pct'] * 100
    
    # --- PRINT TABLE ---
    print("\n" + "="*60)
    print("üìä DETAILED REPORT BY OBJECT")
    print("="*60)
    # Use pandas to_string for nice formatting
    print(report.to_string(float_format="{:.2f}".format))
    print("="*60)
    
    # --- CALCULATE GLOBAL METRICS ---
    total_correct = df['is_correct'].sum()
    total_samples = len(df)
    global_acc = (total_correct / total_samples) * 100
    global_err = df['add_error_cm'].mean()
    
    print(f"\nüèÜ GLOBAL RESULTS (Entire Dataset)")
    print(f"   ‚û§ Total Samples:       {total_samples}")
    print(f"   ‚û§ Mean Error (ADD):    {global_err:.2f} cm")
    print(f"   ‚û§ Accuracy (ADD-0.1d): {global_acc:.2f} %")
    print("="*60)
    
    return report, df 

# --- USAGE EXAMPLE ---
MODELS_ROOT = '/kaggle/input/line-mode/Linemod_preprocessed/models'

# Make sure 'compute_ADD_metric_quaternion' is defined before running
report_df, raw_df = evaluate_with_pandas(model_rot, test_loader, DEVICE, MODELS_ROOT, model_trans=None)

In [None]:
# ONLY TRANSLATION TRAINING SCRIPT
from src.pose_rgb import model
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import os
import numpy as np

# ==========================================
# 1. CONFIGURATION
# ==========================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64        # Adjust if you run out of memory
LR = 0.001             # Constant Learning Rate
NUM_EPOCHS = 60
CHECKPOINT_DIR = f'/kaggle/working/run_translation' 
DATA_ROOT = '/kaggle/input/line-mode/Linemod_preprocessed'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ==========================================
# 2. MODEL & OPTIMIZER SETUP
# ==========================================
print(f"üß† Initializing Model on {DEVICE}...")

# Initialize your custom TranslationNet
model_transl = TranslationNet().to(DEVICE)

# Define Loss (Weighted to prioritize Depth Z)
criterion = TranslationLoss(z_weight=1) 

# Simple Adam Optimizer (No Scheduler)
optimizer = optim.Adam(model_transl.parameters(), lr=LR)

# ==========================================
# 3. TRAINING LOOP
# ==========================================
best_val_mae = float('inf') # Track the best error to save the best model

print("üöÄ Starting Training Loop...")

for epoch in range(NUM_EPOCHS):
    
    # --- TRAIN PHASE ---
    model_transl.train()
    running_loss = 0.0
    
    # Progress bar for training
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
    
    for batch in train_loop:
        # Move data to GPU
        imgs = batch['image'].to(DEVICE)           # (B, 3, 224, 224)
        bbox_info = batch['bbox_info'].to(DEVICE)  # (B, 4) Normalized BBox GPS
        gt_trans = batch['translation'].to(DEVICE) # (B, 3) Absolute Translation in METERS

        # Forward Pass
        preds = model_transl(imgs, bbox_info)
        
        # Calculate Loss
        loss = criterion(preds, gt_trans)
        
        # Backward Pass (Update Weights)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update stats
        running_loss += loss.item()
        train_loop.set_postfix(loss=loss.item())

    avg_train_loss = running_loss / len(train_loader)

    # --- VALIDATION PHASE ---
    model_transl.eval()
    val_loss = 0.0
    
    # Variables to calculate error in Centimeters (for human readability)
    error_sum_xyz = np.array([0.0, 0.0, 0.0]) 
    total_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Eval]"):
            imgs = batch['image'].to(DEVICE)
            bbox_info = batch['bbox_info'].to(DEVICE)
            gt_trans = batch['translation'].to(DEVICE)

            # Predict
            preds = model_transl(imgs, bbox_info)
            
            # Calculate Loss
            loss = criterion(preds, gt_trans)
            val_loss += loss.item()
            
            # Calculate Absolute Error (in Meters)
            abs_err = torch.abs(preds - gt_trans).cpu().numpy()
            error_sum_xyz += abs_err.sum(axis=0)
            total_samples += imgs.shape[0]

    avg_val_loss = val_loss / len(test_loader)
    
    # Convert Mean Error to Centimeters
    mean_error_m = error_sum_xyz / total_samples
    mean_error_cm = mean_error_m * 100.0
    total_mae_cm = np.mean(mean_error_cm) # Average error across X, Y, Z

    # --- REPORTING ---
    print(f"\nüìä REPORT EPOCH {epoch+1}")
    print(f"   Train Loss:    {avg_train_loss:.5f}")
    print(f"   Val Loss:      {avg_val_loss:.5f}")
    print(f"   --------------------------------")
    print(f"   Error X:       {mean_error_cm[0]:.2f} cm")
    print(f"   Error Y:       {mean_error_cm[1]:.2f} cm")
    print(f"   Error Z:       {mean_error_cm[2]:.2f} cm (Depth)")
    print(f"   --------------------------------")
    
    # Save Best Model (if error is lower than previous best)
    if total_mae_cm < best_val_mae:
        best_val_mae = total_mae_cm
        torch.save(model_transl.state_dict(), f"{CHECKPOINT_DIR}/best_translation_model.pth")
        print(f"   üíæ New Best Model Saved! (Avg Error: {total_mae_cm:.2f} cm)")
        
    # Save periodic checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save(model_transl.state_dict(), f"{CHECKPOINT_DIR}/translation_ep{epoch+1}.pth")

print("\n‚úÖ Training Complete. Best model saved in:", CHECKPOINT_DIR)

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

def evaluate_translation_only(model_trans, dataloader, device):
    """
    Evaluates only the Translation Model.
    Reports Mean Absolute Error (MAE) in cm for X, Y, Z.
    """
    model_trans.eval()
    
    raw_results = []
    
    print("\nüöÄ Starting Translation Benchmark...")
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            imgs = batch['image'].to(device)
            bbox_info = batch['bbox_info'].to(device)
            gt_trans = batch['translation'].to(device) # (B, 3) in Meters
            obj_ids = batch['object_id']
            
            # Predict Translation
            pred_trans = model_trans(imgs, bbox_info) # (B, 3)
            
            # Calculate Absolute Error (Meters)
            # abs_err shape: (B, 3) -> [err_x, err_y, err_z]
            abs_err = torch.abs(pred_trans - gt_trans).cpu().numpy()
            
            gt_np = gt_trans.cpu().numpy()
            pred_np = pred_trans.cpu().numpy()
            
            batch_size = imgs.shape[0]
            for i in range(batch_size):
                curr_id = int(obj_ids[i])
                
                raw_results.append({
                    'obj_id': curr_id,
                    'err_x_cm': abs_err[i, 0] * 100,
                    'err_y_cm': abs_err[i, 1] * 100,
                    'err_z_cm': abs_err[i, 2] * 100,
                    'total_err_cm': np.linalg.norm(abs_err[i]) * 100,
                    'gt_z_m': gt_np[i, 2],    # Useful to see if error correlates with depth
                    'pred_z_m': pred_np[i, 2]
                })

    if not raw_results:
        print("‚ùå No results collected.")
        return None

    df = pd.DataFrame(raw_results)
    
    # Group by Object ID
    report = df.groupby('obj_id').agg(
        Samples=('obj_id', 'count'),
        MAE_X_cm=('err_x_cm', 'mean'),
        MAE_Y_cm=('err_y_cm', 'mean'),
        MAE_Z_cm=('err_z_cm', 'mean'),
        Mean_Total_Error_cm=('total_err_cm', 'mean')
    )
    
    print("\n" + "="*65)
    print("üìä TRANSLATION REPORT (Mean Absolute Error in cm)")
    print("="*65)
    print(report.to_string(float_format="{:.2f}".format))
    print("="*65)
    
    # Global Stats
    print(f"\nüèÜ GLOBAL TRANSLATION RESULTS")
    print(f"   ‚û§ Mean Error X: {df['err_x_cm'].mean():.2f} cm")
    print(f"   ‚û§ Mean Error Y: {df['err_y_cm'].mean():.2f} cm")
    print(f"   ‚û§ Mean Error Z: {df['err_z_cm'].mean():.2f} cm (Depth)")
    print(f"   ‚û§ Mean Euclidean Dist: {df['total_err_cm'].mean():.2f} cm")
    print("="*65)
    
    return report, df

report_df, raw_df = evaluate_translation_only(model_transl, test_loader, DEVICE)

In [None]:
import torch
import numpy as np
import os
import trimesh
import pandas as pd
from tqdm.auto import tqdm


# ==========================================
# MAIN: FULL 6D EVALUATION
# ==========================================
def evaluate_full_6d(model_rot, model_trans, dataloader, device, models_dir):
    """
    Evaluates the complete 6D Pose Estimation pipeline.
    Combines RotationNet + TranslationNet predictions.
    Computes ADD Metric and ADD-0.1d Accuracy.
    """
    model_rot.eval()
    model_trans.eval()
    
    # 1. Load 3D Models Info
    try:
        all_obj_ids = [s['object_id'] for s in dataloader.dataset]
    except:
        all_obj_ids = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]

    points_dict, diameters_dict = load_models_info(models_dir, all_obj_ids)
    
    raw_results = []
    
    print("\nüöÄ Starting Full 6D Pose Benchmark...")
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            # Inputs
            imgs = batch['image'].to(device)
            bbox_info = batch['bbox_info'].to(device)
            obj_ids = batch['object_id']
            
            # Ground Truth
            gt_quats = batch['rotation'].to(device)    # (B, 4)
            gt_trans = batch['translation'].to(device) # (B, 3) in Meters
            
            # --- PREDICTIONS ---
            # 1. Predict Rotation
            pred_quats = model_rot(imgs)
            
            # 2. Predict Translation
            pred_trans = model_trans(imgs, bbox_info)
            
            # Convert to Numpy
            pred_q_np = pred_quats.cpu().numpy()
            pred_t_np = pred_trans.cpu().numpy()
            gt_q_np = gt_quats.cpu().numpy()
            gt_t_np = gt_trans.cpu().numpy()
            
            batch_size = imgs.shape[0]
            for i in range(batch_size):
                curr_id = int(obj_ids[i])
                
                if curr_id not in points_dict: continue
                
                # --- METRIC COMPUTATION (ADD) ---
                # Computes the average distance between transformed model points
                # using GT pose vs Predicted pose.
                add_error = compute_ADD_metric_quaternion(
                    model_points=points_dict[curr_id],
                    gt_quat=gt_q_np[i],
                    gt_translation=gt_t_np[i],
                    pred_quat=pred_q_np[i],
                    pred_translation=pred_t_np[i]
                )
                
                # --- ACCURACY CHECK ---
                diam = diameters_dict[curr_id]
                threshold = diam * 0.1 # 10% of diameter
                is_correct = add_error < threshold
                
                # Store Data
                raw_results.append({
                    'obj_id': curr_id,
                    'diameter_cm': diam * 100,
                    'add_error_cm': add_error * 100,
                    'is_correct': is_correct,
                    'err_trans_cm': np.linalg.norm(pred_t_np[i] - gt_t_np[i]) * 100 # Trans error only
                })

    if not raw_results:
        print("‚ùå No results collected.")
        return None, None

    # --- REPORTING ---
    df = pd.DataFrame(raw_results)
    
    report = df.groupby('obj_id').agg(
        Samples=('obj_id', 'count'),
        Diameter_cm=('diameter_cm', 'first'),
        ADD_Error_cm=('add_error_cm', 'mean'), # Combined Error (Rot + Trans)
        Trans_Error_cm=('err_trans_cm', 'mean'), # Translation Error only
        Accuracy_pct=('is_correct', 'mean')
    )
    
    report['Accuracy_pct'] = report['Accuracy_pct'] * 100
    
    print("\n" + "="*80)
    print("üìä FULL 6D POSE REPORT (Rotation + Translation)")
    print("="*80)
    print(report.to_string(float_format="{:.2f}".format))
    print("="*80)
    
    # Global Stats
    total_acc = (df['is_correct'].sum() / len(df)) * 100
    print(f"\nüèÜ GLOBAL 6D RESULTS")
    print(f"   ‚û§ Mean ADD Error:      {df['add_error_cm'].mean():.2f} cm")
    print(f"   ‚û§ Mean Trans. Error:   {df['err_trans_cm'].mean():.2f} cm")
    print(f"   ‚û§ Final Accuracy:      {total_acc:.2f} %")
    print("="*80)
    
    return report, df

MODELS_ROOT = '/kaggle/input/line-mode/Linemod_preprocessed/models'
report_df, raw_df = evaluate_full_6d(model_rot, model_transl, test_loader, DEVICE, MODELS_ROOT)