In [None]:
import os
import sys

# Clone or pull part
repo_url = "https://github.com/fraco03/6D_pose.git"
repo_dir = "/kaggle/working/6D_pose"   #Modify here for kaggle
branch = "pose_rgbd"

# Clone if missing
if not os.path.exists(repo_dir):
    !git clone -b {branch} {repo_url}
    print(f"Cloned {repo_url} to {repo_dir}")
else:
    %cd {repo_dir}
    !git fetch origin
    !git checkout {branch}
    !git reset --hard origin/{branch}
    # %cd ..
    print(f"Updated {repo_url} to {repo_dir}")

# Add repository to Python path
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)


In [None]:
%cd ..
!gdown --fuzzy https://drive.google.com/file/d/1zNthSyiBdPUfn7BmUKPbKoGgQdG1vGnS/view?usp=drive_link -O Linemod_preprocessed.zip
!unzip Linemod_preprocessed.zip
%cd 6D_pose

In [None]:
from google.colab import drive
from utils.load_data import mount_drive

# Mounting part
mount_drive()

In [None]:
%mv Linemod_preprocessed working/

In [None]:
# dataset_root = "/content/drive/MyDrive/Linemod_preprocessed" #Modify here for kaggle
dataset_root = "../../Linemod_preprocessed_small"
# dataset_root = "/content/Linemod_preprocessed"
# dataset_root = "/kaggle/working/Linemod_preprocessed"

print("\n‚úÖ Setup complete!")
print(f"üìÅ Dataset path: {dataset_root}")


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%mv Linemod_preprocessed ./working

In [None]:
!pip install plyfile

In [None]:
import sys

sys.path.append('../..')

In [None]:
from src.pose_pointnet.dataset import PointNetLineModDataset

train_dataset = PointNetLineModDataset(
    root_dir=dataset_root,
    split="train"
)

test_dataset = PointNetLineModDataset(
    root_dir=dataset_root,
    split="test"
)


In [None]:
import torch
sample = train_dataset[0]

print(f"Sample keys: {sample.keys()}")
for key, value in sample.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: Tensor of shape {value.shape} and dtype {value.dtype}")
    else:
        print(f"  {key}: {type(value)} with value {value}")

In [None]:
import torch

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
!pip install plyfile

In [None]:
from utils.linemod_config import get_linemod_config
import numpy as np
import torch


linemod_config = get_linemod_config(dataset_root)

all_model_points = []
NUM_POINTS = 1000  # Number of points to sample from each model
VALID_OBJ_IDS = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15] 
for obj_id in VALID_OBJ_IDS:
    model_points = linemod_config.get_model_3d(obj_id, unit='m')  # (N, 3)
    if model_points.shape[0] >= NUM_POINTS:
        choice = np.random.choice(model_points.shape[0], NUM_POINTS, replace=False)
    else:
        choice = np.random.choice(model_points.shape[0], NUM_POINTS, replace=True)
    model_points = model_points[choice, :]
    all_model_points.append(torch.tensor(model_points, dtype=torch.float32))
all_model_points = torch.stack(all_model_points, dim=0)  # (Num_Classes, NUM_POINTS, 3)
all_model_points = all_model_points.to(device)

max_obj_id = max(VALID_OBJ_IDS)

# Create a lookup table: obj_id -> index
obj_id_to_idx = torch.full((max_obj_id + 1,), -1, dtype=torch.long, device=device)
for idx, obj_id in enumerate(VALID_OBJ_IDS):
    obj_id_to_idx[obj_id] = idx


In [None]:
all_model_points.shape

In [None]:
from src.pose_pointnet.loss import MultiObjectPointMatchingLoss
import torch.nn as nn
from src.pose_pointnet.model import PointNetPoseModel
from torch.optim import Adam

model = PointNetPoseModel()

if torch.cuda.device_count() > 1:
    print(f"üî• Using {torch.cuda.device_count()} GPU!")
    model = nn.DataParallel(model)

model = model.to(device)

# Loss function and optimizer
criterion = MultiObjectPointMatchingLoss(all_model_points)
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
from torch.utils.data import DataLoader

