In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install open3d
!pip install ruamel.yaml
!pip install trimesh
!pip install ultralytics
!pip install pcl
!pip install pyyaml
!pip install plotly

# replace with yolo detection position if you want to skip finetuning
#!unzip -q "/content/drive/MyDrive/2024-25_S2/01TXFSM - MLADL/04_3DPE_PROJECT/04_COLAB_NOTEBOOK/01_DETECTION/02_YOLO/YOLOv11_finetuning.zip"

In [None]:
# ==============================================================================
# BLOCK 1: LIBRARIES IMPORTING AND DATASET PREPROCESSING FOR YOLO FINE-TUNING
# ==============================================================================

import os
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import open3d as o3d
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import json
import glob
import re
import shutil
import sys
from sklearn.model_selection import train_test_split
from PIL import Image
from matplotlib.patches import Rectangle
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
from scipy.spatial.transform import Rotation as R
from scipy.spatial import cKDTree
import trimesh
import subprocess
import datetime
import gc
import random
from tqdm import tqdm
from ultralytics import YOLO
import warnings
warnings.filterwarnings("ignore")

# ==============================================================================
# CONFIGURATION - MODIFY THESE PATHS FOR YOUR SETUP
# ==============================================================================
# Dataset paths
#
#CHANGE THIS DIRECTORY WRT YOUR GDRIVE
LINEMOD_ZIP_PATH = "/content/drive/MyDrive/2024-25_S2/01TXFSM - MLADL/04_3DPE_PROJECT/03_Docs/00_DATASET/Linemod_preprocessed.zip"
LINEMOD_ROOT = "/content/datasets/linemod"
YOLO_DATASET_ROOT = "/content/datasets/linemod/Linemod_preprocessed_yolo_2"

# Model paths
PLY_MODELS_DIR = "/content/Linemod_preprocessed/models"
FINAL_MODEL_DIR = "/content/datasets/linemod/Linemod_preprocessed_yolo_2/pose_models"

# Dataset configuration
OBJECT_IDS = ['01', '02', '04', '05', '06', '08', '09', '10', '11', '12', '13', '14', '15']
OBJECT_NAMES = ['ape', 'benchvise', 'camera', 'can', 'cat', 'driller', 'duck', 'eggbox', 'glue', 'holepuncher', 'iron', 'lamp', 'phone']
OBJECTS_TO_SKIP = ["03", "07"]  # Objects removed in DenseFusion preprocessing

# Dataset split ratios (following DenseFusion paper)
TRAIN_RATIO = 0.7
VAL_RATIO = 0.1
TEST_RATIO = 0.2
RANDOM_SEED = 42

# Depth processing configuration
INCLUDE_DEPTH = True
DEPTH_SUBFOLDER = "depth"
DEPTH_SCALE_FACTORS = {}  # Will store depth scale factors for each object

# ==============================================================================
# INITIAL SETUP
# ==============================================================================
def extract_and_setup_dataset():
    """Extract dataset and create directory structure"""
    print("Setting up Linemod dataset...")

    # Extract dataset
    os.system(f'unzip -q "{LINEMOD_ZIP_PATH}"')

    # Create directory structure
    os.makedirs(f"{LINEMOD_ROOT}/data", exist_ok=True)
    os.system(f'mv "/content/Linemod_preprocessed/data" "{LINEMOD_ROOT}/"')

    print("✓ Dataset extraction completed")

# ==============================================================================
# DATASET CONVERSION UTILITIES
# ==============================================================================

def create_yolo_directories():
    """Create YOLO dataset directory structure"""
    directories = [
        'images/train', 'images/val', 'images/test',
        'labels/train', 'labels/val', 'labels/test'
    ]

    if INCLUDE_DEPTH:
        directories.extend([
            'depth/train', 'depth/val', 'depth/test',
            'metadata'
        ])

    for directory in directories:
        os.makedirs(os.path.join(YOLO_DATASET_ROOT, directory), exist_ok=True)

    print("✓ YOLO directory structure created")

def get_all_samples(dataset_root):
    """Retrieve all available sample indices from object folders"""
    samples = []
    data_path = os.path.join(dataset_root, 'data')

    for folder_id in OBJECT_IDS:
        rgb_folder = os.path.join(data_path, folder_id, "rgb")
        if os.path.exists(rgb_folder):
            sample_ids = sorted([
                int(os.path.splitext(f)[0])
                for f in os.listdir(rgb_folder)
                if f.endswith('.png')
            ])
            samples.extend([(folder_id, sid) for sid in sample_ids])

    return samples

def load_yaml_file(file_path):
    """Load YAML file with error handling"""
    try:
        with open(file_path, 'r') as f:
            return yaml.load(f, Loader=yaml.FullLoader)
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None

def convert_bbox_to_yolo(bbox_linemod, image_width, image_height):
    """
    Convert Linemod bbox [x_min, y_min, width, height] to YOLO format
    YOLO format: [center_x, center_y, width, height] normalized to [0,1]
    """
    if not isinstance(bbox_linemod, (list, np.ndarray)) or len(bbox_linemod) != 4:
        return None

    x_min, y_min, width_px, height_px = bbox_linemod

    # Convert to normalized center coordinates
    center_x = (x_min + width_px / 2.0) / image_width
    center_y = (y_min + height_px / 2.0) / image_height
    width_norm = width_px / image_width
    height_norm = height_px / image_height

    # Clamp to valid range
    center_x = max(0.0, min(1.0, center_x))
    center_y = max(0.0, min(1.0, center_y))
    width_norm = max(0.0, min(1.0, width_norm))
    height_norm = max(0.0, min(1.0, height_norm))

    if width_norm <= 0 or height_norm <= 0:
        return None

    return [center_x, center_y, width_norm, height_norm]

def load_depth_scale_factor(dataset_root, folder_id):
    """Load depth scale factor for proper depth-to-meters conversion"""
    camera_path = os.path.join(dataset_root, 'data', folder_id, 'camera.yml')

    if os.path.exists(camera_path):
        try:
            camera_data = load_yaml_file(camera_path)
            if camera_data:
                return camera_data.get('depth_scale', 1000.0)
        except Exception:
            pass

    return 1000.0  # Default: mm to meters

def process_depth_image(depth_src_path, depth_target_path):
    """Process and link depth image"""
    try:
        if os.path.exists(depth_target_path):
            os.remove(depth_target_path)
        os.symlink(depth_src_path, depth_target_path)
        return True
    except Exception:
        return False

# ==============================================================================
# MAIN CONVERSION PROCESS
# ==============================================================================

def convert_linemod_to_yolo():
    """Convert Linemod dataset to YOLO format"""
    print("Starting Linemod to YOLO conversion...")

    # Debug: Check if data directory exists
    data_dir = os.path.join(LINEMOD_ROOT, 'data')
    print(f"Checking data directory: {data_dir}")
    print(f"Data directory exists: {os.path.exists(data_dir)}")

    if os.path.exists(data_dir):
        subdirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
        print(f"Found subdirectories: {subdirs}")

        # Check RGB folders
        for subdir in subdirs[:3]:  # Check first few
            rgb_dir = os.path.join(data_dir, subdir, 'rgb')
            print(f"  {subdir}/rgb exists: {os.path.exists(rgb_dir)}")
            if os.path.exists(rgb_dir):
                files = os.listdir(rgb_dir)[:5]  # Show first few files
                print(f"    Sample files: {files}")

    # Get all samples and split dataset
    all_samples = get_all_samples(LINEMOD_ROOT)
    print(f"Found {len(all_samples)} total samples")

    # Split dataset (following DenseFusion methodology)
    train_samples, temp_samples = train_test_split(
        all_samples, train_size=TRAIN_RATIO, random_state=RANDOM_SEED, shuffle=True
    )

    if TEST_RATIO > 0:
        val_size_relative = VAL_RATIO / (VAL_RATIO + TEST_RATIO)
        val_samples, test_samples = train_test_split(
            temp_samples, train_size=val_size_relative, random_state=RANDOM_SEED
        )
    else:
        val_samples = temp_samples
        test_samples = []

    print(f"Dataset split: {len(train_samples)} train, {len(val_samples)} val, {len(test_samples)} test")

    # Create sample-to-split mapping
    sample_to_split = {}
    for s in train_samples:
        sample_to_split[s] = 'train'
    for s in val_samples:
        sample_to_split[s] = 'val'
    for s in test_samples:
        sample_to_split[s] = 'test'

    # Load depth scale factors
    if INCLUDE_DEPTH:
        for folder_id in OBJECT_IDS:
            DEPTH_SCALE_FACTORS[folder_id] = load_depth_scale_factor(LINEMOD_ROOT, folder_id)

    # Cache YAML data
    gt_cache = {}
    info_cache = {}

    def get_cached_yaml(folder_id):
        if folder_id not in gt_cache:
            gt_cache[folder_id] = load_yaml_file(
                os.path.join(LINEMOD_ROOT, 'data', folder_id, 'gt.yml')
            )
        if folder_id not in info_cache:

            info_cache[folder_id] = load_yaml_file(
                os.path.join(LINEMOD_ROOT, 'data', folder_id, 'info.yml')
            )

        return gt_cache[folder_id], info_cache[folder_id]

    # Process samples
    processed_count = 0
    samples_with_annotations = 0
    total_annotations = 0
    depth_processed = 0

    # Map object folder IDs to class IDs (1-based indexing: 1 to 13)
    object_id_to_class_id = {obj_id: i+1 for i, obj_id in enumerate(OBJECT_IDS)}

    print(f"Processing {len(all_samples)} samples...")
    print(f"Object ID mapping (1-based): {object_id_to_class_id}")

    for sample_idx, (folder_id, sample_id) in enumerate(all_samples):
        split_name = sample_to_split.get((folder_id, sample_id))
        if not split_name:
            print(f"Warning: Sample ({folder_id}, {sample_id}) not found in split mapping")
            continue

        # Load YAML data
        gt_data, info_data = get_cached_yaml(folder_id)
        if not gt_data or not info_data:
            continue

        # Construct file paths
        img_filename = f"{sample_id:04d}.png"
        img_src_path = os.path.join(LINEMOD_ROOT, 'data', folder_id, 'rgb', img_filename)

        if not os.path.exists(img_src_path):
            continue

        # Linemod standard image dimensions (all images are the same size)
        image_width = 640
        image_height = 480

        # Process annotations
        image_annotations = gt_data.get(sample_id, [])


        yolo_lines = []

        if not image_annotations:
            print(f"Debug: No annotations found for sample {sample_id} in folder {folder_id}")

        for annotation in image_annotations:
            obj_id = annotation.get('obj_id')
            folder_match = f"{obj_id:02d}"
            if folder_match == folder_id:

                bbox = annotation.get('obj_bb')
                cam_R = annotation.get('cam_R_m2c')
                cam_t = annotation.get('cam_t_m2c')
                R11, R12, R13, R21, R22, R23, R31, R32, R33  = cam_R
                tx, ty, tz = cam_t
                if obj_id is None or bbox is None:
                    print(f"Warning: Invalid annotation for sample {sample_id}: obj_id={obj_id}, bbox={bbox}")
                    continue

                # Map object ID to class ID
                folder_match = f"{obj_id:02d}"
                if folder_match not in OBJECT_IDS:
                    print(f"Warning: Object ID {obj_id} ({folder_match}) not in OBJECT_IDS for sample {sample_id}")
                    continue

                class_id = object_id_to_class_id[folder_match]
                yolo_bbox = convert_bbox_to_yolo(bbox, image_width, image_height)

                if yolo_bbox:
                    line = f"{class_id} {yolo_bbox[0]:.6f} {yolo_bbox[1]:.6f} {yolo_bbox[2]:.6f} {yolo_bbox[3]:.6f}"
                    yolo_lines.append(line)
                    line = f"{class_id} {R11:.6f} {R12:.6f} {R13:.6f} {R21:.6f} {R22:.6f} {R23:.6f} {R31:.6f} {R32:.6f} {R33:.6f}"
                    yolo_lines.append(line)
                    line = f"{class_id} {tx:.6f} {ty:.6f} {tz:.6f}"
                    yolo_lines.append(line)

                    total_annotations += 1

        if yolo_lines:
            samples_with_annotations += 1

        # Create target paths
        new_img_name = f"{sample_idx:05d}.png"
        new_label_name = f"{sample_idx:05d}.txt"

        img_target_path = os.path.join(YOLO_DATASET_ROOT, 'images', split_name, new_img_name)
        label_target_path = os.path.join(YOLO_DATASET_ROOT, 'labels', split_name, new_label_name)

        # Write label file
        try:
            with open(label_target_path, 'w') as f:
                for line in yolo_lines:
                    f.write(line + '\n')
        except Exception as e:
            print(f"Warning: Could not write label file {label_target_path}: {e}")
            continue

        # Link image file
        try:
            if os.path.exists(img_target_path):
                os.remove(img_target_path)
            os.symlink(img_src_path, img_target_path)
        except Exception as e:
            print(f"Warning: Could not create symlink for {img_src_path}: {e}")
            # Try copying instead
            try:
                import shutil
                shutil.copy2(img_src_path, img_target_path)
            except Exception as e2:
                print(f"Warning: Could not copy image {img_src_path}: {e2}")
                continue

        # Process depth if enabled
        if INCLUDE_DEPTH:
            depth_src_path = os.path.join(LINEMOD_ROOT, 'data', folder_id, DEPTH_SUBFOLDER, img_filename)
            if os.path.exists(depth_src_path):
                depth_target_path = os.path.join(YOLO_DATASET_ROOT, 'depth', split_name, new_img_name)
                if process_depth_image(depth_src_path, depth_target_path):
                    depth_processed += 1

        processed_count += 1

        # Progress reporting
        if processed_count % 2000 == 0:
            print(f"Processed {processed_count}/{len(all_samples)} samples...")

    print(f"Conversion completed:")
    print(f"  Processed: {processed_count} samples")
    print(f"  Annotations: {total_annotations}")
    print(f"  Samples with annotations: {samples_with_annotations}")
    if INCLUDE_DEPTH:
        print(f"  Depth images: {depth_processed}")

def create_data_yaml():
    """Create YOLO dataset configuration file"""
    config_content = f"""# Linemod Dataset Configuration for DenseFusion
# Following original paper specifications

# Image directories
train: {os.path.join(YOLO_DATASET_ROOT, 'images', 'train')}
val: {os.path.join(YOLO_DATASET_ROOT, 'images', 'val')}
test: {os.path.join(YOLO_DATASET_ROOT, 'images', 'test')}
"""

    if INCLUDE_DEPTH:
        config_content += f"""
# Depth directories
depth_train: {os.path.join(YOLO_DATASET_ROOT, 'depth', 'train')}
depth_val: {os.path.join(YOLO_DATASET_ROOT, 'depth', 'val')}
depth_test: {os.path.join(YOLO_DATASET_ROOT, 'depth', 'test')}

# Depth metadata
depth_scales: {os.path.join(YOLO_DATASET_ROOT, 'metadata', 'depth_scales.json')}
"""

    config_content += f"""
# Object classes (13 objects after removing 03 and 07)
nc: {len(OBJECT_IDS)}
names: {OBJECT_NAMES}

# Class mapping: 1-13 (not 0-12)
# Class 1: ape, Class 2: benchvise, Class 3: camera, etc.
"""

    config_path = os.path.join(YOLO_DATASET_ROOT, 'data.yaml')
    with open(config_path, 'w') as f:
        f.write(config_content)

    # Save depth scale factors if enabled
    if INCLUDE_DEPTH:
        depth_metadata_path = os.path.join(YOLO_DATASET_ROOT, 'metadata', 'depth_scales.json')
        with open(depth_metadata_path, 'w') as f:
            json.dump(DEPTH_SCALE_FACTORS, f, indent=2)

    print(f"✓ Created data.yaml configuration")

# ==============================================================================
# 3D MODEL PROCESSING
# ==============================================================================

def process_ply_models():
    """Process and renumber PLY model files"""
    print("Processing 3D PLY models...")

    # Find and filter PLY files
    all_ply_files = glob.glob(os.path.join(PLY_MODELS_DIR, "obj_*.ply"))

    files_to_process = []
    for ply_path in all_ply_files:
        base_name = os.path.basename(ply_path)
        match = re.match(r"obj_(\d+)\.ply", base_name)

        if match:
            file_number = match.group(1)
            if file_number not in OBJECTS_TO_SKIP:
                files_to_process.append((int(file_number), ply_path))

    # Sort and renumber
    files_to_process.sort(key=lambda x: x[0])

    for new_index, (original_number, original_path) in enumerate(files_to_process):
        new_number = new_index + 1
        new_filename = f"obj_{new_number:02d}.ply"
        new_path = os.path.join(PLY_MODELS_DIR, new_filename)

        if original_path != new_path:
            os.rename(original_path, new_path)

    print(f"✓ Processed {len(files_to_process)} PLY models")

def process_model_info_yaml():
    """Process models_info.yml to match renumbered objects"""
    os.makedirs(FINAL_MODEL_DIR, exist_ok=True)

    # Move and process models_info.yml
    source_yml = os.path.join(PLY_MODELS_DIR, "models_info.yml")
    target_yml = os.path.join(FINAL_MODEL_DIR, "models_info.yml")

    shutil.move(source_yml, target_yml)

    # Load and modify YAML
    yaml_processor = YAML()
    yaml_processor.preserve_quotes = True
    yaml_processor.width = sys.maxsize

    with open(target_yml, "r") as f:
        data = yaml_processor.load(f)

    # Remove objects 3 and 7
    for key in [3, 7]:
        data.pop(key, None)

    # Set flow style for dictionaries
    for key, value in data.items():
        if isinstance(value, dict):
            cm = CommentedMap(value)
            cm.fa.set_flow_style()
            data[key] = cm

    # Renumber keys consecutively starting from 1
    new_data = CommentedMap()
    for i, (old_key, value) in enumerate(sorted(data.items()), 1):
        if isinstance(value, dict):
            cm = CommentedMap(value)
            cm.fa.set_flow_style()
            new_data[i] = cm
        else:
            new_data[i] = value

    # Save updated YAML
    with open(target_yml, "w") as f:
        yaml_processor.dump(new_data, f)

    print("✓ Processed models_info.yml")