batch_size = 64  #double GPU

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
from tqdm import tqdm
import os
import torch
from datetime import datetime

# ==========================================
# 0. SETUP AND CONFIGURATION
# ==========================================
num_epochs = 25  # PointNet converges relatively fast
best_test_loss = float('inf')
batch_size = 32  # Adjust based on your GPU VRAM

# Setup checkpoint directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# checkpoint_dir = f'/kaggle/working/POINTNET_{timestamp}'
checkpoint_dir = f'./POINTNET_{timestamp}'
os.makedirs(checkpoint_dir, exist_ok=True)

# Trackers for plotting
train_losses = []
test_losses = []

print(f"üöÄ Starting PointNet Training on {device}")
print(f"üìÅ Checkpoints will be saved to: {checkpoint_dir}")
print(f"üó∫Ô∏è  Object ID Mapping created for {len(obj_id_to_idx)} objects.")

# ==========================================
# 1. TRAINING LOOP
# ==========================================
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    # Initialize progress bar
    train_pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}] - Training")
    
    for batch in train_pbar:
        # Move data to GPU
        # PointNet input: (Batch, 3, Num_Points)
        points = batch['points'].to(device)  
        
        # Auxiliary data for reconstruction and loss
        centroids = batch['centroid'].to(device)       # (B, 3)
        gt_rotations = batch['rotation'].to(device)    # (B, 4)
        gt_t_absolute = batch['gt_translation'].to(device) # (B, 3) - Absolute target
        
        # Handle Object IDs for Loss Indexing
        raw_obj_ids = batch['object_id'].tolist()
        # Map raw IDs (e.g., 15) to buffer indices (e.g., 12)
        target_indices = torch.tensor(
            [obj_id_to_idx[oid] for oid in raw_obj_ids], 
            dtype=torch.long, device=device
        )

        # --- FORWARD PASS ---
        optimizer.zero_grad()
        
        # The network predicts: 
        # 1. Quaternion (pred_q)
        # 2. Residual Translation relative to centroid (pred_t_res)
        pred_q, pred_t_res = model(points)

        # --- RECONSTRUCTION ---
        # Reconstruct absolute translation for the ADD Loss
        # Absolute_Pos = Centroid + Residual
        pred_t_abs = centroids + pred_t_res

        # --- LOSS CALCULATION ---
        # Using MultiObjectPointMatchingLoss (ADD metric)
        loss = criterion(
            pred_q=pred_q, 
            pred_t=pred_t_abs,   # Pass the reconstructed absolute translation
            gt_q=gt_rotations, 
            gt_t=gt_t_absolute, 
            class_indices=target_indices
        )
        
        loss.backward()
        optimizer.step()

        # Update stats
        epoch_loss += loss.item()
        train_pbar.set_postfix({'ADD Loss (m)': f"{loss.item():.4f}"})

    # Calculate average training loss
    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # ==========================================
    # 2. VALIDATION LOOP
    # ==========================================
    model.eval()
    test_loss = 0.0
    val_pbar = tqdm(test_loader, desc="Validating")
    
    with torch.no_grad():
        for batch in val_pbar:
            # Move data to GPU
            points = batch['points'].to(device)
            centroids = batch['centroid'].to(device)
            gt_rotations = batch['rotation'].to(device)
            gt_t_absolute = batch['gt_translation'].to(device)
            
            # Map IDs
            raw_obj_ids = batch['object_id'].tolist()
            target_indices = torch.tensor(
                [obj_id_to_idx[oid] for oid in raw_obj_ids], 
                dtype=torch.long, device=device
            )

            # Forward
            pred_q, pred_t_res = model(points)

            # Reconstruction
            pred_t_abs = centroids + pred_t_res

            # Loss
            loss = criterion(
                pred_q=pred_q, 
                pred_t=pred_t_abs, 
                gt_q=gt_rotations, 
                gt_t=gt_t_absolute, 
                class_indices=target_indices
            )
            
            test_loss += loss.item()
            val_pbar.set_postfix({'Val Loss': f"{loss.item():.4f}"})

    # Calculate average validation loss
    avg_test_loss = test_loss / len(test_loader)
    test_losses.append(avg_test_loss)
    
    print(f"üìä Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_train_loss:.4f} m | Val Loss: {avg_test_loss:.4f} m")

    # ==========================================
    # 3. CHECKPOINT SAVING
    # ==========================================
    if avg_test_loss < best_test_loss:
        best_test_loss = avg_test_loss
        checkpoint_path = os.path.join(checkpoint_dir, "best_model.pth")
        
        # Handle DataParallel state_dict if necessary
        model_state = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
        
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_test_loss,
            'config': { 
                'num_points': 1024, # Useful for inference
                'obj_map': obj_id_to_idx
            }
        }, checkpoint_path)
        print(f"‚úÖ New Record! Model saved with Loss: {best_test_loss:.4f} m")
    else:
        print(f"‚è≥ No improvement (Best: {best_test_loss:.4f} m)")
    
    print("-" * 60)