def setup_model_directories():
    """Create necessary model directories for DenseFusion (simplified)"""
    global PLY_MODELS_DIR  # Declare global at the start

    # Only create directories we actually need
    directories = [
        'pose_models/models',      # 3D PLY models
        'trained_models',          # Saved DenseFusion models
        'checkpoints'              # Training checkpoints
    ]

    for directory in directories:
        dir_path = os.path.join(YOLO_DATASET_ROOT, directory)
        os.makedirs(dir_path, exist_ok=True)
        print(f"  Created: {directory}")

    # Debug: Check if source PLY directory exists and has files
    print(f"Checking source PLY directory: {PLY_MODELS_DIR}")
    print(f"Source PLY dir exists: {os.path.exists(PLY_MODELS_DIR)}")

    if os.path.exists(PLY_MODELS_DIR):
        ply_files = [f for f in os.listdir(PLY_MODELS_DIR) if f.endswith('.ply')]
        print(f"Found PLY files: {ply_files}")

        # Copy PLY models to final location
        ply_target = os.path.join(FINAL_MODEL_DIR, 'models')
        print(f"Target PLY directory: {ply_target}")

        # Create target directory if it doesn't exist
        os.makedirs(ply_target, exist_ok=True)

        # Copy each PLY file individually (more reliable than copytree)
        copied_count = 0
        for ply_file in ply_files:
            src_path = os.path.join(PLY_MODELS_DIR, ply_file)
            dst_path = os.path.join(ply_target, ply_file)
            try:
                shutil.copy2(src_path, dst_path)
                copied_count += 1
                print(f"    Copied: {ply_file}")
            except Exception as e:
                print(f"    Failed to copy {ply_file}: {e}")

        print(f"  Copied {copied_count}/{len(ply_files)} PLY models to: {ply_target}")

        # Verify the copy worked
        if os.path.exists(ply_target):
            copied_files = [f for f in os.listdir(ply_target) if f.endswith('.ply')]
            print(f"  Verification: {len(copied_files)} PLY files in target directory")

    else:
        print(f"  WARNING: Source PLY directory not found: {PLY_MODELS_DIR}")

        # Try to find PLY files in alternate locations
        alternate_locations = [
            "/content/Linemod_preprocessed/models",
            "/content/datasets/linemod/models",
            "/content/datasets/linemod/Linemod_preprocessed/models"
        ]

        for alt_path in alternate_locations:
            print(f"  Checking alternate location: {alt_path}")
            if os.path.exists(alt_path):
                ply_files = [f for f in os.listdir(alt_path) if f.endswith('.ply')]
                if ply_files:
                    print(f"    Found {len(ply_files)} PLY files at: {alt_path}")
                    print(f"    Files: {ply_files}")

                    # Update the global variable
                    PLY_MODELS_DIR = alt_path

                    # Copy from this location
                    ply_target = os.path.join(FINAL_MODEL_DIR, 'models')
                    os.makedirs(ply_target, exist_ok=True)

                    copied_count = 0
                    for ply_file in ply_files:
                        src_path = os.path.join(alt_path, ply_file)
                        dst_path = os.path.join(ply_target, ply_file)
                        try:
                            shutil.copy2(src_path, dst_path)
                            copied_count += 1
                        except Exception as e:
                            print(f"      Failed to copy {ply_file}: {e}")

                    print(f"    Copied {copied_count} PLY files from alternate location")
                    break

    print("✓ Essential model directories created")

# ==============================================================================
# MAIN EXECUTION
# ==============================================================================

def main():
    """Main preprocessing pipeline"""
    print("="*60)
    print("LINEMOD DATASET PREPROCESSING FOR DENSEFUSION")
    print("="*60)

    # Step 1: Extract and setup dataset
    extract_and_setup_dataset()

    # Step 2: Create YOLO directory structure
    create_yolo_directories()

    # Step 3: Convert dataset to YOLO format
    convert_linemod_to_yolo()

    # Step 4: Create configuration files
    create_data_yaml()

    # Step 5: Process 3D models
    process_ply_models()
    process_model_info_yaml()
    setup_model_directories()

    print("="*60)
    print(f"Dataset ready for DenseFusion training:")
    print(f"  YOLO dataset: {YOLO_DATASET_ROOT}")
    print(f"  3D models: {FINAL_MODEL_DIR}")
    print(f"  Configuration: {YOLO_DATASET_ROOT}/data.yaml")
    print("="*60)
    print("\n✓ Block 1 completed: Preprocessing completed succesfully")

if __name__ == "__main__":
    main()

In [None]:
# ==============================================================================
# BLOCK 2: SETUP AND CONFIGURATION - possible rerunning from here needed after block 4 running if using cuda
# ==============================================================================

# ==============================================================================
# CONFIGURATION PARAMETERS
# ==============================================================================

class Config:
    """Configuration class for DenseFusion"""

    def __init__(self):
        # Paths - UPDATE THESE FOR YOUR SETUP
        self.LINEMOD_ROOT = "/content/datasets/linemod/Linemod_preprocessed_yolo_2"
        self.DATA_YAML_PATH = os.path.join(LINEMOD_ROOT, 'data.yaml')
        self.PLY_MODELS_DIR = "/content/datasets/linemod/Linemod_preprocessed_yolo_2/pose_models/models"
        self.DIAMETER_INFO_PATH = "/content/datasets/linemod/Linemod_preprocessed_yolo_2/pose_models/models_info.yml"
        self.MODELS_SAVE_DIR = "/content/drive/MyDrive/2024-25_S2/01TXFSM - MLADL/04_3DPE_PROJECT/04_COLAB_NOTEBOOK/02_POSE_ESTIMATION/00_DENSEFUSION/02_DEV/01_FINAL_DEV_20250621/01_MODEL_STATS"
        self.CHECKPOINTS_DIR = "/content/datasets/linemod/Linemod_preprocessed_yolo_2/checkpoints"
        self.YOLO_PROJ_NAME = "YOLO_Linemod"
        self.YOLO_NAME = "yolov11s_adam_finetuning"
        self.YOLO_SAVING_PATH = "/content/"+self.YOLO_PROJ_NAME+"/"+self.YOLO_NAME+"/weights/best.pt"
        self.YOLO_MODEL_PATH = "/content/YOLOv11_finetuning/weights/best.pt" # YOLO has been moved here for dev purposes
        self.MODELS_NAME = "EC_500_512_MLP" # change name to load model (for dev purpose e.g. XX_npoints_npatches_MLP/TRS)

        # Device configuration
        self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

        # Scale factors
        self.MODEL_SCALE_MM_TO_M = 0.001
        self.DEPTH_SCALE_MM_TO_M = 1000.0

        # Camera intrinsics (Linemod standard)
        self.K = np.array([
            [572.4114, 0,        325.2611],
            [0,        573.57043, 242.04899],
            [0,        0,        1        ]
        ], dtype=np.float32)

        # List of symmetric objects
        self.SYMMETRIC_LIST = [7,8]

        # YOLO hyperparameters
        self.YOLO_NUM_EPOCHS = 10
        self.YOLO_IMG_SIZE = 640
        self.YOLO_BATCH_SIZE = 10
        self.YOLO_LR = 0.008

        # Model configuration
        self.USE_SEGMENTATION = True
        self.MAX_EVAL_SAMPLES = 100
        #Transformer fusion options
        self.USE_TRANSFORMER_FUSION = False
        self.TRANSFORMER_HEADS = 2
        self.TRANSFORMER_LAYERS = 4
        self.TRANSFORMER_DIM = 128
        self.TRANSFORMER_DROPOUT = 0.1

        # Model hyperparameters
        self.NUM_POINTS = 500
        self.PATCH_SIZE = 512

        # Training hyperparameters
        self.BATCH_SIZE = 12
        self.NUM_EPOCHS = 15
        self.LEARNING_RATE = 1e-4
        self.USE_MIXED_PRECISION = True
        self.GRADIENT_ACCUMULATION_STEPS = 2

    def setup_environment(self):
        """Setup optimized environment for training"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.deterministic = False

        # Set environment variables
        os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
        os.environ['TORCH_USE_CUDA_DSA'] = '1'

        # Multiprocessing setup
        import torch.multiprocessing as mp
        try:
            if mp.get_start_method(allow_none=True) != 'spawn':
                mp.set_start_method('spawn', force=True)
        except RuntimeError:
            pass

    def verify_paths(self):
        """Verify all required paths exist"""
        paths = {
            'LINEMOD dataset': self.LINEMOD_ROOT,
            'YOLO model': self.YOLO_MODEL_PATH,
            'PLY models': self.PLY_MODELS_DIR,
            'Diameter info': self.DIAMETER_INFO_PATH
        }

        all_good = True
        for name, path in paths.items():
            if os.path.exists(path):
                print(f"✓ {name}: {path}")
            else:
                print(f"✗ {name} NOT FOUND: {path}")
                all_good = False

        return all_good

    def print_config(self):
        """Print configuration summary"""
        print("="*60)
        print("DENSEFUSION CONFIGURATION")
        print("="*60)
        print(f"Device: {self.DEVICE}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"GPU: {torch.cuda.get_device_name()}")

        print(f"\nTraining Configuration:")
        print(f"  Batch size: {self.BATCH_SIZE}")
        print(f"  Epochs: {self.NUM_EPOCHS}")
        print(f"  Learning rate: {self.LEARNING_RATE}")
        print(f"  Points per sample: {self.NUM_POINTS}")
        print(f"  Patch size: {self.PATCH_SIZE}")

        print(f"\nFeatures:")
        print(f"  Segmentation: {self.USE_SEGMENTATION}")
        print(f"  Mixed precision: {self.USE_MIXED_PRECISION}")

# Initialize configuration
config = Config()
config.setup_environment()
config.print_config()

# Verify paths
if not config.verify_paths():
    print("\n⚠ Please update the paths in the Config class before proceeding")

print("\n✓ Block 2 completed: Configuration ready")

In [None]:
# ==============================================================================
# BLOCK 3: YOLO FINE-TUNING (skip if unzipped finetuned yolov11 at start)
# ==============================================================================

# We'll fine-tune the pretrained YOLOv11s as v11 it's the current standard
model = YOLO('yolo11s.pt')
print(f"Model loaded: {model.model.__class__.__name__}")

results = model.train(
    data=config.DATA_YAML_PATH,
    epochs=config.YOLO_NUM_EPOCHS,
    imgsz=config.YOLO_IMG_SIZE,
    batch=config.YOLO_BATCH_SIZE,
    optimizer='Adam',
    device=device,
    lr0=config.YOLO_LR,
    patience=3,
    project=config.YOLO_PROJ_NAME,
    name=config.YOLO_NAME,
    cache='disk',
)
print("\nTraining finished.")

# Fine-tuned YOLO evaluation
best_model = YOLO(config.YOLO_PATH)
metrics = best_model.val(
    data=config.DATA_YAML_PATH,
    split="test",
)

# Display results
print(metrics.results_dict)
print("\n✓ Block 3 completed: YOLO finetuning complete")

In [None]:
# ==============================================================================
# BLOCK 4: UTILITY FUNCTIONS
# ==============================================================================

def cleanup_memory():
    """Clean up GPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

def load_dataset_config(linemod_root):
    """Load dataset configuration from data.yaml"""
    config_path = os.path.join(linemod_root, 'data.yaml')
    with open(config_path, 'r') as f:
        dataset_conf = yaml.safe_load(f)

    # Convert relative paths to absolute
    for split_key in ['train', 'val', 'test', 'depth_train', 'depth_val', 'depth_test']:
        if split_key in dataset_conf and isinstance(dataset_conf[split_key], str):
            if not os.path.isabs(dataset_conf[split_key]):
                dataset_conf[split_key] = os.path.join(linemod_root, dataset_conf[split_key])

    return dataset_conf

def load_model_diameters(diameter_yml_path):
    """Load object model diameters from YAML file"""
    with open(diameter_yml_path, 'r') as f:
        diameter_data = yaml.safe_load(f)

    model_diameters = {}
    for class_id, info in diameter_data.items():
        if isinstance(info, dict) and 'diameter' in info:
            internal_class_id = int(class_id) - 1  # Convert to 0-based
            model_diameters[internal_class_id] = float(info['diameter']) * config.MODEL_SCALE_MM_TO_M

    return model_diameters

def load_yolo_model(model_path):
    """Load and validate YOLO model"""
    try:
        yolo_model = YOLO(model_path)
        # Test with dummy image
        test_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
        results = yolo_model(test_image, verbose=False)
        print(f"✓ YOLO model loaded successfully")
        return yolo_model
    except Exception as e:
        print(f"✗ Failed to load YOLO model: {e}")
        return None

def decompose_pose_numpy(pose_numpy):
    """Decompose 7D pose [tx, ty, tz, qw, qx, qy, qz] into rotation matrix and translation"""
    t = pose_numpy[:3]
    q_wxyz = pose_numpy[3:]

    # Normalize quaternion
    norm_q = np.linalg.norm(q_wxyz)
    if norm_q < 1e-6:
        return np.identity(3), t
    q_wxyz = q_wxyz / norm_q

    # Convert to rotation matrix (scipy expects [x, y, z, w])
    rot = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]]).as_matrix()
    return rot, t

def compute_add_metric(pred_pose_numpy, gt_pose_numpy, model_vertices):
    """Compute ADD (Average Distance) metric for pose estimation"""
    if model_vertices is None or model_vertices.shape[0] == 0:
        return float('inf')

    try:
        R_pred, t_pred = decompose_pose_numpy(pred_pose_numpy)
        R_gt, t_gt = decompose_pose_numpy(gt_pose_numpy)

        # Scale model vertices from mm to meters
        model_points_meters = model_vertices * config.MODEL_SCALE_MM_TO_M

        # Transform model points
        pred_transformed = (R_pred @ model_points_meters.T).T + t_pred
        gt_transformed = (R_gt @ model_points_meters.T).T + t_gt

        # Calculate distances
        distances = np.linalg.norm(pred_transformed - gt_transformed, axis=1)
        return np.mean(distances)

    except Exception as e:
        print(f"ADD computation error: {e}")
        return float('inf')

def compute_add_s_metric(pred_pose_numpy, gt_pose_numpy, model_vertices):
    """Compute ADD-S (symmetric version of ADD) for pose estimation."""
    if model_vertices is None or model_vertices.shape[0] == 0:
        print("ERROR: INVALID 3D MODEL")
        return float('inf')

    try:
        # Decompose predicted and GT poses
        R_pred, t_pred = decompose_pose_numpy(pred_pose_numpy)
        R_gt, t_gt = decompose_pose_numpy(gt_pose_numpy)

        # Scale model vertices from mm to meters
        model_points_meters = model_vertices * config.MODEL_SCALE_MM_TO_M

        # Transform model points
        pred_transformed = (R_pred @ model_points_meters.T).T + t_pred  # [M, 3]
        gt_transformed = (R_gt @ model_points_meters.T).T + t_gt        # [M, 3]

        # Build KD-tree for GT transformed points
        gt_kdtree = cKDTree(gt_transformed)

        # Find nearest GT point for each predicted point
        distances, _ = gt_kdtree.query(pred_transformed, k=1)

        # Return mean of nearest distances
        return np.mean(distances)

    except Exception as e:
        print(f"ADD-S computation error: {e}")
        return float('inf')

def compute_add_metrics_with_thresholds(pred_pose, gt_pose, class_id, sym_list, model_vertices, diameter=None):
    """Compute ADD metric with various success thresholds"""

    if class_id in sym_list:
        add_value = compute_add_s_metric(pred_pose, gt_pose, model_vertices)
    else:
        add_value = compute_add_metric(pred_pose, gt_pose, model_vertices)

    results = {
        "add_value": add_value,
        "add_success_2cm": add_value < 0.02,
        "add_success_5cm": add_value < 0.05,
        "add_success_10cm": add_value < 0.10,
    }

    if diameter is not None and diameter > 0:
        results.update({
            "diameter": diameter,
            "add_success_5p": add_value < (0.05 * diameter),
            "add_success_10p": add_value < (0.10 * diameter),
            "add_success_20p": add_value < (0.20 * diameter),
        })

    return results

def compute_rotation_difference_degrees(pred_pose, gt_pose):
    """Compute rotational difference in degrees between poses"""
    try:
        pred_quat = pred_pose[3:] / np.linalg.norm(pred_pose[3:])
        gt_quat = gt_pose[3:] / np.linalg.norm(gt_pose[3:])

        # Convert to scipy format [qx, qy, qz, qw]
        pred_quat_scipy = [pred_quat[1], pred_quat[2], pred_quat[3], pred_quat[0]]
        gt_quat_scipy = [gt_quat[1], gt_quat[2], gt_quat[3], gt_quat[0]]

        pred_rot = R.from_quat(pred_quat_scipy)
        gt_rot = R.from_quat(gt_quat_scipy)

        # Overall angular difference
        relative_rot = pred_rot * gt_rot.inv()
        overall_angle_deg = np.degrees(relative_rot.magnitude())

        # Per-axis differences
        pred_euler = pred_rot.as_euler('xyz', degrees=True)
        gt_euler = gt_rot.as_euler('xyz', degrees=True)

        diff_x = min(abs(pred_euler[0] - gt_euler[0]), 360 - abs(pred_euler[0] - gt_euler[0]))
        diff_y = min(abs(pred_euler[1] - gt_euler[1]), 360 - abs(pred_euler[1] - gt_euler[1]))
        diff_z = min(abs(pred_euler[2] - gt_euler[2]), 360 - abs(pred_euler[2] - gt_euler[2]))

        return {
            'overall': overall_angle_deg,
            'x_axis': diff_x,
            'y_axis': diff_y,
            'z_axis': diff_z,
            'pred_euler': pred_euler,
            'gt_euler': gt_euler
        }
    except Exception as e:
        return {'overall': float('inf'), 'x_axis': 0, 'y_axis': 0, 'z_axis': 0}

def convert_yolo_bbox_to_pixel(bbox_normalized, image_width, image_height):
    """Convert YOLO normalized bbox to pixel coordinates"""
    xc_n, yc_n, w_n, h_n = bbox_normalized

    xc_px = xc_n * image_width
    yc_px = yc_n * image_height
    w_px = w_n * image_width
    h_px = h_n * image_height

    x1 = max(0, int(xc_px - w_px / 2))
    y1 = max(0, int(yc_px - h_px / 2))
    x2 = min(image_width, int(xc_px + w_px / 2))
    y2 = min(image_height, int(yc_px + h_px / 2))

    return [x1, y1, x2, y2]

def create_directories():
    """Create necessary directories"""
    directories = [config.MODELS_SAVE_DIR, config.CHECKPOINTS_DIR]
    for directory in directories:
        os.makedirs(directory, exist_ok=True)

print("✓ Block 4 completed: Utility functions ready")

In [None]:
# ==============================================================================
# BLOCK 5: SEGMENTATION MODULE
# ==============================================================================

import torchvision
from torchvision import transforms

class DenseFusionSegmentationModule:
    """Segmentation module using Mask R-CNN for instance segmentation"""

    def __init__(self, confidence_threshold=0.5):
        self.confidence_threshold = confidence_threshold
        self.device = config.DEVICE
        self.model = None
        self.transform = transforms.Compose([transforms.ToTensor()])

        # Statistics tracking
        self.stats = {
            'total_calls': 0,
            'successful_segmentations': 0,
            'bbox_fallbacks': 0
        }

        self._initialize_model()

    def _initialize_model(self):
        """Initialize Mask R-CNN model"""
        try:
            self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(
                weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.COCO_V1
            )
            self.model.eval()
            self.model = self.model.to(self.device)
            print("✓ Mask R-CNN model loaded successfully")
        except Exception as e:
            print(f"Failed to initialize Mask R-CNN: {e}")
            self.model = None

    def refine_detection(self, rgb_image, bbox_pixel, class_id=None):
        """Refine YOLO detection using Mask R-CNN segmentation"""
        self.stats['total_calls'] += 1

        if self.model is None:
            return self._bbox_to_mask(rgb_image.shape[:2], bbox_pixel), {
                'source': 'bbox_fallback_no_model'
            }

        try:
            # Prepare image
            if isinstance(rgb_image, np.ndarray):
                if rgb_image.dtype != np.uint8:
                    rgb_image = (rgb_image * 255).astype(np.uint8)
                pil_image = Image.fromarray(rgb_image)
                image_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)

            # Run inference
            with torch.no_grad():
                predictions = self.model(image_tensor)

            if len(predictions) == 0 or len(predictions[0]['masks']) == 0:
                self.stats['bbox_fallbacks'] += 1
                return self._bbox_to_mask(rgb_image.shape[:2], bbox_pixel), {
                    'source': 'bbox_fallback_no_detections'
                }

            # Find best mask
            best_mask, best_info = self._find_best_mask(
                predictions[0], bbox_pixel, rgb_image.shape[:2]
            )

            if best_mask is not None:
                self.stats['successful_segmentations'] += 1
                return best_mask, best_info
            else:
                self.stats['bbox_fallbacks'] += 1
                return self._bbox_to_mask(rgb_image.shape[:2], bbox_pixel), {
                    'source': 'bbox_fallback_low_overlap'
                }

        except Exception as e:
            self.stats['bbox_fallbacks'] += 1
            return self._bbox_to_mask(rgb_image.shape[:2], bbox_pixel), {
                'source': 'bbox_fallback_error',
                'error': str(e)
            }

    def _find_best_mask(self, prediction, yolo_bbox, image_shape):
        """Find the best mask that overlaps with YOLO detection"""
        x1, y1, x2, y2 = map(int, yolo_bbox)
        yolo_area = max(1, (x2 - x1) * (y2 - y1))

        best_mask = None
        best_score = 0
        best_info = {'source': 'bbox_fallback'}

        masks = prediction['masks']
        scores = prediction['scores']
        boxes = prediction['boxes']

        for mask_tensor, score, box in zip(masks, scores, boxes):
            if score < self.confidence_threshold:
                continue

            # Convert mask to numpy
            mask_np = mask_tensor.squeeze().cpu().numpy()
            if mask_np.shape != image_shape:
                mask_np = cv2.resize(mask_np, (image_shape[1], image_shape[0]),
                                   interpolation=cv2.INTER_NEAREST)

            mask_binary = (mask_np > 0.5).astype(np.uint8)

            # Calculate overlap with YOLO bbox
            mask_in_bbox = mask_binary[y1:y2, x1:x2]
            overlap_area = np.sum(mask_in_bbox)
            overlap_ratio = overlap_area / yolo_area if yolo_area > 0 else 0

            # Calculate IoU between predicted box and YOLO box
            pred_x1, pred_y1, pred_x2, pred_y2 = box.cpu().numpy()
            inter_x1 = max(x1, pred_x1)
            inter_y1 = max(y1, pred_y1)
            inter_x2 = min(x2, pred_x2)
            inter_y2 = min(y2, pred_y2)

            iou = 0.0
            if inter_x2 > inter_x1 and inter_y2 > inter_y1:
                inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
                pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
                union_area = yolo_area + pred_area - inter_area
                iou = inter_area / union_area if union_area > 0 else 0

            # Combined score
            combined_score = float(score) * 0.4 + overlap_ratio * 0.3 + iou * 0.3

            if combined_score > best_score and overlap_ratio > 0.1:
                best_score = combined_score
                best_mask = mask_binary
                best_info = {
                    'source': 'mask_rcnn',
                    'confidence': float(score),
                    'overlap': overlap_ratio,
                    'iou': iou
                }

        return best_mask, best_info

    def _bbox_to_mask(self, image_shape, bbox_pixel):
        """Fallback: create mask from bounding box"""
        mask = np.zeros(image_shape, dtype=np.uint8)
        x1, y1, x2, y2 = map(int, bbox_pixel)

        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(image_shape[1], x2), min(image_shape[0], y2)

        mask[y1:y2, x1:x2] = 1
        return mask

    def print_stats(self):
        """Print segmentation statistics"""
        if self.stats['total_calls'] > 0:
            success_rate = self.stats['successful_segmentations'] / self.stats['total_calls']
            print(f"Segmentation Statistics:")
            print(f"  Total calls: {self.stats['total_calls']}")
            print(f"  Success rate: {success_rate:.2%}")

# Initialize segmentation module
segmentation_module = None
if config.USE_SEGMENTATION:
    segmentation_module = DenseFusionSegmentationModule()

print("✓ Block 5 completed: Segmentation module ready")

In [None]:
# ==============================================================================
# BLOCK 6: DENSEFUSION MODEL ARCHITECTURE
# ==============================================================================

class RGBFeatureExtractor(nn.Module):
    """CNN for feature extraction from RGB images"""
    def __init__(self, d_rgb = 32):
        super(RGBFeatureExtractor, self).__init__()

        # Convolutional backbone
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(512)

        # Feature pyramid pooling
        self.psp = nn.AdaptiveAvgPool2d((1, 1))
        self.final_conv = nn.Conv2d(512, d_rgb, kernel_size=1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))

        x = self.psp(x)
        x = self.final_conv(x)
        return x

class PointNetFeatureExtractor(nn.Module):
    """PointNet-style feature extractor for point clouds"""
    def __init__(self, d_geo = 32):
        super(PointNetFeatureExtractor, self).__init__()

        # Point-wise MLPs
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        # Global feature
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, d_geo)

        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        # x: [B, N, 3] -> [B, 3, N]
        x = x.transpose(2, 1)

        # Point-wise convolutions
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        # Global max pooling
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        # MLP for global features
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)

        return x

class TransformerFuser(nn.Module):
    """Fuses a global RGB and a global Point Cloud feature vector using a Transformer."""
    def __init__(self, rgb_feature_dim, geo_feature_dim, embed_dim=config.TRANSFORMER_DIM, num_heads=config.TRANSFORMER_HEADS, num_layers=config.TRANSFORMER_LAYERS, dropout=config.TRANSFORMER_DROPOUT):
        super().__init__()
        self.feature_dim = feature_dim
        self.embed_dim = embed_dim

        # Linear layers to project the 32-dim features into a common embedding space
        self.rgb_proj = nn.Linear(rgb_feature_dim, embed_dim)
        self.point_proj = nn.Linear(geo_feature_dim, embed_dim)

        # A learnable token that will act as the final fused representation
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Learnable positional embeddings for the 3 tokens: [CLS, RGB, PointCloud]
        self.pos_embedding = nn.Parameter(torch.randn(1, 3, embed_dim))

        # The Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # The final layer normalization
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, rgb_features, point_features):
        # rgb_features: [B, 32], point_features: [B, 32]
        batch_size = rgb_features.shape[0]

        # 1. Project features into the embedding dimension
        rgb_embed = self.rgb_proj(rgb_features).unsqueeze(1)    # Shape: [B, 1, 128]
        point_embed = self.point_proj(point_features).unsqueeze(1)  # Shape: [B, 1, 128]

        # 2. Prepare the sequence: [CLS, RGB, PointCloud]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1) # Shape: [B, 1, 128]
        token_sequence = torch.cat((cls_tokens, rgb_embed, point_embed), dim=1) # Shape: [B, 3, 128]

        # 3. Add positional embeddings
        token_sequence = token_sequence + self.pos_embedding

        # 4. Pass through the transformer
        fused_sequence = self.transformer(token_sequence) # Shape: [B, 3, 128]
        fused_sequence = self.norm(fused_sequence)

        # 5. Extract the output of the [CLS] token as our final fused vector
        cls_output = fused_sequence[:, 0] # Shape: [B, 128]

        return cls_output



class DenseFusionNetwork(nn.Module):
    """Dense fusion network combining RGB and point cloud features"""
    def __init__(self, num_objects=13, use_transformer=config.USE_TRANSFORMER_FUSION, d_rgb = 32, d_geo = 32):
        super(DenseFusionNetwork, self).__init__()
        self.num_objects = num_objects

        #TRANSFORMER USE
        self.use_transformer = use_transformer

        # Feature extractors
        self.rgb_extractor = RGBFeatureExtractor(d_rgb)
        self.point_extractor = PointNetFeatureExtractor(d_geo)

        if self.use_transformer:
            self.transformer_fuser = TransformerFuser(rgb_feature_dim=d_rgb, geo_feature_dim=d_geo)

            # Pose regression head
            self.pose_fc1 = nn.Linear(128, 512)
            self.pose_fc2 = nn.Linear(512, 256)
            self.pose_fc3 = nn.Linear(256, 7)

            # Confidence estimation head
            self.conf_fc1 = nn.Linear(128, 256)
            self.conf_fc2 = nn.Linear(256, 64)
            self.conf_fc3 = nn.Linear(64, 1)

        else:
            self.fusion_head = nn.Sequential(
                nn.Linear(d_rgb + d_geo, 512),
                nn.ReLU(),
                nn.Linear(512, 1024),
                nn.ReLU()
            )

            # Pose regression head
            self.pose_fc1 = nn.Linear(1024, 512)
            self.pose_fc2 = nn.Linear(512, 256)
            self.pose_fc3 = nn.Linear(256, 7)  # [tx, ty, tz, qw, qx, qy, qz] - translation + quaternion

            # Confidence estimation head
            self.conf_fc1 = nn.Linear(1024, 256)
            self.conf_fc2 = nn.Linear(256, 64)
            self.conf_fc3 = nn.Linear(64, 1)

        self.dropout = nn.Dropout(0.3)
        self._initialize_weights()



    def _initialize_weights(self):
        """Initialize network weights"""
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Conv1d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    if m.out_features == 7:  # Pose prediction layer
                        nn.init.constant_(m.bias, 0)
                        with torch.no_grad():
                            m.bias[:3] = torch.tensor([0.0, 0.0, 0.3])  # translation
                            m.bias[3:] = torch.tensor([1.0, 0.0, 0.0, 0.0])  # quaternion
                    else:
                        nn.init.constant_(m.bias, 0)

    def forward(self, rgb, points):
        """Forward pass"""
        batch_size = rgb.size(0)
        num_points = points.size(1)

        # Extract features
        rgb_features = self.rgb_extractor(rgb).view(batch_size, -1)      # Shape: [B, 32]
        point_features = self.point_extractor(points)  # [B, 32]


        # # CONDITIONAL FUSION
        if self.use_transformer:
              x = self.transformer_fuser(rgb_features, point_features) # Shape: [B, 128]
        else:
              combined_features = torch.cat([rgb_features, point_features], dim=1) # Shape: [B, 64]
              x = self.fusion_head(combined_features)                           # Shape: [B, 1024]

        # Pose prediction
        pose_x = F.relu(self.pose_fc1(x))
        pose_x = self.dropout(pose_x)
        pose_x = F.relu(self.pose_fc2(pose_x))
        pose_x = self.dropout(pose_x)
        pose = self.pose_fc3(pose_x)

        # Normalize quaternion
        translation = pose[:, :3]
        quaternion = pose[:, 3:]
        quaternion = F.normalize(quaternion, p=2, dim=1)
        pose = torch.cat([translation, quaternion], dim=1)

        # Confidence prediction
        conf_x = F.relu(self.conf_fc1(x))
        conf_x = self.dropout(conf_x)
        conf_x = F.relu(self.conf_fc2(conf_x))
        confidence = self.conf_fc3(conf_x)

        return pose, confidence


#######

def test_model_architecture():
    """Test the Global Fusion architecture"""
    try:
        print("\nTesting Global Fusion version...")

        # Create the model (no 'use_transformer' argument needed)
        model = DenseFusionNetwork(num_objects=13)
        model = model.to(config.DEVICE)

        # --- Test with dummy data (this part is the same) ---
        batch_size = 2
        # Assuming config has PATCH_SIZE and NUM_POINTS defined
        test_rgb = torch.randn(batch_size, 3, config.PATCH_SIZE, config.PATCH_SIZE).to(config.DEVICE)
        test_points = torch.randn(batch_size, config.NUM_POINTS, 3).to(config.DEVICE)

        with torch.no_grad():
            pred_pose, pred_conf = model(test_rgb, test_points)

        print(f"  ✓ Output pose shape: {pred_pose.shape}")
        print(f"  ✓ Output confidence shape: {pred_conf.shape}")
        print(f"  ✓ Parameters: {sum(p.numel() for p in model.parameters()):,}")

        return True

    except Exception as e:
        print(f"✗ Model test failed: {e}")
        return False

# Test the architecture
test_model = test_model_architecture()
print("\n✓ Block 6 completed: DenseFusion architecture ready")

In [None]:
# ==============================================================================
# BLOCK 7: LOSS FUNCTIONS
# ==============================================================================

class DenseFusionLoss(nn.Module):
    """Complete loss function for DenseFusion training"""

    def __init__(self, object_models, sym_list, add_weight=1.0, conf_weight=0.1):
        super(DenseFusionLoss, self).__init__()
        self.object_models = object_models
        self.add_weight = add_weight  # Increased to balance scales
        self.conf_weight = conf_weight  # Reduced to prevent domination
        self.conf_loss = nn.BCEWithLogitsLoss()
        self.reset_stats()

    def reset_stats(self):
        """Reset statistics for tracking"""
        self.stats = {
            'total_samples': 0,
            'add_computed': 0,
            'pose_fallback': 0,
            'avg_add_error': 0.0
        }

    def quaternion_to_matrix(self, q):
        """Safe, batched quaternion to matrix conversion"""
        if q.dim() == 1:
            q = q.unsqueeze(0)

        q = F.normalize(q, p=2, dim=-1)
        w, x, y, z = q.unbind(-1)

        xx, yy, zz = x*x, y*y, z*z
        xy, xz, yz = x*y, x*z, y*z
        wx, wy, wz = w*x, w*y, w*z

        matrix = torch.stack([
            1 - 2*(yy + zz), 2*(xy - wz), 2*(xz + wy),
            2*(xy + wz), 1 - 2*(xx + zz), 2*(yz - wx),
            2*(xz - wy), 2*(yz + wx), 1 - 2*(xx + yy)
        ], dim=-1).view(q.size(0), 3, 3)

        return matrix.squeeze(0) if matrix.size(0) == 1 else matrix

    def compute_add_loss(self, pred_pose, gt_pose, class_id):
        """Compute ADD loss for a single sample"""
        if class_id not in self.object_models:
            return torch.tensor(0.0, device=pred_pose.device)

        vertices = torch.tensor(
            self.object_models[class_id]['vertices_raw'],
            device=pred_pose.device,
            dtype=torch.float32
        )
        if vertices.shape[0] < 10:
              return self._pose_distance_loss(pred_pose, gt_pose)
        # Extract components
        t_pred = pred_pose[:3]
        q_pred = pred_pose[3:]
        t_gt = gt_pose[:3]
        q_gt = gt_pose[3:]

        # Convert to rotation matrices
        R_pred = self.quaternion_to_matrix(q_pred)
        R_gt = self.quaternion_to_matrix(q_gt)

        # Transform model points
        pred_pts = vertices @ R_pred.T + t_pred
        gt_pts = vertices @ R_gt.T + t_gt

        # Calculate distances (scale to cm)
        dists = torch.norm(pred_pts - gt_pts, p=2, dim=1)
        add_loss = dists.mean() # Convert to cm scale
        if add_loss == float('inf') or torch.isnan(add_loss.detach()):
            print("ERROR: CAN'T COMPUTE ADD LOSS")
        # Update statistics
        self.stats['add_computed'] += 1
        self.stats['avg_add_error'] += add_loss.item()

        return add_loss

    def compute_add_s_loss(self, pred_pose, gt_pose, class_id):
        """Compute ADD-S loss for a single sample"""
        if class_id not in self.object_models:
            return torch.tensor(0.0, device=pred_pose.device)

        vertices = torch.tensor(
            self.object_models[class_id]['vertices_raw'],
            device=pred_pose.device,
            dtype=torch.float32
        )
        if vertices.shape[0] < 10:
            return self._pose_distance_loss(pred_pose, gt_pose)

        # Extract translation and quaternion
        t_pred = pred_pose[:3]
        q_pred = pred_pose[3:]
        t_gt = gt_pose[:3]
        q_gt = gt_pose[3:]

        # Convert quaternion to rotation matrices
        R_pred = self.quaternion_to_matrix(q_pred)
        R_gt = self.quaternion_to_matrix(q_gt)

        # Transform model vertices to world coordinates
        pred_pts = vertices @ R_pred.T + t_pred
        gt_pts = vertices @ R_gt.T + t_gt

        # ADD-S: Closest point distance for symmetric objects
        dists = torch.cdist(pred_pts.unsqueeze(0), gt_pts.unsqueeze(0)).squeeze(0)
        add_loss = dists.min(dim=1)[0].mean()

        if add_loss == float('inf') or torch.isnan(add_loss.detach()):
            print("ERROR: CAN'T COMPUTE ADD-S LOSS")

        self.stats['add_computed'] += 1
        self.stats['avg_add_error'] += add_loss.item()

        return add_loss

    def forward(self, pred_poses, gt_poses, pred_confidences, class_ids):
        """Compute total loss"""
        batch_size = pred_poses.size(0)
        self.stats['total_samples'] += batch_size

        total_add_loss = 0.0
        valid_samples = 0

        # Compute ADD loss for each sample
        for i in range(batch_size):
            class_id = class_ids[i].item()
            if class_id in self.sym_list:
                add_loss_val = self.compute_add_s_loss(pred_poses[i], gt_poses[i], class_id)
            else:
                add_loss_val = self.compute_add_loss(pred_poses[i], gt_poses[i], class_id)
            total_add_loss += add_loss_val
            valid_samples += 1

        avg_add_loss = total_add_loss / batch_size

        # Confidence loss
        conf_targets = torch.ones_like(pred_confidences)
        conf_loss = self.conf_loss(pred_confidences, conf_targets)
        pose_reg_loss = torch.tensor(0.0, device=pred_poses.device)
        for i in range(batch_size):
            trans_magnitude = torch.norm(pred_poses[i, :3])
            if trans_magnitude > 2.0:
                pose_reg_loss += (trans_magnitude - 2.0) ** 2
        pose_reg_loss = pose_reg_loss / batch_size

        # Total loss
        total_loss = (self.add_weight * avg_add_loss +
                     self.conf_weight * conf_loss +
                     0.1 * pose_reg_loss)
        loss_dict = {
            'total_loss': total_loss.item(),
            'add_loss': avg_add_loss.item(),
            'conf_loss': conf_loss.item(),
            'valid_samples': valid_samples
        }

        return total_loss, loss_dict

    def get_stats(self):
        """Get current statistics"""
        if self.stats['total_samples'] > 0:
            return {
                'total_samples': self.stats['total_samples'],
                'add_computed_ratio': self.stats['add_computed'] / self.stats['total_samples'],
                'pose_fallback_ratio': self.stats['pose_fallback'] / self.stats['total_samples'],
                'avg_add_error': self.stats['avg_add_error'] / max(self.stats['add_computed'], 1)
            }
        return self.stats