In [None]:
import matplotlib.pyplot as plt

# Create plots directory
# plots_dir = "plots"
plots_dir = checkpoint_dir
os.makedirs(plots_dir, exist_ok=True)

# Plot 1: Training vs Test Loss
plt.figure(figsize=(10, 6))
epochs_range = range(1, len(test_losses)+1)
plt.plot(range(1, len(train_losses)+1), train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=6)
plt.plot(range(1, len(test_losses)+1), test_losses, 'r-s', label='Test Loss', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training vs Test Loss', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
loss_plot_path = os.path.join(plots_dir, "loss_comparison.png")
plt.savefig(loss_plot_path, dpi=300, bbox_inches='tight')
print(f"‚úÖ Plot saved: {loss_plot_path}")
plt.show()

# Plot 2: Only Training Loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses)+1), train_losses, 'b-o', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Training Loss', fontsize=12)
plt.title('Training Loss Over Epochs', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
train_loss_path = os.path.join(plots_dir, "training_loss.png")
plt.savefig(train_loss_path, dpi=300, bbox_inches='tight')
print(f"‚úÖ Plot saved: {train_loss_path}")
plt.show()

# Plot 3: Only Test Loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(test_losses)+1), test_losses, 'r-s', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Test Loss', fontsize=12)
plt.title('Test Loss Over Epochs', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.axhline(y=best_test_loss, color='g', linestyle='--', label=f'Best: {best_test_loss:.4f}', linewidth=2)
plt.legend(fontsize=11)
plt.tight_layout()
test_loss_path = os.path.join(plots_dir, "test_loss.png")
plt.savefig(test_loss_path, dpi=300, bbox_inches='tight')
print(f"‚úÖ Plot saved: {test_loss_path}")
plt.show()

print(f"\n‚úÖ All plots saved in '{plots_dir}' directory!")

In [None]:
# Save losses
import pickle


losses_dict = {
    'train_losses': train_losses,
    'test_losses': test_losses
}

losses_path = os.path.join(checkpoint_dir, "losses.pkl")
with open(losses_path, 'wb') as f:
    pickle.dump(losses_dict, f)


# Visualize samples

In [None]:
sample.keys()

In [None]:
import random
import cv2
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from collections import OrderedDict
from utils.projection_utils import setup_projection_utils, visualize_pose_comparison

# ==========================================
# 1. SETUP E CARICAMENTO MODELLO
# ==========================================

# Setup projection utils (assumiamo dataset_root sia definito)
setup_projection_utils(dataset_root)

# Load best model
best_checkpoint_path = os.path.join(checkpoint_dir, "best_model.pth")
if not os.path.exists(best_checkpoint_path):
    raise FileNotFoundError(f"Checkpoint non trovato: {best_checkpoint_path}")

print(f"üìÇ Caricamento checkpoint da: {best_checkpoint_path}")
checkpoint = torch.load(best_checkpoint_path, map_location=device)

state_dict = checkpoint['model_state_dict']

# Rimuovi il prefisso 'module.' se il modello era in DataParallel
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith('module.') else k 
    new_state_dict[name] = v

# Inizializza il modello (Assumiamo PointNetPoseModel sia importata)
# model = PointNetPoseModel(num_points=1024).to(device) # Scommenta se devi instanziare
model.load_state_dict(new_state_dict)
model.eval()

print(f"‚úÖ Modello caricato dall'epoca {checkpoint.get('epoch', '?')} con loss: {checkpoint.get('best_val_loss', '?'):.4f}")

# ==========================================
# 2. SELEZIONE E PREPARAZIONE SAMPLE
# ==========================================

# Seleziona un indice casuale dal test dataset
random_idx = random.randint(0, len(test_dataset) - 1)
sample = test_dataset[random_idx]

print(f"\nüì∑ Visualizing Sample {random_idx}:")
print(f"   Object ID: {sample['object_id']}")
# Gestione robusta nel caso 'img_id' manchi (vecchi dataset)
img_id_display = sample.get('img_id', 'N/A')
print(f"   Image ID: {img_id_display}")

# Recuperiamo l'immagine RGB originale per disegnare sopra
# Nel dataset PointNet non carichiamo l'RGB nel __getitem__, quindi dobbiamo farlo a mano qui
# Costruiamo il path usando le info nel sample o nel config
if 'img_path' in sample:
    img_path = sample['img_path']
else:
    # Fallback: ricostruiamo il path se non √® nel sample
    # (Assumendo struttura LineMod standard)
    obj_id = sample['object_id']
    img_id = sample['img_id']
    img_path = os.path.join(dataset_root, 'data', f"{obj_id:02d}", 'rgb', f"{img_id:04d}.png")

image_bgr = cv2.imread(str(img_path))
if image_bgr is None:
    raise FileNotFoundError(f"Impossibile caricare immagine da: {img_path}")
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

# ==========================================
# 3. INFERENZA POINTNET
# ==========================================

# Prepara i tensori (Aggiungi dimensione batch)
points = sample['points'].unsqueeze(0).to(device)       # (1, 3, N)
centroid = sample['centroid'].unsqueeze(0).to(device)   # (1, 3)

with torch.no_grad():
    # Il modello restituisce rotazione E residuo traslazione
    pred_q, pred_t_res = model(points)
    
    # Ricostruzione Traslazione Assoluta
    # T_abs = Centroid + Residual
    pred_trans_abs = centroid + pred_t_res

# Converti in numpy per visualizzazione
pred_rotation = pred_q[0].cpu().numpy()
pred_translation = pred_trans_abs[0].cpu().numpy()

# Ground Truth
gt_rotation = sample['rotation'].numpy()
gt_translation = sample['gt_translation'].numpy() # O sample['translation']
cam_K = sample['cam_K'].numpy()

print(f"\nüìä Ground Truth vs Prediction:")
print(f"   GT Rotation:   {gt_rotation}")
print(f"   Pred Rotation: {pred_rotation}")
print(f"   GT Trans (m):   {gt_translation}")
print(f"   Pred Trans (m): {pred_translation}")

# Calcola errore distanza (solo per curiosit√†)
dist_error = np.linalg.norm(gt_translation - pred_translation)
print(f"   Translation Error: {dist_error*100:.2f} cm")

# ==========================================
# 4. VISUALIZZAZIONE
# ==========================================

# Visualizza confronto pose
# Nota: La funzione visualize_pose_comparison si aspetta un'immagine RGB (numpy)
img_vis = visualize_pose_comparison(
    image_rgb,
    object_id=sample['object_id'],
    cam_K=cam_K,
    gt_rotation=gt_rotation,
    gt_translation=gt_translation,
    pred_rotation=pred_rotation,
    pred_translation=pred_translation  # <-- ORA USIAMO LA TRASLAZIONE PREDETTA!
)

# Plot con Matplotlib
fig, ax = plt.subplots(1, 1, figsize=(14, 8))
# visualize_pose_comparison ritorna RGB se gli passi RGB, quindi ok
ax.imshow(img_vis)
ax.axis('off')
ax.set_title(f"PointNet Pose - Obj {sample['object_id']} (Err: {dist_error*100:.1f}cm)", fontsize=16, fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\n‚úÖ Visualizzazione completata!")

In [None]:
!pip install trimesh

In [None]:
import torch
import numpy as np
import os
import trimesh
import pandas as pd
from collections import defaultdict
from tqdm.auto import tqdm
from metrics.ADD_metric import compute_ADD_metric_quaternion, compute_ADDs_metric_quaternion
# Ensure you import the correct PointNet model class here
from src.pose_pointnet.model import PointNetPoseModel 

# ==========================================
# 1. LOAD DATA AND DIAMETERS
# ==========================================
def load_models_info(models_dir, obj_ids, num_points=1000):
    """
    Loads sampled points and calculates the DIAMETER of each object.
    (This function remains unchanged as it works on .ply files).
    """
    point_cache = {}
    diameters = {}
    
    unique_ids = list(set(obj_ids))
    print(f"‚è≥ Loading info for {len(unique_ids)} models...")
    
    for oid in tqdm(unique_ids, desc="Mesh Analysis"):
        filename = f"obj_{int(oid):02d}.ply"
        path = os.path.join(models_dir, filename)
        
        if os.path.exists(path):
            mesh = trimesh.load(path)
            # 1. Sample points for ADD metric calculation
            points, _ = trimesh.sample.sample_surface(mesh, num_points)
            point_cache[oid] = points / 1000.0  # Convert mm to Meters
            
            # 2. Diameter Calculation (Max distance in the mesh)
            extents = mesh.extents / 1000.0  # Meters
            diameter = np.linalg.norm(extents)
            diameters[oid] = diameter
        else:
            print(f"‚ö†Ô∏è Missing model file: {path}")
            
    return point_cache, diameters

# ==========================================
# 2. COMPREHENSIVE EVALUATION (POINTNET VERSION)
# ==========================================
def evaluate_comprehensive(model, dataloader, device, models_dir, output_csv="evaluation_results.csv"):
    """
    Evaluates the PointNet model using ADD and ADD-S metrics.
    """
    model.eval()
    
    # --- MAPPING ID TO NAMES (LineMOD Standard) ---
    id_to_name = {
        1: 'ape', 2: 'benchvise', 4: 'camera', 5: 'can', 6: 'cat',
        8: 'driller', 9: 'duck', 10: 'eggbox', 11: 'glue',
        12: 'holepuncher', 13: 'iron', 14: 'lamp', 15: 'phone'
    }

    # Define IDs to evaluate
    all_obj_ids = list(id_to_name.keys())
    
    # Load mesh data (Points and Diameters)
    points_dict, diameters_dict = load_models_info(models_dir, all_obj_ids)
    
    # Data Structures for logging
    errors_dict = defaultdict(list)
    accuracy_stats = defaultdict(lambda: {"correct": 0, "total": 0})

    # Objects requiring ADD-S (Symmetric objects)
    SYMMETRIC_OBJECTS = [10, 11]  # Eggbox, Glue
    
    print("\nüöÄ Starting Comprehensive Benchmark (ADD Error + ADD-0.1d Accuracy)...")
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            # --- UPDATED INPUTS FOR POINTNET ---
            # We load points and centroids, not images
            points = batch['points'].to(device)           # (B, 3, N)
            centroids = batch['centroid'].to(device)      # (B, 3)
            
            gt_quats = batch['rotation'].to(device)       # (B, 4)
            gt_trans = batch['gt_translation'].to(device) # (B, 3) - Absolute GT
            obj_ids = batch['object_id']
            
            # --- FORWARD PASS ---
            # Predict Rotation and Residual Translation
            pred_quats, pred_t_res = model(points)
            
            # --- RECONSTRUCT ABSOLUTE TRANSLATION ---
            # Abs_Trans = Centroid + Residual
            pred_trans_abs = centroids + pred_t_res
            
            batch_size = points.shape[0]
            for i in range(batch_size):
                curr_id = int(obj_ids[i])
                if curr_id not in points_dict: continue

                # Select Metric: ADD-S for symmetric, ADD for others
                metric = compute_ADDs_metric_quaternion if curr_id in SYMMETRIC_OBJECTS else compute_ADD_metric_quaternion
                
                # --- CALCULATE ERROR ---
                # We pass the PREDICTED translation (pred_trans_abs), not the GT one!
                # This evaluates the full 6D pose (Rot + Trans).
                add_error = metric(
                    model_points=points_dict[curr_id],
                    gt_quat=gt_quats[i].cpu().numpy(),
                    gt_translation=gt_trans[i].cpu().numpy(),
                    pred_quat=pred_quats[i].cpu().numpy(),
                    pred_translation=pred_trans_abs[i].cpu().numpy() 
                )
                
                # Store absolute error
                errors_dict[curr_id].append(add_error)
                
                # Calculate Accuracy (Threshold = 10% of diameter)
                threshold = diameters_dict[curr_id] * 0.1
                accuracy_stats[curr_id]["total"] += 1
                if add_error < threshold:
                    accuracy_stats[curr_id]["correct"] += 1

    # ==========================================
    # 3. PANDAS REPORT GENERATION
    # ==========================================
    results_data = []
    
    total_acc_correct = 0
    total_acc_count = 0
    total_errors = []

    sorted_ids = sorted(errors_dict.keys())
    
    for oid in sorted_ids:
        # Error stats
        mean_err_m = np.mean(errors_dict[oid])
        mean_err_cm = mean_err_m * 100.0
        total_errors.extend(errors_dict[oid])
        
        # Accuracy stats
        stats = accuracy_stats[oid]
        acc_perc = (stats["correct"] / stats["total"]) * 100.0 if stats["total"] > 0 else 0.0
        
        total_acc_correct += stats["correct"]
        total_acc_count += stats["total"]
        
        diam_cm = diameters_dict[oid] * 100.0
        
        # Get Class Name
        class_name = id_to_name.get(oid, "Unknown")

        # Append to list
        results_data.append({
            "Object ID": oid,
            "Class Name": class_name,
            "Diameter (cm)": round(diam_cm, 2),
            "Mean ADD Error (cm)": round(mean_err_cm, 2),
            "Accuracy (%)": round(acc_perc, 2),
            "Samples": stats['total']
        })

    # Create DataFrame
    df = pd.DataFrame(results_data)

    # Calculate Global Stats
    global_mean_error_cm = np.mean(total_errors) * 100.0 if total_errors else 0.0
    global_accuracy = (total_acc_correct / total_acc_count * 100.0) if total_acc_count > 0 else 0.0

    # Add Global Row
    global_row = pd.DataFrame([{
        "Object ID": "GLOBAL",
        "Class Name": "ALL", 
        "Diameter (cm)": "-",
        "Mean ADD Error (cm)": round(global_mean_error_cm, 2),
        "Accuracy (%)": round(global_accuracy, 2),
        "Samples": total_acc_count
    }])
    
    df = pd.concat([df, global_row], ignore_index=True)

    # Print and Save
    print("\n" + "="*80)
    print("FINAL EVALUATION REPORT (POINTNET)")
    print("="*80)
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 1000)
    
    print(df.to_string(index=False))
    print("="*80)
    
    df.to_csv(output_csv, index=False)
    print(f"‚úÖ Results saved to {output_csv}")

In [None]:
# --- USAGE ---

# 1. Define Paths
# MODELS_ROOT = '/kaggle/working/Linemod_preprocessed/models'
MODELS_ROOT = '../../Linemod_preprocessed_small/models'
checkpoint_path = checkpoint_dir + "/best_model.pth"

# 2. Load Checkpoint
print(f"üìÇ Loading checkpoint from: {checkpoint_path}")
data = torch.load(checkpoint_path, map_location=device)
state_dict = data['model_state_dict']

# 3. Clean State Dict (Remove 'module.' prefix from DataParallel)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith('module.') else k 
    new_state_dict[name] = v

# 4. Initialize and Load Model
model = PointNetPoseModel()
model.load_state_dict(new_state_dict)
model.to(device)

# 5. Run Evaluation
evaluate_comprehensive(
    model, 
    test_loader, 
    device, 
    MODELS_ROOT, 
    output_csv=checkpoint_dir + '/linemod_results.csv'
)