def test_loss_function():
    """Test the DenseFusionLoss with raw & GT vertices"""
    try:
        # 1) Create dummy object models with raw & GT vertices
        dummy_models = {}
        for i in range(3):
            # 100 random points in object frame (in meters)
            verts_raw = (np.random.randn(100, 3) * 0.05).astype(np.float32)
            # random GT rotation & translation
            axis = np.random.randn(3)
            angle = np.random.rand() * 2 * np.pi
            R_gt = R.from_rotvec(axis/np.linalg.norm(axis)*angle).as_matrix().astype(np.float32)
            t_gt = (np.random.randn(3) * 0.1).astype(np.float32)  # up to ±10cm
            verts_gt = verts_raw @ R_gt.T + t_gt[np.newaxis, :]

            dummy_models[i] = {
                'vertices_raw': verts_raw,  # Use raw vertices for model
            }

        loss_fn = DenseFusionLoss(dummy_models, config.SYMMETRIC_LIST)

        # 2) Generate test inputs
        batch_size = 2
        # Predicted poses: [tx,ty,tz,qw,qx,qy,qz]
        test_pred_poses = torch.randn(batch_size, 7, requires_grad=True)
        test_gt_poses = torch.randn(batch_size, 7, requires_grad=True)

        # Normalize quaternion part
        with torch.no_grad():
            for i in range(batch_size):
                q = test_pred_poses.data[i, 3:]
                test_pred_poses.data[i, 3:] = F.normalize(q, p=2, dim=0)
                q = test_gt_poses.data[i, 3:]
                test_gt_poses.data[i, 3:] = F.normalize(q, p=2, dim=0)

        # Predicted confidences
        test_pred_confs = torch.randn(batch_size, 1, requires_grad=True)
        # Random class ids in [0,2]
        test_class_ids = torch.randint(0, 3, (batch_size,))

        # 3) Call loss with exactly the three required args
        total_loss, loss_dict = loss_fn(test_pred_poses, test_gt_poses, test_pred_confs, test_class_ids)

        # 4) Backward to verify gradients flow
        total_loss.backward()

        print(f"✓ Loss function test successful:")
        print(f"  Total loss:   {loss_dict['total_loss']:.6f}")
        print(f"  ADD loss:     {loss_dict['add_loss']:.6f}")
        print(f"  Conf loss:    {loss_dict['conf_loss']:.6f}")

        return loss_fn

    except Exception as e:
        print(f"✗ Loss function test failed: {e}")
        import traceback
        traceback.print_exc()
        return None

# Run the test
test_loss = test_loss_function()
print("\n✓ Block 7 completed: Loss functions ready")

In [None]:
# ==============================================================================
# BLOCK 8: DATASET CLASS
# ==============================================================================

class DenseFusionDataset(Dataset):
    """Dataset for DenseFusion training and evaluation"""

    def __init__(self, data_config, split='train', num_points=None,
                patch_size=None, use_segmentation=config.USE_SEGMENTATION):
        self.data_config = data_config
        self.split = split
        self.num_points = num_points or config.NUM_POINTS
        self.patch_size = patch_size or config.PATCH_SIZE
        self.use_segmentation = use_segmentation

        # Dataset paths
        self.rgb_dir = data_config[split]
        self.depth_dir = data_config.get(f'depth_{split}')

        if not os.path.exists(self.rgb_dir):
            raise FileNotFoundError(f"RGB directory not found: {self.rgb_dir}")

        # Load image paths
        rgb_extensions = ["*.png", "*.jpg", "*.jpeg"]
        all_rgb_paths = []
        for ext in rgb_extensions:
            all_rgb_paths.extend(glob.glob(os.path.join(self.rgb_dir, ext)))

        self.rgb_paths = sorted(all_rgb_paths)
        print(f"Dataset '{split}': {len(self.rgb_paths)} images")
        self.object_models = {}
        self._load_raw_models(data_config.get('names', []))
        # Load 3D object models

    def _load_raw_models(self, object_names):
        """Load 3D object models from PLY files"""
        self.object_models = {}

        for obj_idx, obj_name in enumerate(object_names):
            vertices = None

            if os.path.exists(config.PLY_MODELS_DIR):
                ply_candidates = [
                    os.path.join(config.PLY_MODELS_DIR, f"obj_{obj_idx+1:02d}.ply"),
                    os.path.join(config.PLY_MODELS_DIR, f"obj_{obj_idx+1}.ply"),
                    os.path.join(config.PLY_MODELS_DIR, f"{obj_name}.ply")
                ]

                for ply_path in ply_candidates:
                    if os.path.exists(ply_path):
                        try:
                            mesh = trimesh.load_mesh(ply_path, process=False)
                            vertices = np.asarray(mesh.vertices, dtype=np.float32)
                            if vertices.size > 0:
                                break
                        except Exception:
                            continue

            # Fallback to dummy vertices if loading fails
            if vertices is None or vertices.size == 0:
                print('Erorr never use dummy')
                vertices = np.array([
                    [-20, -20, -20], [20, -20, -20], [20, 20, -20], [-20, 20, -20],
                    [-20, -20, 20], [20, -20, 20], [20, 20, 20], [-20, 20, 20]
                ], dtype=np.float32)

            # Sample vertices if too many
            if vertices.shape[0] > config.NUM_POINTS:
                indices = np.random.choice(vertices.shape[0], config.NUM_POINTS, replace=False)
                vertices = vertices[indices]
            self.object_models[obj_idx] = {
                'name': obj_name,
                'vertices_raw': vertices*0.001,  # meters
                'vertices_gt' : None               # fill per-sample
                }


        print(f"✓ Loaded models for {len(self.object_models)} objects")

    def get_depth_path(self, rgb_path):
        """Get corresponding depth image path"""
        if not self.depth_dir:
            return None
        rgb_filename = os.path.basename(rgb_path)
        depth_path = os.path.join(self.depth_dir, rgb_filename)
        return depth_path if os.path.exists(depth_path) else None
    def load_ground_truth(self, rgb_path):
        """Load ground truth pose and class from corresponding files"""
        filename_no_ext = os.path.splitext(os.path.basename(rgb_path))[0]

        # Default values
        gt_class_id = 0


        # Load YOLO label for bbox and class
        yolo_label_path = rgb_path.replace('/images/', '/labels/').replace('.png', '.txt').replace('.jpg', '.txt')
        if os.path.exists(yolo_label_path):
            try:
                flag_bbox=0
                with open(yolo_label_path, 'r') as f:
                  for lin in f.readlines():
                    line = lin.strip().split()
                    if len(line) == 5 and flag_bbox==0:
                        yolo_class_id = int(line[0])
                        gt_class_id = max(0, min(yolo_class_id - 1, len(self.data_config.get('names', [])) - 1))
                        bbox_normalized = [float(x) for x in line[1:5]]
                        flag_bbox=1
                    elif len(line) == 4:
                        gt_t = np.array([float(x) for x in line[1:]], dtype=np.float32)
                    elif len(line) >= 7:
                        gt_r = np.array([float(x) for x in line[1:]], dtype=np.float32)
            except Exception:

                pass
        # convert t mm->m

        # quaternion
        # reorder to wxyz

        rotation_matrix = np.array(gt_r).reshape((3, 3))
        rot = R.from_matrix(rotation_matrix)
        gt_t = gt_t / 1000.0
        quat = rot.as_quat()  # xyzw
        quat_wxyz = np.array([quat[3], quat[0], quat[1], quat[2]], np.float32)


        return gt_class_id, gt_t, quat_wxyz, gt_r, bbox_normalized


    def extract_patches_with_segmentation(self, rgb_image, depth_image, bbox_norm, mask=None):
        """Extract RGB and depth patches using segmentation mask"""
        h, w = rgb_image.shape[:2]

        # Convert normalized bbox to pixel coordinates
        xc_n, yc_n, w_n, h_n = bbox_norm
        xc_px = int(xc_n * w)
        yc_px = int(yc_n * h)
        w_px = int(w_n * w)
        h_px = int(h_n * h)

        x1 = max(0, xc_px - w_px // 2)
        y1 = max(0, yc_px - h_px // 2)
        x2 = min(w, xc_px + w_px // 2)
        y2 = min(h, yc_px + h_px // 2)

        # Ensure valid bounding box
        if x2 <= x1 or y2 <= y1:
            x1, y1, x2, y2 = 0, 0, min(w, 100), min(h, 100)

        # Extract RGB patch
        rgb_patch = rgb_image[y1:y2, x1:x2].copy()

        # Apply mask if available
        if mask is not None:
            mask_patch = mask[y1:y2, x1:x2]
            if rgb_patch.shape[:2] == mask_patch.shape:
                rgb_patch[mask_patch == 0] = 0

        # Resize RGB patch
        if rgb_patch.size == 0:
            rgb_patch = np.zeros((self.patch_size, self.patch_size, 3), dtype=np.uint8)
        else:
            rgb_patch = cv2.resize(rgb_patch, (self.patch_size, self.patch_size))

        # Extract depth patch
        depth_patch = None
        if depth_image is not None:
            depth_patch = depth_image[y1:y2, x1:x2].copy()
            if mask is not None:
                mask_patch = mask[y1:y2, x1:x2]
                if depth_patch.shape == mask_patch.shape:
                    depth_patch[mask_patch == 0] = 0

            if depth_patch.size > 0:
                depth_patch = cv2.resize(depth_patch, (self.patch_size, self.patch_size),
                                       interpolation=cv2.INTER_NEAREST)
            else:
                depth_patch = None

        return rgb_patch, depth_patch, [x1, y1, x2, y2]

    def depth_to_pointcloud(self, depth_patch, bbox_pixel):
        """Convert depth patch to point cloud"""
        if depth_patch is None:
            points = np.random.randn(self.num_points, 3).astype(np.float32) * 0.01
            return points

        h, w = depth_patch.shape
        if h == 0 or w == 0:
            points = np.random.randn(self.num_points, 3).astype(np.float32) * 0.01
            return points

        fx, fy = config.K[0, 0], config.K[1, 1]
        cx, cy = config.K[0, 2], config.K[1, 2]

        # Create coordinate grids
        y_coords, x_coords = np.mgrid[0:h, 0:w]

        # Map patch coordinates to original image coordinates
        x1, y1, x2, y2 = bbox_pixel
        scale_x = (x2 - x1) / w if w > 0 else 1
        scale_y = (y2 - y1) / h if h > 0 else 1

        x_coords_orig = x_coords * scale_x + x1
        y_coords_orig = y_coords * scale_y + y1

        # Flatten and filter valid depth values
        x_flat = x_coords_orig.flatten()
        y_flat = y_coords_orig.flatten()
        z_flat = depth_patch.flatten()

        valid_mask = (z_flat > 0) & (z_flat < 5.0)
        x_valid = x_flat[valid_mask]
        y_valid = y_flat[valid_mask]
        z_valid = z_flat[valid_mask]

        if len(z_valid) == 0:
            points = np.random.randn(self.num_points, 3).astype(np.float32) * 0.01
            return points

        # Convert to 3D coordinates
        points_x = (x_valid - cx) * z_valid / fx
        points_y = (y_valid - cy) * z_valid / fy
        points_z = z_valid

        points_3d = np.column_stack((points_x, points_y, points_z))

        # Sample to target number of points
        if len(points_3d) > self.num_points:
            indices = np.random.choice(len(points_3d), self.num_points, replace=False)
            points_3d = points_3d[indices]
        elif len(points_3d) < self.num_points:
            if len(points_3d) == 0:
                points_3d = np.random.randn(self.num_points, 3).astype(np.float32) * 0.01
            else:
                num_to_pad = self.num_points - len(points_3d)
                pad_indices = np.random.choice(len(points_3d), num_to_pad, replace=True)
                points_3d = np.vstack([points_3d, points_3d[pad_indices]])

        return points_3d.astype(np.float32)

    def __len__(self):
        return len(self.rgb_paths)

    def __getitem__(self, idx):
        """Get a single dataset sample"""
        rgb_path = self.rgb_paths[idx]

        try:
            # Load RGB image
            rgb_image = cv2.imread(rgb_path)
            if rgb_image is None:
                raise FileNotFoundError(f"Could not load RGB image: {rgb_path}")
            rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)

            # Load depth image
            depth_image = None
            depth_path = self.get_depth_path(rgb_path)
            if depth_path:
                depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
                if depth_image is not None:
                    depth_image = depth_image.astype(np.float32) / config.DEPTH_SCALE_MM_TO_M

            # Load ground truth
            gt_class_id, gt_t, gt_quat, gt_R, bbox_norm = self.load_ground_truth(rgb_path)
            gt_pose_7d_list = [gt_t[0], gt_t[1], gt_t[2],gt_quat[0], gt_quat[1], gt_quat[2], gt_quat[3]]

            # Convert the list to a NumPy array with the specified dtype
            gt_pose_7d = np.array(gt_pose_7d_list, dtype=np.float32)
            model = self.object_models[gt_class_id]
            raw = model['vertices_raw']  # [M,3]
            # transform raw -> GT camera

            # Apply segmentation if enabled
            object_mask = None
            if self.use_segmentation and segmentation_module is not None:
                try:
                    bbox_pixel = convert_yolo_bbox_to_pixel(
                        bbox_norm, rgb_image.shape[1], rgb_image.shape[0]
                    )
                    object_mask, _ = segmentation_module.refine_detection(
                        rgb_image, bbox_pixel, gt_class_id
                    )
                except Exception:
                    object_mask = None

            # Extract patches
            rgb_patch, depth_patch, bbox_pixel = self.extract_patches_with_segmentation(
                rgb_image, depth_image, bbox_norm, object_mask
            )

            # Generate point cloud
            points_3d = self.depth_to_pointcloud(depth_patch, bbox_pixel)

            # Convert to tensors
            rgb_tensor = torch.from_numpy(rgb_patch.transpose(2, 0, 1)).float() / 255.0
            points_tensor = torch.from_numpy(points_3d).float()
            gt_pose_tensor = torch.from_numpy(gt_pose_7d).float()
            class_id_tensor = torch.tensor(gt_class_id, dtype=torch.long)

            return {
                'rgb': rgb_tensor,
                'points': points_tensor,
                'class_id': class_id_tensor,
                'gt_pose': gt_pose_tensor,
            }

        except Exception as e:
            # Return default tensors on error
            print(f"Error loading sample: {e}")
            print(f"++++++++Error loading sample: please check and don't use default values ")
            default_rgb = torch.zeros((3, self.patch_size, self.patch_size), dtype=torch.float32)
            default_points = torch.zeros((self.num_points, 3), dtype=torch.float32)
            default_class_id = torch.tensor(0, dtype=torch.long)
            default_pose = torch.tensor([0.0, 0.0, 0.3, 1.0, 0.0, 0.0, 0.0], dtype=torch.float32)

            return {
                'rgb': default_rgb,
                'points': default_points,
                'class_id': default_class_id,
                'gt_pose': default_pose,
            }

def test_dataset():
    """Test the dataset implementation"""
    try:
        dataset_config = load_dataset_config(config.LINEMOD_ROOT)
        test_dataset = DenseFusionDataset(dataset_config, split='train', use_segmentation=False)

        # Test loading a sample
        sample = test_dataset[0]
        print(f"✓ Dataset test successful:")
        print(f"  Dataset size: {len(test_dataset)}")
        print(f"  RGB shape: {sample['rgb'].shape}")
        print(f"  Points shape: {sample['points'].shape}")
        print(f"  Class ID: {sample['class_id'].item()}")
        print(f"  GT pose shape: {sample['gt_pose'].shape}")

        return test_dataset

    except Exception as e:
        print(f"✗ Dataset test failed: {e}")
        return None

# Test the dataset
test_dataset_obj = test_dataset()
print("\n✓ Block 8 completed: Dataset ready")

In [None]:
# ==============================================================================
# BLOCK 9: TRAINING PIPELINE
# ==============================================================================
def load_checkpoint(path, model, optimizer=None, scheduler=None):
    """
    Loads model—and optionally optimizer & scheduler—state from a checkpoint.
    Returns the epoch to resume from.
    """

    checkpoint = torch.load(path, map_location=config.DEVICE)

    model.load_state_dict(checkpoint['model_state_dict'])
    start_epoch = checkpoint['epoch'] + 1

    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    print(f"Resuming from epoch {start_epoch}")
    return start_epoch

def create_data_loaders():
    """Create optimized data loaders"""
    dataset_config = load_dataset_config(config.LINEMOD_ROOT)

    train_dataset = DenseFusionDataset(
        dataset_config, split='train',
        use_segmentation=config.USE_SEGMENTATION
    )

    val_dataset = DenseFusionDataset(
        dataset_config, split='val',
        use_segmentation=config.USE_SEGMENTATION
    )

    # Use num_workers=0 for Colab compatibility
    train_loader = DataLoader(
        train_dataset, batch_size=config.BATCH_SIZE, shuffle=True,
        num_workers=0, pin_memory=False, drop_last=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=config.BATCH_SIZE, shuffle=False,
        num_workers=0, pin_memory=False, drop_last=False
    )

    return train_loader, val_loader, train_dataset

def save_model_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss, is_best=False):
    """Save model checkpoint"""
    checkpoint_data = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'config': {
            'num_points': config.NUM_POINTS,
            'patch_size': config.PATCH_SIZE,
            'batch_size': config.BATCH_SIZE,
            'learning_rate': config.LEARNING_RATE
        }
    }

    # Create directories
    create_directories()

    if is_best:
        # Save best model
        best_path = os.path.join(config.MODELS_SAVE_DIR, f'{config.MODELS_NAME}_full_info.pth' )
        torch.save(checkpoint_data, best_path)

        # Save simple state dict to drive
        simple_path = os.path.join(config.MODELS_SAVE_DIR, f'{config.MODELS_NAME}_densefusion_best.pth')
        torch.save(model.state_dict(), simple_path)

        print(f"✓ Best model saved: {simple_path}")
        return simple_path
    else:
        # Save regular checkpoint
        epoch_path = os.path.join(config.CHECKPOINTS_DIR, f'{config.MODELS_NAME}_checkpoint_epoch_{epoch:03d}.pth')
        torch.save(checkpoint_data, epoch_path)
        return epoch_path

def create_training_plots(train_losses, val_losses, save_path=None):
    """Create and save training progress plots"""
    try:
        plt.figure(figsize=(12, 4))

        # Loss curves
        plt.subplot(1, 2, 1)
        epochs = range(1, len(train_losses) + 1)
        plt.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
        plt.plot(epochs, val_losses, 'r-', label='Val Loss', linewidth=2)
        plt.title('Training Progress')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Log scale
        plt.subplot(1, 2, 2)
        plt.semilogy(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
        plt.semilogy(epochs, val_losses, 'r-', label='Val Loss', linewidth=2)
        plt.title('Training Progress (Log Scale)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss (log)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            plt.close()
        else:
            plt.show()

    except Exception as e:
        print(f"Failed to create plots: {e}")

def train_densefusion():
    """Complete (simplified) DenseFusion training pipeline"""
    print("=" * 60)
    print("SIMPLIFIED DENSEFUSION TRAINING PIPELINE")
    print("=" * 60)

    # Setup
    config.setup_environment()
    start_time = datetime.datetime.now()

    # Load data
    train_loader, val_loader, train_dataset = create_data_loaders()
    dataset_config = load_dataset_config(config.LINEMOD_ROOT)
    num_classes = len(dataset_config.get('names', []))

    print(f"Training setup:")
    print(f"  Classes: {num_classes}")
    print(f"  Train samples: {len(train_dataset)}")
    print(f"  Epochs: {config.NUM_EPOCHS}")
    print(f"  Batch size: {config.BATCH_SIZE}")
    print(f"  Device: {config.DEVICE}")

    # Initialize model and training components
    #model = DenseFusionNetwork(num_objects=num_classes).to(config.DEVICE)
    model = DenseFusionNetwork( num_objects=num_classes,
                               use_transformer=config.USE_TRANSFORMER_FUSION
                                ).to(config.DEVICE)
    loss_fn = DenseFusionLoss(train_dataset.object_models, config.SYMMETRIC_LIST)

    optimizer = optim.Adam(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=1e-4
    )

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-7
    )

    # Mixed precision scaler
    scaler = None
    if config.USE_MIXED_PRECISION and torch.cuda.is_available():
        scaler = torch.cuda.amp.GradScaler()

    # Training state
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    print(f"\nStarting training...")

    try:

        for epoch in range(config.NUM_EPOCHS):
            print(f"\nEpoch {epoch+1}/{config.NUM_EPOCHS}")
            print("-" * 40)

            # Training phase
            model.train()
            train_loss_accum = 0.0
            train_batches = 0
            loss_fn.reset_stats()

            train_pbar = tqdm(train_loader, desc=f"Training", leave=False)

            for batch_idx, batch in enumerate(train_pbar):
                try:

                    # Move data to device
                    rgb = batch['rgb'].to(config.DEVICE, non_blocking=True)
                    points = batch['points'].to(config.DEVICE, non_blocking=True)
                    gt_poses = batch['gt_pose'].to(config.DEVICE, non_blocking=True)
                    class_ids = batch['class_id'].to(config.DEVICE, non_blocking=True)

                    optimizer.zero_grad()

                    # Forward pass
                    if scaler is not None:
                        with torch.cuda.amp.autocast():
                            pred_poses, pred_confs = model(rgb, points)
                            total_loss, loss_dict = loss_fn(pred_poses,gt_poses, pred_confs, class_ids)
                    else:
                        pred_poses, pred_confs = model(rgb, points)
                        total_loss, loss_dict = loss_fn(pred_poses,gt_poses, pred_confs, class_ids)

                    # Backward pass
                    if scaler is not None:
                        scaler.scale(total_loss).backward()
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        total_loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        optimizer.step()

                    train_loss_accum += loss_dict['total_loss']
                    train_batches += 1

                    train_pbar.set_postfix({
                        'Loss': f"{loss_dict['total_loss']:.6f}",
                        'ADD': f"{loss_dict['add_loss']:.6f}"
                    })

                    # Memory cleanup
                    if batch_idx % 10 == 0:
                        cleanup_memory()

                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"\nOOM at batch {batch_idx}, cleaning up...")
                        cleanup_memory()
                        optimizer.zero_grad()
                        continue
                    else:
                        print(f"\nError in batch {batch_idx}: {e}")
                        continue

            avg_train_loss = train_loss_accum / max(train_batches, 1)
            train_losses.append(avg_train_loss)

            # Validation phase
            model.eval()
            val_loss_accum = 0.0
            val_batches = 0

            val_pbar = tqdm(val_loader, desc=f"Validation", leave=False)

            with torch.no_grad():
                for batch_idx, batch in enumerate(val_pbar):
                    try:
                        rgb = batch['rgb'].to(config.DEVICE, non_blocking=True)
                        points = batch['points'].to(config.DEVICE, non_blocking=True)
                        gt_poses = batch['gt_pose'].to(config.DEVICE, non_blocking=True)
                        class_ids = batch['class_id'].to(config.DEVICE, non_blocking=True)

                        if scaler is not None:
                            with torch.cuda.amp.autocast():
                                pred_poses, pred_confs = model(rgb, points)
                                total_loss, loss_dict = loss_fn(pred_poses,gt_poses, pred_confs, class_ids)
                        else:
                            pred_poses, pred_confs = model(rgb, points)
                            total_loss, loss_dict = loss_fn(pred_poses,gt_poses, pred_confs, class_ids)

                        val_loss_accum += loss_dict['total_loss']
                        val_batches += 1

                        val_pbar.set_postfix({'Val Loss': f"{loss_dict['total_loss']:.4f}"})

                    except Exception as e:
                        continue

            avg_val_loss = val_loss_accum / max(val_batches, 1)
            val_losses.append(avg_val_loss)

            # Update learning rate
            scheduler.step(avg_val_loss)
            current_lr = optimizer.param_groups[0]['lr']

            # Print epoch results
            print(f"Epoch {epoch+1} Results:")
            print(f"  Train Loss: {avg_train_loss:.6f}")
            print(f"  Val Loss: {avg_val_loss:.6f}")
            print(f"  Learning Rate: {current_lr:.8f}")

            # Save best model
            is_best = avg_val_loss < best_val_loss
            if is_best:
                best_val_loss = avg_val_loss
                print(f"  ✓ NEW BEST MODEL!")

            save_model_checkpoint(model, optimizer, scheduler, epoch,
                                avg_train_loss, avg_val_loss, is_best)

            # Create plots every 2 epochs
            if (epoch + 1) % 2 == 0:
                plot_path = os.path.join(config.MODELS_SAVE_DIR, f'{config.MODELS_NAME}_training_progress.png')
                create_training_plots(train_losses, val_losses, plot_path)

            cleanup_memory()

    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    except Exception as e:
        print(f"\nTraining error: {e}")
        import traceback
        traceback.print_exc()

    # Save final results
    total_time = (datetime.datetime.now() - start_time).total_seconds()

    # Final model save
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    final_path = os.path.join(config.MODELS_SAVE_DIR, f'{config.MODELS_NAME}_densefusion_final_{timestamp}.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_val_loss': best_val_loss,
        'total_time': total_time,
        'config_dict': config.__dict__
    }, final_path)

    # Final plots
    if train_losses and val_losses:
        final_plot_path = os.path.join(config.MODELS_SAVE_DIR, f'{config.MODELS_NAME}_final_training_plot_{timestamp}.png')
        create_training_plots(train_losses, val_losses, final_plot_path)

    print(f"\n" + "=" * 60)
    print("TRAINING COMPLETED")
    print(f"=" * 60)
    print(f"Total time: {total_time:.1f} seconds")
    print(f"Epochs completed: {len(train_losses)}")
    print(f"Best validation loss: {best_val_loss:.6f}")
    if train_losses:
        print(f"Final train loss: {train_losses[-1]:.6f}")
    print(f"Models saved to: {config.MODELS_SAVE_DIR}")

    return {
        'model': model,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_val_loss': best_val_loss,
        'total_time': total_time
    }

def load_trained_model(model_path=None,use_transformer=config.USE_TRANSFORMER_FUSION):
    """Load a trained DenseFusion model"""
    if model_path is None:
        # Try to find best model
        model_path = os.path.join(config.MODELS_SAVE_DIR, f'{config.MODELS_NAME}_densefusion_best.pth')

    if not os.path.exists(model_path):
        print(f"Model not found: {model_path}")
        return None

    try:
        # Load dataset config to get number of classes
        dataset_config = load_dataset_config(config.LINEMOD_ROOT)
        num_classes = len(dataset_config.get('names', []))

        # Create model
        model = DenseFusionNetwork(num_objects=num_classes,use_transformer=use_transformer)

        # Load weights
        if model_path.endswith('best_model.pth'):
            # Load from checkpoint
            checkpoint = torch.load(model_path, map_location=config.DEVICE)
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # Load state dict directly
            model.load_state_dict(torch.load(model_path, map_location=config.DEVICE))

        model = model.to(config.DEVICE)
        model.eval()

        model_type = "Transformer" if use_transformer else "MLP"
        print(f"✓ {model_type} model loaded from: {model_path}")
        return model

    except Exception as e:
        print(f"Failed to load model: {e}")
        return None

print("✓ Block 9 completed: Training pipeline ready")

In [None]:
# ==============================================================================
# BLOCK 10: EVALUATION PIPELINE
# ==============================================================================

def detect_and_estimate_pose(yolo_model, pose_model, dataset, rgb_path, depth_path=None):
    """Complete detection and pose estimation pipeline"""
    try:
        # Load RGB image
        rgb_image = cv2.imread(rgb_path)
        if rgb_image is None:
            return None
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)

        # YOLO detection
        results = yolo_model(rgb_image, verbose=False)
        if len(results) == 0 or len(results[0].boxes) == 0:
            return None

        # Get best detection
        result = results[0]
        box = result.boxes[0]
        conf = float(box.conf)
        yolo_class_id = int(box.cls)
        class_id = yolo_class_id - 1  # Convert to 0-based
        x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()

        # Ensure valid bbox
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(rgb_image.shape[1], x2), min(rgb_image.shape[0], y2)

        # Convert to normalized bbox
        h, w = rgb_image.shape[:2]
        xc_n = (x1 + x2) / (2 * w)
        yc_n = (y1 + y2) / (2 * h)
        w_n = (x2 - x1) / w
        h_n = (y2 - y1) / h
        bbox_normalized = [xc_n, yc_n, w_n, h_n]

        # Apply segmentation if available
        object_mask = None
        if segmentation_module is not None:
            try:
                object_mask, _ = segmentation_module.refine_detection(
                    rgb_image, [x1, y1, x2, y2], class_id
                )
            except Exception:
                object_mask = None

        # Load depth image
        depth_image = None
        if depth_path and os.path.exists(depth_path):
            depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
            if depth_image is not None:
                depth_image = depth_image.astype(np.float32) / config.DEPTH_SCALE_MM_TO_M

        # Extract patches
        rgb_patch, depth_patch, bbox_pixel = dataset.extract_patches_with_segmentation(
            rgb_image, depth_image, bbox_normalized, object_mask
        )

        # Generate point cloud
        points_3d = dataset.depth_to_pointcloud(depth_patch, bbox_pixel)

        # Prepare for model inference
        rgb_tensor = torch.from_numpy(rgb_patch.transpose(2, 0, 1)).float() / 255.0
        rgb_tensor = rgb_tensor.unsqueeze(0).to(config.DEVICE)
        points_tensor = torch.from_numpy(points_3d).unsqueeze(0).to(config.DEVICE)

        # Pose estimation
        with torch.no_grad():
            pred_pose, pred_conf = pose_model(rgb_tensor, points_tensor)

        pred_pose_np = pred_pose.cpu().numpy().flatten()
        confidence = torch.sigmoid(pred_conf).cpu().numpy().item()

        return {
            'bbox': [x1, y1, x2, y2],
            'class_id': class_id,
            'pose': pred_pose_np,
            'confidence': confidence,
            'yolo_confidence': conf,
            'mask': object_mask
        }

    except Exception as e:
        print(f"Error in pose estimation: {e}")
        return None

def evaluate_model_comprehensive(yolo_model, pose_model, test_dataset, model_diameters=None):
    """Comprehensive model evaluation with ADD metrics"""
    print(f"Starting comprehensive evaluation...")
    num_samples = min(config.MAX_EVAL_SAMPLES, len(test_dataset))
    print(f"Evaluating on {num_samples} samples")

    pose_model.eval()

    # Initialize metrics
    metrics = {
        'add_values': [],
        'add_2cm': [],
        'add_5cm': [],
        'add_10cm': [],
        'yolo_confidence': [],
        'pose_confidence': [],
        'class_ids': [],
        'success_by_class': {},
        'detection_rate': 0,
        'pose_estimates': []
    }

    # Add diameter-based metrics if available
    if model_diameters is not None:
        metrics.update({'add_5p': [], 'add_10p': [], 'add_20p': []})

    # Evaluation loop
    for i in tqdm(range(num_samples), desc="Evaluating"):
        try:
            # Load sample
            sample = test_dataset[i]
            rgb_path = test_dataset.rgb_paths[i]
            gt_pose = sample['gt_pose'].cpu().numpy()
            class_id = sample['class_id'].item()
            depth_path = test_dataset.get_depth_path(rgb_path)

            # Run detection and pose estimation
            result = detect_and_estimate_pose(
                yolo_model, pose_model, test_dataset, rgb_path, depth_path
            )

            if result is None:
                # No detection - record as failure
                metrics['add_values'].append(float('inf'))
                metrics['add_2cm'].append(0)
                metrics['add_5cm'].append(0)
                metrics['add_10cm'].append(0)
                metrics['yolo_confidence'].append(0)
                metrics['pose_confidence'].append(0)
                metrics['class_ids'].append(class_id)

                if model_diameters is not None:
                    metrics['add_5p'].append(0)
                    metrics['add_10p'].append(0)
                    metrics['add_20p'].append(0)
                continue

            # Extract prediction results
            pred_pose = result['pose']
            yolo_conf = result['yolo_confidence']
            pose_conf = result['confidence']

            # Compute ADD metrics if valid model available
            if class_id in test_dataset.object_models:
                model_vertices = test_dataset.object_models[class_id]['vertices_raw']

                if model_vertices.shape[0] > 10:
                    diameter = model_diameters.get(class_id, None) if model_diameters else None

                    # Compute ADD metrics
                    add_metrics = compute_add_metrics_with_thresholds(
                        pred_pose, gt_pose, class_id, config.SYMMETRIC_LIST, model_vertices, diameter
                    )

                    metrics['add_values'].append(add_metrics['add_value'])
                    metrics['add_2cm'].append(int(add_metrics['add_success_2cm']))
                    metrics['add_5cm'].append(int(add_metrics['add_success_5cm']))
                    metrics['add_10cm'].append(int(add_metrics['add_success_10cm']))

                    if model_diameters is not None and diameter is not None:
                        metrics['add_5p'].append(int(add_metrics['add_success_5p']))
                        metrics['add_10p'].append(int(add_metrics['add_success_10p']))
                        metrics['add_20p'].append(int(add_metrics['add_success_20p']))
                    elif model_diameters is not None:
                        metrics['add_5p'].append(0)
                        metrics['add_10p'].append(0)
                        metrics['add_20p'].append(0)
                else:
                    # Invalid model - record as failure
                    metrics['add_values'].append(float('inf'))
                    metrics['add_2cm'].append(0)
                    metrics['add_5cm'].append(0)
                    metrics['add_10cm'].append(0)
                    if model_diameters is not None:
                        metrics['add_5p'].append(0)
                        metrics['add_10p'].append(0)
                        metrics['add_20p'].append(0)
            else:
                # No model available - record as failure
                metrics['add_values'].append(float('inf'))
                metrics['add_2cm'].append(0)
                metrics['add_5cm'].append(0)
                metrics['add_10cm'].append(0)
                if model_diameters is not None:
                    metrics['add_5p'].append(0)
                    metrics['add_10p'].append(0)
                    metrics['add_20p'].append(0)

            # Record confidence scores and other metrics
            metrics['yolo_confidence'].append(yolo_conf)
            metrics['pose_confidence'].append(pose_conf)
            metrics['class_ids'].append(class_id)

            # Per-class tracking
            if class_id not in metrics['success_by_class']:
                metrics['success_by_class'][class_id] = {
                    'count': 0, 'success_2cm': 0, 'success_5cm': 0, 'success_10cm': 0
                }
                if model_diameters is not None:
                    metrics['success_by_class'][class_id].update({
                        'success_5p': 0, 'success_10p': 0, 'success_20p': 0
                    })

            metrics['success_by_class'][class_id]['count'] += 1
            metrics['success_by_class'][class_id]['success_2cm'] += metrics['add_2cm'][-1]
            metrics['success_by_class'][class_id]['success_5cm'] += metrics['add_5cm'][-1]
            metrics['success_by_class'][class_id]['success_10cm'] += metrics['add_10cm'][-1]

            if model_diameters is not None:
                metrics['success_by_class'][class_id]['success_5p'] += metrics['add_5p'][-1]
                metrics['success_by_class'][class_id]['success_10p'] += metrics['add_10p'][-1]
                metrics['success_by_class'][class_id]['success_20p'] += metrics['add_20p'][-1]

        except Exception as e:
            print(f"Error evaluating sample {i}: {e}")
            continue

    # Calculate overall metrics
    detection_count = sum(1 for v in metrics['add_values'] if v < float('inf'))
    total_samples = len(metrics['add_values'])

    metrics['detection_rate'] = detection_count / total_samples if total_samples > 0 else 0

    # Calculate mean ADD only from valid detections
    valid_add_values = [v for v in metrics['add_values'] if v < float('inf')]
    metrics['mean_add'] = np.mean(valid_add_values) if valid_add_values else float('inf')

    metrics['success_rate_2cm'] = np.mean(metrics['add_2cm']) if metrics['add_2cm'] else 0
    metrics['success_rate_5cm'] = np.mean(metrics['add_5cm']) if metrics['add_5cm'] else 0
    metrics['success_rate_10cm'] = np.mean(metrics['add_10cm']) if metrics['add_10cm'] else 0

    if model_diameters is not None:
        metrics['success_rate_5p'] = np.mean(metrics['add_5p']) if metrics['add_5p'] else 0
        metrics['success_rate_10p'] = np.mean(metrics['add_10p']) if metrics['add_10p'] else 0
        metrics['success_rate_20p'] = np.mean(metrics['add_20p']) if metrics['add_20p'] else 0

    # Calculate per-class success rates
    for class_id, stats in metrics['success_by_class'].items():
        if stats['count'] > 0:
            for key in ['success_2cm', 'success_5cm', 'success_10cm']:
                stats[f'rate_{key}'] = stats[key] / stats['count']
            if model_diameters is not None:
                for key in ['success_5p', 'success_10p', 'success_20p']:
                    stats[f'rate_{key}'] = stats[key] / stats['count']

    # Print results
    print("\n" + "="*60)
    print("EVALUATION RESULTS")
    print("="*60)
    print(f"Overall Performance:")
    print(f"  Samples evaluated: {total_samples}")
    print(f"  Detection rate: {metrics['detection_rate']:.4f} ({detection_count}/{total_samples})")
    print(f"  Mean ADD error: {metrics['mean_add']:.4f} m")
    print(f"  Success rate (<2cm): {metrics['success_rate_2cm']:.4f}")
    print(f"  Success rate (<5cm): {metrics['success_rate_5cm']:.4f}")
    print(f"  Success rate (<10cm): {metrics['success_rate_10cm']:.4f}")

    if model_diameters is not None:
        print(f"  Success rate (<5% diameter): {metrics['success_rate_5p']:.4f}")
        print(f"  Success rate (<10% diameter): {metrics['success_rate_10p']:.4f}")
        print(f"  Success rate (<20% diameter): {metrics['success_rate_20p']:.4f}")

    print("="*60)
    return metrics

def run_complete_evaluation():
    """Run the complete evaluation pipeline"""
    print("="*60)
    print("COMPLETE SIMPLIFIED DENSEFUSION EVALUATION")
    print("="*60)

    try:
        # Load models and data
        yolo_model = load_yolo_model(config.YOLO_MODEL_PATH)
        if yolo_model is None:
            return None

        dataset_config = load_dataset_config(config.LINEMOD_ROOT)
        model_diameters = load_model_diameters(config.DIAMETER_INFO_PATH)

        # Create test dataset
        test_dataset = DenseFusionDataset(
            dataset_config, split='val',
            use_segmentation=config.USE_SEGMENTATION
        )

        # Load DenseFusion model
        pose_model = load_trained_model(use_transformer=config.USE_TRANSFORMER_FUSION)
        if pose_model is None:
            print("⚠ No trained model found. Using random initialization.")
            num_classes = len(dataset_config.get('names', []))
            pose_model = DenseFusionNetwork(num_objects=num_classes,use_transformer=config.USE_TRANSFORMER_FUSION).to(config.DEVICE)

        # Run evaluation
        metrics = evaluate_model_comprehensive(
            yolo_model, pose_model, test_dataset, model_diameters
        )

        # Save evaluation results
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = os.path.join(config.MODELS_SAVE_DIR, f'{config.MODELS_NAME}_evaluation_results_{timestamp}.json')

        # Convert numpy arrays to lists for JSON serialization
        json_metrics = {}
        for key, value in metrics.items():
            if isinstance(value, np.ndarray):
                json_metrics[key] = value.tolist()
            elif isinstance(value, dict):
                json_metrics[key] = {}
                for k, v in value.items():
                    if isinstance(v, np.ndarray):
                        json_metrics[key][k] = v.tolist()
                    elif isinstance(v, (np.integer, np.floating)):
                        json_metrics[key][k] = float(v)
                    else:
                        json_metrics[key][k] = v
            elif isinstance(value, (np.integer, np.floating)):
                json_metrics[key] = float(value)
            else:
                json_metrics[key] = value

        try:
            with open(results_file, 'w') as f:
                json.dump(json_metrics, f, indent=2)
            print(f"✓ Evaluation results saved to: {results_file}")
        except Exception as e:
            print(f"⚠ Failed to save results: {e}")

        return metrics

    except Exception as e:
        print(f"✗ Evaluation failed: {e}")
        import traceback
        traceback.print_exc()
        return None

print("✓ Block 10 completed: Evaluation pipeline ready")

In [None]:
# ==============================================================================
# BLOCK 11: VISUALIZATION BLOCK - WIP DOES IT WORK??
# ==============================================================================

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

def get_all_objects_in_image(dataset, sample_idx):
    """Find ALL objects in the same image by reading the YOLO label file"""
    try:
        # Check if dataset is valid
        if not hasattr(dataset, 'rgb_paths') or sample_idx >= len(dataset.rgb_paths):
            print(f"⚠️ Invalid dataset or sample_idx {sample_idx}")
            return []

        rgb_path = dataset.rgb_paths[sample_idx]
        label_path = rgb_path.replace('/images/', '/labels/').replace('.png', '.txt').replace('.jpg', '.txt')

        if not os.path.exists(label_path):
            return [{'sample_idx': sample_idx, 'class_id': 0, 'bbox_norm': [0.5, 0.5, 0.3, 0.3],
                    'rgb_path': rgb_path, 'object_idx': 0}]

        objects = []
        with open(label_path, 'r') as f:
            for line_idx, line in enumerate(f):
                line = line.strip()
                if line:
                    parts = line.split()
                    if len(parts) == 5:
                        yolo_class_id = int(parts[0])
                        class_id = yolo_class_id - 1
                        bbox_norm = [float(x) for x in parts[1:5]]

                    elif len(parts) == 4:

                        gt_t = np.array([float(x) for x in parts[1:]], dtype=np.float32)
                    elif len(parts) >= 7:


                        gt_r = np.array([float(x) for x in parts[1:]], dtype=np.float32)

            rotation_matrix = np.array(gt_r).reshape((3, 3))
            rot = R.from_matrix(rotation_matrix)
            gt_t = gt_t / 1000.0
            # quaternion
            quat = rot.as_quat()  # xyzw
            # reorder to wxyz
            quat_wxyz = np.array([quat[3], quat[0], quat[1], quat[2]], np.float32)
            gt_pose_7d_list = [gt_t[0], gt_t[1], gt_t[2],quat_wxyz[0], quat_wxyz[1], quat_wxyz[2], quat_wxyz[3]]
            gt_pose_7d = np.array(gt_pose_7d_list, dtype=np.float32)
            gt_pose_tensor = torch.from_numpy(gt_pose_7d).float()

            objects.append({
                'sample_idx': sample_idx,
                'class_id': class_id,
                'bbox_norm': bbox_norm,
                'rgb_path': rgb_path,
                'object_idx': line_idx,
                'pose':gt_pose_tensor
            })

        return objects if objects else [{'sample_idx': sample_idx, 'class_id': 0,
                                       'bbox_norm': [0.5, 0.5, 0.3, 0.3], 'rgb_path': rgb_path, 'object_idx': 0}]

    except Exception as e:
        print(f"Error in get_all_objects_in_image: {e}")
        return []

def create_sample_for_specific_object(dataset, object_info, rgb_image=None):
    """Create a sample dict for a specific object"""
    try:
        sample_idx = object_info['sample_idx']
        class_id = object_info['class_id']
        bbox_norm = object_info['bbox_norm']
        # Load RGB image if not provided
        if rgb_image is None:
            rgb_path = object_info['rgb_path']
            rgb_image = cv2.imread(rgb_path)
            if rgb_image is not None:
                rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)

        if rgb_image is None:
            return None

        # Load depth image
        depth_image = None
        depth_path = dataset.get_depth_path(object_info['rgb_path'])
        if depth_path and os.path.exists(depth_path):
            depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
            if depth_image is not None:
                depth_image = depth_image.astype(np.float32) / config.DEPTH_SCALE_MM_TO_M

        # Extract patches for this specific object
        rgb_patch, depth_patch, bbox_pixel = dataset.extract_patches_with_segmentation(
            rgb_image, depth_image, bbox_norm, mask=None
        )

        # Generate point cloud
        points_3d = dataset.depth_to_pointcloud(depth_patch, bbox_pixel)

        # Create tensors
        rgb_tensor = torch.from_numpy(rgb_patch.transpose(2, 0, 1)).float() / 255.0
        points_tensor = torch.from_numpy(points_3d).float()

        gt_pose_tensor = object_info['pose']

        return {
            'rgb': rgb_tensor,
            'points': points_tensor,
            'class_id': torch.tensor(class_id, dtype=torch.long),
            'gt_pose': gt_pose_tensor,
            'object_info': object_info
        }

    except Exception as e:
        return None

def compute_add_visualization_fixed(pred_pose, gt_pose, vertices):
    """ ADD computation for visualization"""
    try:
        # Convert to numpy
        if torch.is_tensor(pred_pose):
            pred_pose = pred_pose.detach().cpu().numpy()
        if torch.is_tensor(gt_pose):
            gt_pose = gt_pose.detach().cpu().numpy()
        if torch.is_tensor(vertices):
            vertices = vertices.detach().cpu().numpy()

        # Normalize quaternions
        pred_quat_norm = np.linalg.norm(pred_pose[3:])
        gt_quat_norm = np.linalg.norm(gt_pose[3:])

        if pred_quat_norm > 0:
            pred_pose[3:] = pred_pose[3:] / pred_quat_norm
        if gt_quat_norm > 0:
            gt_pose[3:] = gt_pose[3:] / gt_quat_norm

        # Convert vertices to meters if in mm
        if np.max(np.abs(vertices)) > 1.0:
            print('conveting_vertices')
            vertices_m = vertices * 0.001
        else:
            vertices_m = vertices

        # Decompose poses - DenseFusion format: [tx, ty, tz, qw, qx, qy, qz]
        gt_t = gt_pose[:3]
        gt_quat_wxyz = gt_pose[3:]
        gt_quat_xyzw = [gt_quat_wxyz[1], gt_quat_wxyz[2], gt_quat_wxyz[3], gt_quat_wxyz[0]]
        gt_R = R.from_quat(gt_quat_xyzw).as_matrix()

        pred_t = pred_pose[:3]
        pred_quat_wxyz = pred_pose[3:]
        pred_quat_xyzw = [pred_quat_wxyz[1], pred_quat_wxyz[2], pred_quat_wxyz[3], pred_quat_wxyz[0]]
        pred_R = R.from_quat(pred_quat_xyzw).as_matrix()

        # Transform vertices
        gt_points = (gt_R @ vertices_m.T).T + gt_t
        pred_points = (pred_R @ vertices_m.T).T + pred_t

        # Compute ADD
        distances = np.linalg.norm(gt_points - pred_points, axis=1)
        add_error = np.mean(distances)

        return add_error, pred_points, gt_points

    except Exception as e:
        return float('inf'), None, None

def compute_add_s_visualization_fixed(pred_pose, gt_pose, vertices):
    """ADD-S computation for visualization (symmetric objects)"""
    try:
        # Convert to numpy
        if torch.is_tensor(pred_pose):
            pred_pose = pred_pose.detach().cpu().numpy()
        if torch.is_tensor(gt_pose):
            gt_pose = gt_pose.detach().cpu().numpy()
        if torch.is_tensor(vertices):
            vertices = vertices.detach().cpu().numpy()

        # Normalize quaternions
        pred_quat_norm = np.linalg.norm(pred_pose[3:])
        gt_quat_norm = np.linalg.norm(gt_pose[3:])
        if pred_quat_norm > 0:
            pred_pose[3:] = pred_pose[3:] / pred_quat_norm
        if gt_quat_norm > 0:
            gt_pose[3:] = gt_pose[3:] / gt_quat_norm

        # Convert vertices to meters if in mm
        if np.max(np.abs(vertices)) > 1.0:
            print('converting_vertices')
            vertices_m = vertices * 0.001
        else:
            vertices_m = vertices

        # Decompose poses
        gt_t = gt_pose[:3]
        gt_quat_wxyz = gt_pose[3:]
        gt_quat_xyzw = [gt_quat_wxyz[1], gt_quat_wxyz[2], gt_quat_wxyz[3], gt_quat_wxyz[0]]
        gt_R = R.from_quat(gt_quat_xyzw).as_matrix()

        pred_t = pred_pose[:3]
        pred_quat_wxyz = pred_pose[3:]
        pred_quat_xyzw = [pred_quat_wxyz[1], pred_quat_wxyz[2], pred_quat_wxyz[3], pred_quat_wxyz[0]]
        pred_R = R.from_quat(pred_quat_xyzw).as_matrix()

        # Transform vertices
        gt_points = (gt_R @ vertices_m.T).T + gt_t
        pred_points = (pred_R @ vertices_m.T).T + pred_t

        # Compute ADD-S (nearest-neighbor distance)
        tree = cKDTree(gt_points)
        distances, _ = tree.query(pred_points, k=1)
        add_s_error = np.mean(distances)

        return add_s_error, pred_points, gt_points

    except Exception as e:
        return float('inf'), None, None

def draw_bboxes_on_image(rgb_image, objects_info, predictions):
    """Draw bounding boxes and labels on the original image"""
    image_with_boxes = rgb_image.copy()
    h, w = image_with_boxes.shape[:2]

    colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255)]

    for i, (obj_info, pred_info) in enumerate(zip(objects_info, predictions)):
        color = colors[i % len(colors)]

        # Convert normalized bbox to pixel coordinates
        bbox_norm = obj_info['bbox_norm']
        xc_n, yc_n, w_n, h_n = bbox_norm

        xc_px = int(xc_n * w)
        yc_px = int(yc_n * h)
        w_px = int(w_n * w)
        h_px = int(h_n * h)

        x1 = max(0, xc_px - w_px // 2)
        y1 = max(0, yc_px - h_px // 2)
        x2 = min(w, xc_px + w_px // 2)
        y2 = min(h, yc_px + h_px // 2)

        # Draw rectangle
        cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), color, 3)

        # Prepare labels
        class_id = obj_info['class_id']
        object_name = pred_info.get('object_name', f'Class_{class_id}')
        add_error = pred_info.get('add_error', 0)
        rot_errors = pred_info.get('rot_errors', {})

        label = f"{object_name}"
        metric_text1 = f"ADD: {add_error:.3f}m"
        metric_text2 = f"Rot: {rot_errors.get('overall', 0):.1f}°"
        metric_text3 = f"X:{rot_errors.get('x_axis', 0):.1f}° Y:{rot_errors.get('y_axis', 0):.1f}° Z:{rot_errors.get('z_axis', 0):.1f}°"

        # Draw text background and labels
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        thickness = 1

        # Calculate text sizes and draw background
        (w1, h1), _ = cv2.getTextSize(label, font, font_scale + 0.1, thickness + 1)
        (w2, h2), _ = cv2.getTextSize(metric_text1, font, font_scale, thickness)
        (w3, h3), _ = cv2.getTextSize(metric_text2, font, font_scale, thickness)
        (w4, h4), _ = cv2.getTextSize(metric_text3, font, font_scale - 0.1, thickness)

        max_width = max(w1, w2, w3, w4) + 10
        total_height = h1 + h2 + h3 + h4 + 20

        cv2.rectangle(image_with_boxes, (x1, y1 - total_height), (x1 + max_width, y1), color, -1)

        # Draw text lines
        text_y = y1 - total_height + h1 + 5
        cv2.putText(image_with_boxes, label, (x1 + 5, text_y), font, font_scale + 0.1, (255, 255, 255), thickness + 1)

        text_y += h2 + 3
        cv2.putText(image_with_boxes, metric_text1, (x1 + 5, text_y), font, font_scale, (255, 255, 255), thickness)

        text_y += h3 + 3
        cv2.putText(image_with_boxes, metric_text2, (x1 + 5, text_y), font, font_scale, (255, 255, 255), thickness)

        text_y += h4 + 3
        cv2.putText(image_with_boxes, metric_text3, (x1 + 5, text_y), font, font_scale - 0.1, (255, 255, 255), thickness)

    return image_with_boxes

def visualize_enhanced_all_objects(pose_model, dataset, sample_idx=0):
    """Enhanced visualization with rotation differences and original image"""
    print(f"🎯 Enhanced Visualization for sample {sample_idx}")

    # Validate inputs
    if pose_model is None:
        print("❌ Pose model is None - model loading failed")
        return None

    if not hasattr(dataset, 'rgb_paths'):
        print("❌ Invalid dataset - missing rgb_paths attribute")
        return None

    if sample_idx >= len(dataset.rgb_paths):
        print(f"❌ Sample index {sample_idx} out of range. Dataset has {len(dataset.rgb_paths)} samples")
        return None

    # Find all objects in the image
    objects_in_image = get_all_objects_in_image(dataset, sample_idx)
    if len(objects_in_image) == 0:
        print("❌ No objects found in image")
        return None

    # Load RGB image
    rgb_path = dataset.rgb_paths[sample_idx]
    rgb_image = cv2.imread(rgb_path)
    if rgb_image is not None:
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
    else:
        print(f"❌ Could not load RGB image: {rgb_path}")
        return None

    print(f"📷 Image: {os.path.basename(rgb_path)}")
    print(f"📦 Objects found: {len(objects_in_image)}")

    # Process each object
    pose_model.eval()
    predictions = []
    total_add = 0
    total_rot = 0
    valid_objects = 0

    with torch.no_grad():
        for obj_idx, object_info in enumerate(objects_in_image):
            sample = create_sample_for_specific_object(dataset, object_info, rgb_image)
            if sample is None:
                continue

            class_id = object_info['class_id']
            object_names = dataset.data_config.get('names', [])
            object_name = object_names[class_id] if class_id < len(object_names) else f"Class_{class_id}"

            # Get prediction
            rgb_batch = sample['rgb'].unsqueeze(0).to(config.DEVICE)
            points_batch = sample['points'].unsqueeze(0).to(config.DEVICE)

            try:
                pred_poses, pred_confs = pose_model(rgb_batch, points_batch)
                pred_pose = pred_poses[0].cpu().numpy()
                gt_pose = sample['gt_pose'].cpu().numpy()

                # Get model vertices
                if class_id not in dataset.object_models:
                    continue

                vertices = dataset.object_models[class_id]['vertices_raw']

                # Compute ADD and rotation differences
                if class_id in config.SYMMETRIC_LIST:
                    add_error, pred_points, gt_points = compute_add_s_visualization_fixed(
                        pred_pose, gt_pose, vertices
                    )
                else:
                    add_error, pred_points, gt_points = compute_add_visualization_fixed(
                        pred_pose, gt_pose, vertices
                    )

                rot_errors = compute_rotation_difference_degrees(pred_pose, gt_pose)

                if pred_points is None:
                    continue

                predictions.append({
                    'object_name': object_name,
                    'class_id': class_id,
                    'add_error': add_error,
                    'rot_errors': rot_errors,
                    'pred_points': pred_points,
                    'gt_points': gt_points,
                    'pred_pose': pred_pose,
                    'gt_pose': gt_pose
                })

                total_add += add_error
                total_rot += rot_errors['overall']
                valid_objects += 1

                print(f"   ✅ {object_name}: ADD={add_error:.4f}m, Rot={rot_errors['overall']:.1f}°")

            except Exception as e:
                print(f"   ❌ Prediction failed for {object_name}: {e}")
                continue

    if valid_objects == 0:
        print("❌ No valid predictions to visualize")
        return None

    # Create visualization
    print(f"\n🎨 Creating enhanced visualization...")

    # Calculate grid dimensions
    n_objects = len(predictions)
    if n_objects == 1:
        rows, cols = 1, 2
    elif n_objects == 2:
        rows, cols = 2, 2
    elif n_objects <= 4:
        rows, cols = 2, 3
    else:
        rows = int(np.ceil((n_objects + 1) / 3))
        cols = 3

    # Create subplot specifications
    subplot_specs = []
    subplot_titles = []

    # First subplot for original image
    subplot_specs.append([{"type": "xy"}])
    subplot_titles.append("Original Image with Detections")

    # Add 3D subplots for each object
    for pred in predictions:
        subplot_specs.append([{"type": "scene"}])
        subplot_titles.append(f"{pred['object_name']}")

    # Adjust specs for layout
    if len(subplot_specs) <= 3:
        spec_list = [[spec[0] for spec in subplot_specs]]
    else:
        spec_list = []
        for i in range(0, len(subplot_specs), cols):
            row_specs = []
            for j in range(cols):
                if i + j < len(subplot_specs):
                    row_specs.append(subplot_specs[i + j][0])
                else:
                    row_specs.append({"type": "xy"})
            spec_list.append(row_specs)

    # Create subplots
    fig = make_subplots(
        rows=len(spec_list),
        cols=cols,
        subplot_titles=subplot_titles[:len(predictions) + 1],
        specs=spec_list,
        horizontal_spacing=0.05,
        vertical_spacing=0.1
    )

    # Add original image with bounding boxes
    image_with_boxes = draw_bboxes_on_image(rgb_image, objects_in_image, predictions)
    fig.add_trace(go.Image(z=image_with_boxes), row=1, col=1)

    # Add 3D plots for each object
    for i, pred in enumerate(predictions):
        row = (i + 1) // cols + 1
        col = (i + 1) % cols + 1
        if col == 0:
            col = cols
            row -= 1

        # Colors for this object
        colors_gt = ['darkgreen', 'darkblue', 'darkred', 'darkorange', 'purple', 'brown']
        colors_pred = ['lightgreen', 'lightblue', 'lightcoral', 'orange', 'violet', 'tan']

        color_gt = colors_gt[i % len(colors_gt)]
        color_pred = colors_pred[i % len(colors_pred)]

        # Subsample points for performance
        gt_points = pred['gt_points']
        pred_points = pred['pred_points']

        if len(gt_points) > 800:
            indices = np.random.choice(len(gt_points), 800, replace=False)
            gt_viz = gt_points[indices]
            pred_viz = pred_points[indices]
        else:
            gt_viz = gt_points
            pred_viz = pred_points

        # Add GT and predicted points
        fig.add_trace(
            go.Scatter3d(
                x=gt_viz[:, 0], y=gt_viz[:, 1], z=gt_viz[:, 2],
                mode='markers',
                marker=dict(size=3, color=color_gt, opacity=0.8),
                name=f'GT_{pred["object_name"]}',
                showlegend=True
            ),
            row=row, col=col
        )

        fig.add_trace(
            go.Scatter3d(
                x=pred_viz[:, 0], y=pred_viz[:, 1], z=pred_viz[:, 2],
                mode='markers',
                marker=dict(size=3, color=color_pred, opacity=0.8),
                name=f'Pred_{pred["object_name"]}',
                showlegend=True
            ),
            row=row, col=col
        )

        # Add coordinate axes
        axis_length = 0.03
        fig.add_trace(go.Scatter3d(x=[0, axis_length], y=[0, 0], z=[0, 0], mode='lines',
                                   line=dict(color='red', width=4), showlegend=False), row=row, col=col)
        fig.add_trace(go.Scatter3d(x=[0, 0], y=[0, axis_length], z=[0, 0], mode='lines',
                                   line=dict(color='green', width=4), showlegend=False), row=row, col=col)
        fig.add_trace(go.Scatter3d(x=[0, 0], y=[0, 0], z=[0, axis_length], mode='lines',
                                   line=dict(color='blue', width=4), showlegend=False), row=row, col=col)

    # Update layout
    mean_add = total_add / valid_objects
    mean_rot = total_rot / valid_objects

    # Update scene properties for 3D subplots
    for i in range(len(predictions)):
        scene_name = f'scene{i+1}' if i > 0 else 'scene'
        fig.update_layout(**{
            scene_name: dict(
                xaxis_title='X (m)',
                yaxis_title='Y (m)',
                zaxis_title='Z (m)',
                aspectmode='cube',  # Back to cube for realistic proportions
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5),  # Moderate camera distance
                    center=dict(x=0, y=0, z=0),
                    up=dict(x=0, y=0, z=1)
                ),
                # Keep realistic axis ranges
                xaxis=dict(showgrid=True, gridwidth=1, gridcolor='lightgray'),
                yaxis=dict(showgrid=True, gridwidth=1, gridcolor='lightgray'),
                zaxis=dict(showgrid=True, gridwidth=1, gridcolor='lightgray'),
                bgcolor='white'
            )
        })

    # Update main layout with reasonable dimensions
    fig.update_layout(
        title=dict(
            text=f"Enhanced Pose Estimation - Sample {sample_idx}<br>" +
                 f"Objects: {valid_objects} | Mean ADD: {mean_add:.4f}m | Mean Rotation: {mean_rot:.1f}°",
            x=0.5,
            font=dict(size=14)
        ),
        height=700 if len(predictions) <= 2 else 1000,  # Reasonable height
        width=1600,  # Keep original width
        showlegend=True,
        margin=dict(l=20, r=20, t=80, b=20)  # Standard margins
    )

    # Hide axes for image subplot
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)

    # Show visualization
    fig.show()

    # Print summary
    print(f"\n✅ ENHANCED VISUALIZATION COMPLETE:")
    print(f"   📊 Objects visualized: {valid_objects}")
    print(f"   📈 Mean ADD: {mean_add:.6f}m")
    print(f"   🔄 Mean Rotation Error: {mean_rot:.1f}°")

    for i, pred in enumerate(predictions):
        rot_info = pred['rot_errors']
        print(f"      {i+1}. {pred['object_name']}:")
        print(f"         ADD: {pred['add_error']:.4f}m")
        print(f"         Rotation: {rot_info['overall']:.1f}° (X:{rot_info['x_axis']:.1f}° Y:{rot_info['y_axis']:.1f}° Z:{rot_info['z_axis']:.1f}°)")

    return {
        'mean_add': mean_add,
        'mean_rotation': mean_rot,
        'num_objects': valid_objects,
        'individual_results': predictions,
        'image_path': rgb_path
    }

    with torch.no_grad():
        for obj_idx, object_info in enumerate(objects_in_image):
            sample = create_sample_for_specific_object(dataset, object_info, rgb_image)
            if sample is None:
                continue

            class_id = object_info['class_id']
            object_names = dataset.data_config.get('names', [])
            object_name = object_names[class_id] if class_id < len(object_names) else f"Class_{class_id}"

            # Get prediction
            rgb_batch = sample['rgb'].unsqueeze(0).to(config.DEVICE)
            points_batch = sample['points'].unsqueeze(0).to(config.DEVICE)

            try:
                pred_poses, pred_confs = pose_model(rgb_batch, points_batch)
                pred_pose = pred_poses[0].cpu().numpy()
                gt_pose = sample['gt_pose'].cpu().numpy()

                # Get model vertices
                if class_id not in dataset.object_models:
                    continue

                vertices = dataset.object_models[class_id]['vertices']

                # Compute ADD and rotation differences
                if class_id in config.SYMMETRIC_LIST:
                    add_error, pred_points, gt_points = compute_add_s_visualization_fixed(
                        pred_pose, gt_pose, vertices
                    )
                else:
                    add_error, pred_points, gt_points = compute_add_visualization_fixed(
                        pred_pose, gt_pose, vertices
                    )

                rot_errors = compute_rotation_difference_degrees(pred_pose, gt_pose)

                if pred_points is None:
                    continue

                predictions.append({
                    'object_name': object_name,
                    'class_id': class_id,
                    'add_error': add_error,
                    'rot_errors': rot_errors,
                    'pred_points': pred_points,
                    'gt_points': gt_points,
                    'pred_pose': pred_pose,
                    'gt_pose': gt_pose
                })

                total_add += add_error
                total_rot += rot_errors['overall']
                valid_objects += 1

                print(f"   ✅ {object_name}: ADD={add_error:.4f}m, Rot={rot_errors['overall']:.1f}°")

            except Exception as e:
                print(f"   ❌ Prediction failed for {object_name}: {e}")
                continue

    if valid_objects == 0:
        print("❌ No valid predictions to visualize")
        return None

    # Create visualization
    print(f"\n🎨 Creating enhanced visualization...")

    # Calculate grid dimensions
    n_objects = len(predictions)
    if n_objects == 1:
        rows, cols = 1, 2
    elif n_objects == 2:
        rows, cols = 2, 2
    elif n_objects <= 4:
        rows, cols = 2, 3
    else:
        rows = int(np.ceil((n_objects + 1) / 3))
        cols = 3

    # Create subplot specifications
    subplot_specs = []
    subplot_titles = []

    # First subplot for original image
    subplot_specs.append([{"type": "xy"}])
    subplot_titles.append("Original Image with Detections")

    # Add 3D subplots for each object
    for pred in predictions:
        subplot_specs.append([{"type": "scene"}])
        subplot_titles.append(f"{pred['object_name']}")

    # Adjust specs for layout
    if len(subplot_specs) <= 3:
        spec_list = [[spec[0] for spec in subplot_specs]]
    else:
        spec_list = []
        for i in range(0, len(subplot_specs), cols):
            row_specs = []
            for j in range(cols):
                if i + j < len(subplot_specs):
                    row_specs.append(subplot_specs[i + j][0])
                else:
                    row_specs.append({"type": "xy"})
            spec_list.append(row_specs)

    # Create subplots
    fig = make_subplots(
        rows=len(spec_list),
        cols=cols,
        subplot_titles=subplot_titles[:len(predictions) + 1],
        specs=spec_list,
        horizontal_spacing=0.05,
        vertical_spacing=0.1
    )

    # Add original image with bounding boxes
    image_with_boxes = draw_bboxes_on_image(rgb_image, objects_in_image, predictions)
    fig.add_trace(go.Image(z=image_with_boxes), row=1, col=1)

    # Add 3D plots for each object
    for i, pred in enumerate(predictions):
        row = (i + 1) // cols + 1
        col = (i + 1) % cols + 1
        if col == 0:
            col = cols
            row -= 1

        # Colors for this object
        colors_gt = ['darkgreen', 'darkblue', 'darkred', 'darkorange', 'purple', 'brown']
        colors_pred = ['lightgreen', 'lightblue', 'lightcoral', 'orange', 'violet', 'tan']

        color_gt = colors_gt[i % len(colors_gt)]
        color_pred = colors_pred[i % len(colors_pred)]

        # Subsample points for performance
        gt_points = pred['gt_points']
        pred_points = pred['pred_points']

        if len(gt_points) > 800:
            indices = np.random.choice(len(gt_points), 800, replace=False)
            gt_viz = gt_points[indices]
            pred_viz = pred_points[indices]
        else:
            gt_viz = gt_points
            pred_viz = pred_points

        # Add GT and predicted points
        fig.add_trace(
            go.Scatter3d(
                x=gt_viz[:, 0], y=gt_viz[:, 1], z=gt_viz[:, 2],
                mode='markers',
                marker=dict(size=3, color=color_gt, opacity=0.8),
                name=f'GT_{pred["object_name"]}',
                showlegend=True
            ),
            row=row, col=col
        )

        fig.add_trace(
            go.Scatter3d(
                x=pred_viz[:, 0], y=pred_viz[:, 1], z=pred_viz[:, 2],
                mode='markers',
                marker=dict(size=3, color=color_pred, opacity=0.8),
                name=f'Pred_{pred["object_name"]}',
                showlegend=True
            ),
            row=row, col=col
        )

        # Add coordinate axes
        axis_length = 0.03
        fig.add_trace(go.Scatter3d(x=[0, axis_length], y=[0, 0], z=[0, 0], mode='lines',
                                   line=dict(color='red', width=4), showlegend=False), row=row, col=col)
        fig.add_trace(go.Scatter3d(x=[0, 0], y=[0, axis_length], z=[0, 0], mode='lines',
                                   line=dict(color='green', width=4), showlegend=False), row=row, col=col)
        fig.add_trace(go.Scatter3d(x=[0, 0], y=[0, 0], z=[0, axis_length], mode='lines',
                                   line=dict(color='blue', width=4), showlegend=False), row=row, col=col)

    # Update layout
    mean_add = total_add / valid_objects
    mean_rot = total_rot / valid_objects

    # Update scene properties for 3D subplots
    for i in range(len(predictions)):
        scene_name = f'scene{i+1}' if i > 0 else 'scene'
        fig.update_layout(**{
            scene_name: dict(
                xaxis_title='X (m)',
                yaxis_title='Y (m)',
                zaxis_title='Z (m)',
                aspectmode='cube',
                camera=dict(
                    eye=dict(x=1.2, y=1.2, z=1.2),
                    center=dict(x=0, y=0, z=0),
                    up=dict(x=0, y=0, z=1)
                )
            )
        })

    # Update main layout
    fig.update_layout(
        title=f"Enhanced Pose Estimation - Sample {sample_idx}<br>" +
              f"Objects: {valid_objects} | Mean ADD: {mean_add:.4f}m | Mean Rotation: {mean_rot:.1f}°",
        height=600 if len(predictions) <= 2 else 900,
        width=1600,
        showlegend=True
    )

    # Hide axes for image subplot
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)

    # Show visualization
    fig.show()

    # Print summary
    print(f"\n✅ ENHANCED VISUALIZATION COMPLETE:")
    print(f"   📊 Objects visualized: {valid_objects}")
    print(f"   📈 Mean ADD: {mean_add:.6f}m")
    print(f"   🔄 Mean Rotation Error: {mean_rot:.1f}°")

    for i, pred in enumerate(predictions):
        rot_info = pred['rot_errors']
        print(f"      {i+1}. {pred['object_name']}:")
        print(f"         ADD: {pred['add_error']:.4f}m")
        print(f"         Rotation: {rot_info['overall']:.1f}° (X:{rot_info['x_axis']:.1f}° Y:{rot_info['y_axis']:.1f}° Z:{rot_info['z_axis']:.1f}°)")

    return {
        'mean_add': mean_add,
        'mean_rotation': mean_rot,
        'num_objects': valid_objects,
        'individual_results': predictions,
        'image_path': rgb_path
    }

def create_individual_3d_plots(pose_model, dataset, sample_idx=0):
    """Create individual 3D plots for each object with realistic proportions"""
    print(f"🎯 Creating individual 3D plots for sample {sample_idx}")

    # Validate inputs
    if pose_model is None:
        print("❌ Pose model is None")
        return None

    if not hasattr(dataset, 'rgb_paths') or sample_idx >= len(dataset.rgb_paths):
        print("❌ Invalid dataset or sample index")
        return None

    # Get objects and process them
    objects_in_image = get_all_objects_in_image(dataset, sample_idx)
    if len(objects_in_image) == 0:
        print("❌ No objects found")
        return None

    # Load RGB image
    rgb_path = dataset.rgb_paths[sample_idx]
    rgb_image = cv2.imread(rgb_path)
    if rgb_image is not None:
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
    else:
        print(f"❌ Could not load image: {rgb_path}")
        return None

    print(f"📷 Image: {os.path.basename(rgb_path)}")
    print(f"📦 Objects found: {len(objects_in_image)}")

    # Process predictions
    pose_model.eval()
    predictions = []

    with torch.no_grad():
        for obj_idx, object_info in enumerate(objects_in_image):
            sample = create_sample_for_specific_object(dataset, object_info, rgb_image)
            if sample is None:
                continue

            class_id = object_info['class_id']
            object_names = dataset.data_config.get('names', [])
            object_name = object_names[class_id] if class_id < len(object_names) else f"Class_{class_id}"

            # Get prediction
            rgb_batch = sample['rgb'].unsqueeze(0).to(config.DEVICE)
            points_batch = sample['points'].unsqueeze(0).to(config.DEVICE)

            try:
                pred_poses, pred_confs = pose_model(rgb_batch, points_batch)
                pred_pose = pred_poses[0].cpu().numpy()
                gt_pose = sample['gt_pose'].cpu().numpy()

                if class_id not in dataset.object_models:
                    continue

                vertices = dataset.object_models[class_id]['vertices_raw']
                # Compute ADD and rotation differences
                if class_id in config.SYMMETRIC_LIST:
                    add_error, pred_points, gt_points = compute_add_s_visualization_fixed(
                        pred_pose, gt_pose, vertices
                    )
                else:
                    add_error, pred_points, gt_points = compute_add_visualization_fixed(
                        pred_pose, gt_pose, vertices
                    )
                rot_errors = compute_rotation_difference_degrees(pred_pose, gt_pose)

                if pred_points is not None:
                    predictions.append({
                        'object_name': object_name,
                        'class_id': class_id,
                        'add_error': add_error,
                        'rot_errors': rot_errors,
                        'pred_points': pred_points,
                        'gt_points': gt_points
                    })

                    print(f"   ✅ {object_name}: ADD={add_error:.4f}m, Rot={rot_errors['overall']:.1f}°")

            except Exception as e:
                print(f"   ❌ Failed for {object_name}: {e}")
                continue

    if len(predictions) == 0:
        print("❌ No valid predictions")
        return None

    # Create individual 3D plots for each object
    print(f"🎨 Creating {len(predictions)} individual 3D plots...")

    for i, pred in enumerate(predictions):
        # Create individual figure for each object
        fig = go.Figure()

        # Get points
        gt_points = pred['gt_points']
        pred_points = pred['pred_points']

        # Subsample for performance
        if len(gt_points) > 800:
            indices = np.random.choice(len(gt_points), 800, replace=False)
            gt_viz = gt_points[indices]
            pred_viz = pred_points[indices]
        else:
            gt_viz = gt_points
            pred_viz = pred_points

        # Add ground truth points
        fig.add_trace(go.Scatter3d(
            x=gt_viz[:, 0], y=gt_viz[:, 1], z=gt_viz[:, 2],
            mode='markers',
            marker=dict(size=4, color='darkgreen', opacity=0.8),
            name='Ground Truth',
            showlegend=True
        ))

        # Add predicted points
        fig.add_trace(go.Scatter3d(
            x=pred_viz[:, 0], y=pred_viz[:, 1], z=pred_viz[:, 2],
            mode='markers',
            marker=dict(size=4, color='red', opacity=0.7),
            name='Prediction',
            showlegend=True
        ))

        # Add coordinate axes
        axis_length = 0.05
        fig.add_trace(go.Scatter3d(
            x=[0, axis_length], y=[0, 0], z=[0, 0],
            mode='lines',
            line=dict(color='red', width=6),
            name='X-axis',
            showlegend=True
        ))
        fig.add_trace(go.Scatter3d(
            x=[0, 0], y=[0, axis_length], z=[0, 0],
            mode='lines',
            line=dict(color='green', width=6),
            name='Y-axis',
            showlegend=True
        ))
        fig.add_trace(go.Scatter3d(
            x=[0, 0], y=[0, 0], z=[0, axis_length],
            mode='lines',
            line=dict(color='blue', width=6),
            name='Z-axis',
            showlegend=True
        ))

        # Update layout with realistic proportions
        fig.update_layout(
            title=dict(
                text=f"{pred['object_name']} - Sample {sample_idx}<br>" +
                     f"ADD: {pred['add_error']:.4f}m | Rotation: {pred['rot_errors']['overall']:.1f}°<br>" +
                     f"Per-axis Rot: X:{pred['rot_errors']['x_axis']:.1f}° Y:{pred['rot_errors']['y_axis']:.1f}° Z:{pred['rot_errors']['z_axis']:.1f}°",
                x=0.5,
                font=dict(size=14)
            ),
            scene=dict(
                xaxis=dict(
                    title='X (meters)',
                    showgrid=True,
                    gridwidth=1,
                    gridcolor='lightgray'
                ),
                yaxis=dict(
                    title='Y (meters)',
                    showgrid=True,
                    gridwidth=1,
                    gridcolor='lightgray'
                ),
                zaxis=dict(
                    title='Z (meters)',
                    showgrid=True,
                    gridwidth=1,
                    gridcolor='lightgray'
                ),
                aspectmode='cube',  # Keep realistic proportions
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5),  # Standard camera position
                    center=dict(x=0, y=0, z=0),
                    up=dict(x=0, y=0, z=1)
                ),
                bgcolor='white'
            ),
            width=900,
            height=700,
            showlegend=True,
            margin=dict(l=0, r=0, t=100, b=0)
        )

        # Show the plot
        fig.show()

        print(f"📊 Plot {i+1}/{len(predictions)}: {pred['object_name']}")

    # Also create the original image with bounding boxes
    image_with_boxes = draw_bboxes_on_image(rgb_image, objects_in_image, predictions)

    # Show original image
    fig_img = go.Figure()
    fig_img.add_trace(go.Image(z=image_with_boxes))
    fig_img.update_layout(
        title=f"Original Image with Detections - Sample {sample_idx}",
        width=800,
        height=600,
        xaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
        yaxis=dict(showticklabels=False, showgrid=False, zeroline=False)
    )
    fig_img.show()

    return predictions

In [None]:
# ==============================================================================
# BLOCK 12: MAIN EXECUTION PIPELINE
# ==============================================================================

def main_pipeline(mode='complete'):
    """
    Main execution pipeline for DenseFusion inspired network

    Args:
        mode: 'complete' (train+eval), 'train_only', 'eval_only', or 'visualize_only'
    """
    print("="*80)
    print("DENSEFUSION 6D POSE ESTIMATION - MAIN PIPELINE")
    print("="*80)

    print("Configuration Summary:")
    config.print_config()
    print("-"*80)

    results = {}

    try:
        # Step 1: Verify setup
        print("\nStep 1: Verifying Setup")
        print("-" * 40)

        if not config.verify_paths():
            print("⚠ Please fix path configuration before proceeding")
            return None

        print("✓ All paths verified")

        # Step 2: Training (if requested)
        if mode in ['complete', 'train_only']:
            print("\nStep 2: Training DenseFusion Model")
            print("-" * 40)

            try:
                training_results = train_densefusion()
                results['training'] = training_results
                print("✓ Training completed successfully")
            except Exception as e:
                print(f"✗ Training failed: {e}")
                if mode == 'train_only':
                    return results
        else:
            print("\nStep 2: Skipping Training")
            print("-" * 40)

        # Step 3: Evaluation (if requested)
        if mode in ['complete', 'eval_only']:
            print("\nStep 3: Model Evaluation")
            print("-" * 40)

            try:
                eval_results = run_complete_evaluation()
                if eval_results is not None:
                    results['evaluation'] = eval_results
                    print("✓ Evaluation completed successfully")

                    # Print key metrics
                    print(f"\nKey Results:")
                    print(f"  Detection Rate: {eval_results.get('detection_rate', 0):.3f}")
                    print(f"  Mean ADD Error: {eval_results.get('mean_add', float('inf')):.4f} m")
                    print(f"  Success Rate (<5cm): {eval_results.get('success_rate_5cm', 0):.3f}")
                    print(f"  Success Rate (<10cm): {eval_results.get('success_rate_10cm', 0):.3f}")
                else:
                    print("✗ Evaluation failed")

            except Exception as e:
                print(f"✗ Evaluation error: {e}")
        else:
            print("\nStep 3: Skipping Evaluation")
            print("-" * 40)

        # Step 4: Visualization (if requested)
        if mode in ['complete', 'visualize_only']:
            print("\nStep 4: Enhanced Visualization")
            print("-" * 40)

            try:
                # Run visualization on a few sample indices
                sample_indices = [0, 100, 200, 500, 1000]
                viz_results = visualize_enhanced_all_objects(sample_idx=sample_indices[0]) #visualize only 1
                results['visualization'] = viz_results
                print(f"✓ Visualization completed for {len(viz_results)} samples")

            except Exception as e:
                print(f"✗ Visualization error: {e}")
        else:
            print("\nStep 4: Skipping Visualization")
            print("-" * 40)

        # Step 5: Final Summary
        print("\nStep 5: Final Summary")
        print("-" * 40)

        if results:
            print("Pipeline completed successfully!")

            if 'training' in results:
                training = results['training']
                print(f"\nTraining Results:")
                print(f"  Epochs completed: {len(training.get('train_losses', []))}")
                print(f"  Best validation loss: {training.get('best_val_loss', 'N/A'):.6f}")
                print(f"  Training time: {training.get('total_time', 0):.1f} seconds")

            if 'evaluation' in results:
                evaluation = results['evaluation']
                print(f"\nEvaluation Results:")
                print(f"  Detection rate: {evaluation.get('detection_rate', 0):.3f}")
                print(f"  Mean ADD error: {evaluation.get('mean_add', float('inf')):.4f} m")
                print(f"  Success rate (<5cm): {evaluation.get('success_rate_5cm', 0):.3f}")
                print(f"  Success rate (<10cm): {evaluation.get('success_rate_10cm', 0):.3f}")

            if 'visualization' in results:
                viz = results['visualization']
                print(f"\nVisualization Results:")
                print(f"  Samples visualized: {len(viz)}")
                if viz:
                    mean_add = np.mean([r['mean_add'] for r in viz if 'mean_add' in r])
                    mean_rot = np.mean([r['mean_rotation'] for r in viz if 'mean_rotation' in r])
                    print(f"  Average ADD error: {mean_add:.4f} m")
                    print(f"  Average rotation error: {mean_rot:.1f}°")
        else:
            print("Pipeline completed with no results")

        print(f"\nFiles saved to: {config.MODELS_SAVE_DIR}")
        print("="*80)
        return results

    except Exception as e:
        print(f"\n✗ Pipeline failed with error: {e}")
        import traceback
        traceback.print_exc()
        return results

def quick_test():
    """Quick test of all components"""
    print("="*60)
    print("QUICK TEST - DENSEFUSION COMPONENTS")
    print("="*60)

    tests = [
        ("Configuration", lambda: config.verify_paths()),
        ("Dataset Loading", lambda: test_dataset() is not None),
        ("Model Architecture", lambda: test_model_architecture() is not None),
        ("Loss Function", lambda: test_loss_function() is not None),
    ]

    results = {}
    for test_name, test_func in tests:
        print(f"\nTesting {test_name}...")
        try:
            result = test_func()
            results[test_name] = result
            status = "✓ PASS" if result else "✗ FAIL"
            print(f"{test_name}: {status}")
        except Exception as e:
            results[test_name] = False
            print(f"{test_name}: ✗ FAIL ({e})")

    print(f"\n" + "="*60)
    print("QUICK TEST SUMMARY")
    print("="*60)

    passed = sum(1 for r in results.values() if r)
    total = len(results)

    for test_name, result in results.items():
        status = "✓" if result else "✗"
        print(f"  {status} {test_name}")

    print(f"\nOverall: {passed}/{total} tests passed")

    if passed == total:
        print("🎉 All tests passed! Ready to run main pipeline.")
        return True
    else:
        print("⚠ Some tests failed. Fix issues before running main pipeline.")
        return False

def load_and_use_pretrained_model(model_path=None, sample_idx=0):
    """Load a pretrained model and run inference on a sample"""
    print("="*60)
    print("LOADING AND USING PRETRAINED MODEL")
    print("="*60)

    # Load model
    model = load_trained_model(model_path, use_transformer=config.USE_TRANSFORMER_FUSIO)
    if model is None:
        print("❌ Failed to load model")
        return None

    # Load dataset for testing
    dataset_config = load_dataset_config(config.LINEMOD_ROOT)
    test_dataset = DenseFusionDataset(
        dataset_config, split='test', use_segmentation=config.USE_SEGMENTATION
    )

    print(f"✓ Model loaded successfully")
    print(f"✓ Test dataset loaded: {len(test_dataset)} samples")

    # Run inference on a sample
    try:
        sample = test_dataset[sample_idx]
        rgb = sample['rgb'].unsqueeze(0).to(config.DEVICE)
        points = sample['points'].unsqueeze(0).to(config.DEVICE)
        gt_pose = sample['gt_pose']
        class_id = sample['class_id'].item()

        model.eval()
        with torch.no_grad():
            pred_pose, pred_conf = model(rgb, points)

        pred_pose_np = pred_pose.cpu().numpy().flatten()
        confidence = torch.sigmoid(pred_conf).cpu().numpy().item()

        print(f"\n📊 Inference Results for Sample {sample_idx}:")
        print(f"  Class ID: {class_id}")
        print(f"  Predicted pose: {pred_pose_np}")
        print(f"  Ground truth pose: {gt_pose.numpy()}")
        print(f"  Confidence: {confidence:.4f}")

        # Compute ADD if model available
        if class_id in test_dataset.object_models:
            vertices = test_dataset.object_models[class_id]['vertices_raw']
            # Compute ADD and rotation differences
            if class_id in config.SYMMETRIC_LIST:
                add_error = compute_add_s_metric(pred_pose_np, gt_pose.numpy(), vertices)
            else:
                add_error = compute_add_metric(pred_pose_np, gt_pose.numpy(), vertices)
            print(f"  ADD error: {add_error:.6f} m")

        return {
            'model': model,
            'sample_idx': sample_idx,
            'pred_pose': pred_pose_np,
            'gt_pose': gt_pose.numpy(),
            'confidence': confidence,
            'class_id': class_id
        }

    except Exception as e:
        print(f"❌ Inference failed: {e}")
        return None

def simple_visualization(sample_idx=0):
    """Simple visualization function for quick testing"""
    try:
        # Load necessary components
        dataset_config = load_dataset_config(config.LINEMOD_ROOT)
        test_dataset = DenseFusionDataset(dataset_config, split='test', use_segmentation=config.USE_SEGMENTATION)
        model = load_trained_model(use_transformer=config.USE_TRANSFORMER_FUSION)

        if model is None:
            print("⚠ No trained model found for visualization")
            return None

        # Run enhanced visualization
        result = visualize_enhanced_all_objects(model, test_dataset, sample_idx)
        return result

    except Exception as e:
        print(f"❌ Visualization failed: {e}")
        return None


def train_model(use_transformer=config.USE_TRANSFORMER_FUSION):
    """Simple function to just train the model"""
    if use_transformer is not None:
        use_transformer=config.USE_TRANSFORMER_FUSION
    return main_pipeline(mode='train_only')

def evaluate_model(use_transformer=config.USE_TRANSFORMER_FUSION):
    """Simple function to just evaluate the model"""
    if use_transformer is not None:
        use_transformer=config.USE_TRANSFORMER_FUSION
    return main_pipeline(mode='eval_only')

def visualize_model(use_transformer=config.USE_TRANSFORMER_FUSION):
    """Simple function to just run visualization"""
    if use_transformer is not None:
        use_transformer=config.USE_TRANSFORMER_FUSION
    return main_pipeline(mode='visualize_only')

def run_full_pipeline(use_transformer=config.USE_TRANSFORMER_FUSION):
    """Run the complete pipeline"""
    if use_transformer is not None:
        use_transformer=config.USE_TRANSFORMER_FUSION
    return main_pipeline(mode='complete')

# Print usage information
print("✓ Block 12 completed: Main execution pipeline ready")
print("\n" + "="*80)
print("DENSEFUSION IMPLEMENTATION COMPLETE")
print("="*80)
print("\nAvailable functions:")
print("  🧪 quick_test() - Test all components")
print("  🚀 run_full_pipeline() - Complete training and evaluation")
print("  🏋️ train_model() - Training only")
print("  📊 evaluate_model() - Evaluation only")
print("  🎨 visualize_model() - Visualization only")
print("  📱 simple_visualization(sample_idx) - Quick visualization")
print("  🔧 load_and_use_pretrained_model(model_path) - Load and test pretrained model")
print("\nRecommended usage:")
print("  1. quick_test() - Verify everything works")
print("  2. run_full_pipeline() - Complete pipeline")
print("  3. simple_visualization(sample_idx) - Visualize results")
print("\nTo load existing model:")
print("  result = load_and_use_pretrained_model('/path/to/model.pth')")
print("="*80)

In [None]:
# Run a comprehensive test for all the previous blocks
quick_test()

In [None]:
# ==============================================================================
# BLOCK 13: MODEL TRAINING
# ==============================================================================
run_full_pipeline()

In [None]:
# ==============================================================================
# BLOCK 14: MODEL EVALUATION
# ==============================================================================

# Load dataset config first
dataset_config = load_dataset_config(config.LINEMOD_ROOT)

#IF NEEDED ANOTHER NAME FOR THE MODEL
#config.MODEL_NAME = "densefusion_custom_v2"  # Your custom name

# Create and load model
pose_model = DenseFusionNetwork(num_objects=13,use_transformer=config.USE_TRANSFORMER_FUSION)
pose_model.load_state_dict(torch.load(os.path.join(config.MODELS_SAVE_DIR, config.MODELS_NAME+'_densefusion_best.pth'), map_location=config.DEVICE))
pose_model = pose_model.to(config.DEVICE)
pose_model.eval()

# Create test dataset
test_dataset = DenseFusionDataset(
    dataset_config,
    split='test',
    num_points=config.NUM_POINTS,
    patch_size=config.PATCH_SIZE,
    use_segmentation=config.USE_SEGMENTATION
)

# Load YOLO model and diameters
yolo_model = load_yolo_model(config.YOLO_MODEL_PATH)
model_diameters = load_model_diameters(config.DIAMETER_INFO_PATH)

# Set evaluation to ALL samples
config.MAX_EVAL_SAMPLES = len(test_dataset)

# Run evaluation on all test data
metrics = evaluate_model_comprehensive(yolo_model, pose_model, test_dataset, model_diameters)

# Print comprehensive results
print(f"Complete Evaluation Results on {len(test_dataset)} samples:")
print(f"Detection rate: {metrics['detection_rate']:.3f}")
print(f"Mean ADD error: {metrics['mean_add']:.4f} m")
print(f"Success rate (<2cm): {metrics['success_rate_2cm']:.3f}")
print(f"Success rate (<5cm): {metrics['success_rate_5cm']:.3f}")
print(f"Success rate (<10cm): {metrics['success_rate_10cm']:.3f}")

In [None]:
# ==============================================================================
# BLOCK 15: SAMPLE VISUALIZATION
# ==============================================================================

# Load dataset config first #RESULTS WITH TRANSFORMERS
config.USE_TRANSFORMER_FUSION
config.USE_SEGMENTATION
dataset_config = load_dataset_config(config.LINEMOD_ROOT)

# Create and load model - FIXED
pose_model = DenseFusionNetwork(num_objects=13, use_transformer=config.USE_TRANSFORMER_FUSION)
pose_model.load_state_dict(torch.load(os.path.join(config.MODELS_SAVE_DIR, 'EC_500_512_MLP_densefusion_best.pth'), map_location=config.DEVICE))
pose_model = pose_model.to(config.DEVICE)
pose_model.eval()

# Create test dataset - FIXED
test_dataset = DenseFusionDataset(
    dataset_config,
    split='test',
    num_points=config.NUM_POINTS,      # lowercase config
    patch_size=config.PATCH_SIZE,      # lowercase config
    use_segmentation=config.USE_SEGMENTATION  # lowercase config
)

# Visualize
#result = create_individual_3d_plots(pose_model, test_dataset, sample_idx=1)
result=create_individual_3d_plots(pose_model, test_dataset, sample_idx=1)