In [None]:
!pip install ultralytics trimesh

In [None]:
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
import numpy as np
import os
import csv
import yaml
import trimesh
from ultralytics import YOLO
from PIL import Image
from collections import defaultdict
from pathlib import Path
from enum import Enum

*Pretrained Models and Dataset Paths*

This notebook assumes that all models have already been trained.
The following paths point to pretrained weights and dataset locations and are provided as placeholders for demonstration purposes.

Please replace each `path/to/..`. entry with the corresponding local path before running the notebook.

In [None]:
yolo_weights_path = "path/to/yolo_weights.pt"
rotation_extension_weights_path = "path/to/rot_ext_weights.pth"
translation_extension_weights_path = "path/to/trans_ext_weights.pth"
yolo_dataset = "path/to/yolo/data.yaml"
data_root = "path/to/dataset"

# Models

*RGB-D Rotation Estimation Network*

The following model implements a late-fusion architecture for rotation estimation using both RGB and depth information.

The RGB branch is based on a ResNet-50 encoder pretrained on ImageNet.

The depth branch is a lightweight convolutional network that processes single-channel depth maps.

Features from both modalities are concatenated and passed through a fully connected head to predict a 4D unit quaternion representing object rotation.

In [None]:
class DepthNet(nn.Module):
    def __init__(self):
        super(DepthNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2); self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1); self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1); self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1); self.bn4 = nn.BatchNorm2d(256)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1); self.bn5 = nn.BatchNorm2d(512)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.pool(x)
        return x.view(x.size(0), -1)


class DepthRotationNet(nn.Module):
    def __init__(self, pretrained=True):
        super(DepthRotationNet, self).__init__()

        try:
            weights = models.ResNet50_Weights.DEFAULT if pretrained else None
            base_resnet = models.resnet50(weights=weights)
        except:
            base_resnet = models.resnet50(pretrained=pretrained)

        self.rgb_encoder = nn.Sequential(*list(base_resnet.children())[:-1])
        self.depth_encoder = DepthNet()

        # Fusion & Heads
        self.fc1 = nn.Linear(2048 + 512, 1024)
        self.drop = nn.Dropout(0.3)

        # Output Head
        self.head = nn.Linear(1024, 4)  # Quaternion (4D)

    def forward(self, rgb, depth):
        f_rgb = self.rgb_encoder(rgb).view(rgb.size(0), -1)
        f_depth = self.depth_encoder(depth)

        # Concatenate features
        f_fused = torch.cat((f_rgb, f_depth), dim=1)

        x = F.relu(self.fc1(f_fused))
        x = self.drop(x)

        return F.normalize(self.head(x), p=2, dim=1)

*Encoder-Decoder Network for Translation Estimation*

This encoder-decoder architecture predicts a dense per-pixel weight map used for translation estimation.

The input consists of concatenated RGB and depth information (6 channels in total).

The output is a single-channel, unnormalized weight map that is later used to aggregate pixel-level information for estimating object translation.

In [None]:
class EncoderDecoderWeightsNet(nn.Module):
    def __init__(self):
        super().__init__()

        # ----- Encoder -----
        self.conv1 = nn.Conv2d(6, 16, kernel_size=3, stride=1, padding=1); self.bn1 = nn.BatchNorm2d(16)

        self.down1 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1); self.bn2 = nn.BatchNorm2d(32)
        self.down2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1); self.bn3 = nn.BatchNorm2d(64)

        self.down3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1); self.bn4 = nn.BatchNorm2d(128)

        # ----- Decoder -----
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1); self.bn5 = nn.BatchNorm2d(64)
        self.up2 = nn.ConvTranspose2d(128, 32, kernel_size=4, stride=2, padding=1); self.bn6 = nn.BatchNorm2d(32)
        self.up3 = nn.ConvTranspose2d(64, 16, kernel_size=4, stride=2, padding=1); self.bn7 = nn.BatchNorm2d(16)

        # Output: 1 unnormalized weight per pixel
        self.out_conv = nn.Conv2d(16, 1, kernel_size=1)

    def forward(self, x):
        # ----- Encoder -----
        x1 = F.relu(self.bn1(self.conv1(x)))    # 64x64x16
        x2 = F.relu(self.bn2(self.down1(x1)))   # 32x32x32
        x3 = F.relu(self.bn3(self.down2(x2)))   # 16x16x64
        x4 = F.relu(self.bn4(self.down3(x3)))   # 8x8x128

        # ----- Decoder -----
        u1 = F.relu(self.bn5(self.up1(x4)))     # 16x16x64
        u1 = torch.cat([u1, x3], dim=1)         # 16x16x128
        u2 = F.relu(self.bn6(self.up2(u1)))     # 32x32x32
        u2 = torch.cat([u2, x2], dim=1)         # 32x32x64
        u3 = F.relu(self.bn7(self.up3(u2)))     # 64x64x16
        w = self.out_conv(u3)                   # 64x64x1
        return w

# Utility Functions

The following section contains a collection of low-level utility functions used throughout the pipeline.

These utilities handle:
- object symmetries (LineMOD-specific),
- quaternion and rotation algebra,
- 3D point cloud metrics,
- spatial grids and soft attention,
- pinhole camera geometry for depth-based translation estimation.

You can safely skip this entire section when reading the notebook.
The functions are provided here only to keep the notebook fully self-contained and runnable without external imports.

*Object Symmetries (LineMOD)*

Some LineMOD objects exhibit discrete rotational symmetries (e.g., eggbox, glue).

The following definitions specify:
- which objects are symmetric,
- the set of equivalent quaternions used during evaluation.

These symmetries are taken into account when computing rotation errors to avoid penalizing physically equivalent poses.

In [None]:
class SymmetryType(Enum):
    NONE = 0
    DISCRETE = 1


LINEMOD_SYMMETRIES = {
    7: SymmetryType.DISCRETE,  # eggbox
    8: SymmetryType.DISCRETE,  # glue
}

SYMMETRIC_QUATS = {
    7: torch.tensor([
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],   # 180° z
    ]),
    8: torch.tensor([
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
    ]),
}

In [None]:
def rotation_matrix_to_quaternion(R):
    trace = R.trace()

    if trace > 0:
        s = torch.sqrt(trace + 1.0) * 2
        qw = 0.25 * s
        qx = (R[2, 1] - R[1, 2]) / s
        qy = (R[0, 2] - R[2, 0]) / s
        qz = (R[1, 0] - R[0, 1]) / s
    else:
        if R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]:
            s = torch.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) * 2
            qw = (R[2, 1] - R[1, 2]) / s
            qx = 0.25 * s
            qy = (R[0, 1] + R[1, 0]) / s
            qz = (R[0, 2] + R[2, 0]) / s
        elif R[1, 1] > R[2, 2]:
            s = torch.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) * 2
            qw = (R[0, 2] - R[2, 0]) / s
            qx = (R[0, 1] + R[1, 0]) / s
            qy = 0.25 * s
            qz = (R[1, 2] + R[2, 1]) / s
        else:
            s = torch.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) * 2
            qw = (R[1, 0] - R[0, 1]) / s
            qx = (R[0, 2] + R[2, 0]) / s
            qy = (R[1, 2] + R[2, 1]) / s
            qz = 0.25 * s

    q = torch.tensor([qw, qx, qy, qz], dtype=torch.float32)
    return q / torch.norm(q)

def quaternion_to_rotation_matrix(q):
    q = q / q.norm()
    w, x, y, z = q
    return torch.tensor([
        [1-2*(y*y+z*z), 2*(x*y-z*w),   2*(x*z+y*w)],
        [2*(x*y+z*w),   1-2*(x*x+z*z), 2*(y*z-x*w)],
        [2*(x*z-y*w),   2*(y*z+x*w),   1-2*(x*x+y*y)],
    ], device=q.device)


def quat_mul(q1, q2):
    w1, x1, y1, z1 = q1.unbind(-1)
    w2, x2, y2, z2 = q2.unbind(-1)

    return torch.stack([
        w1*w2 - x1*x2 - y1*y2 - z1*z2,
        w1*x2 + x1*w2 + y1*z2 - z1*y2,
        w1*y2 - x1*z2 + y1*w2 + z1*x2,
        w1*z2 + x1*y2 - y1*x2 + z1*w2
    ], dim=-1)

In [None]:
def load_linemod_models(models_dir, device="cpu"):
    models = {}
    for ply in Path(models_dir).glob("obj_*.ply"):
        obj_id = int(ply.stem.split("_")[1])
        mesh = trimesh.load(ply, process=False)
        pts = torch.tensor(mesh.vertices, dtype=torch.float32, device=device)
        models[obj_id] = pts
    return models

def add_metric(
    model_points,   # (N,3)
    R_pred, t_pred, # (3,3), (3,)
    R_gt,   t_gt,   # (3,3), (3,)
):
    pts_pred = (R_pred @ model_points.T).T + t_pred
    pts_gt   = (R_gt   @ model_points.T).T + t_gt

    dists = torch.norm(pts_pred - pts_gt, dim=1)

    return dists.mean()

In [None]:
def build_uv_grid(box, H, W, device):
    B = box.shape[0]

    x, y, bw, bh = box[:, 0], box[:, 1], box[:, 2], box[:, 3]

    i = torch.arange(H, device=device).float()
    j = torch.arange(W, device=device).float()
    ii, jj = torch.meshgrid(i, j, indexing="ij")

    ii = ii.unsqueeze(0).expand(B, -1, -1)
    jj = jj.unsqueeze(0).expand(B, -1, -1)

    u = x[:, None, None] + (jj + 0.5) * bw[:, None, None] / W
    v = y[:, None, None] + (ii + 0.5) * bh[:, None, None] / H

    return torch.stack([u, v], dim=-1)


def spatial_softmax(weight_map, mask=None, tau=0.05):
    B, _, H, W = weight_map.shape
    w = weight_map.view(B, -1) / tau

    if mask is not None:
        m = mask.view(B, -1)
        w = w.masked_fill(m == 0, -1e9)

    w = torch.softmax(w, dim=1)
    return w.view(B, 1, H, W)

def make_coord_grid(H, W, device):
    ys = torch.linspace(-1, 1, H, device=device)
    xs = torch.linspace(-1, 1, W, device=device)
    yy, xx = torch.meshgrid(ys, xs, indexing="ij")
    grid = torch.stack([xx, yy], dim=0)
    return grid

*Pinhole Camera Model and Translation Estimation*

The following utilities implement basic pinhole camera geometry.

Depth values are back-projected into 3D space using camera intrinsics, and a learned per-pixel weight map is used to compute a weighted average of 3D points, yielding the final translation estimate.

In [None]:
def depth_to_points(depth, K, uv_grid):
    """
    depth:   [B, 1, H, W]
    uv_grid: [B, H, W, 2]
    K:       [3, 3]
    return:  [B, H, W, 3]
    """
    fx = K[0, 0]
    fy = K[1, 1]
    cx = K[0, 2]
    cy = K[1, 2]

    u = uv_grid[..., 0]
    v = uv_grid[..., 1]
    z = depth.squeeze(1)

    x = (u - cx) * z / fx
    y = (v - cy) * z / fy

    return torch.stack([x, y, z], dim=-1)


def weighted_translation(points_3d, weights):
    weights = weights.permute(0, 2, 3, 1)
    t = (points_3d * weights).sum(dim=(1,2))
    return t

The pipeline that goes from rgb and depth data to the translation prediction is wrapped inside the DepthTranslationNet:

In [None]:
class DepthTranslationNet(nn.Module):
    def __init__(self, depth_mean, depth_std):
        super().__init__()
        self.enc_dec = EncoderDecoderWeightsNet()
        
        self.register_buffer('depth_mean', torch.tensor(depth_mean))
        self.register_buffer('depth_std', torch.tensor(depth_std))

    def forward(self, rgb, depth, coord, box, K):
        # 1. depth denormalization
        un_normalized_depth = depth * self.depth_std + self.depth_mean

        # 2. encoder-decoder input
        x = torch.cat([rgb, depth, coord], dim=1)

        # 3. encoder-decoder output
        logits = self.enc_dec(x)  

        # 4. pixel < 10 mm are considered backgroud
        valid_mask = (un_normalized_depth > 10.0).float()
        
        # 5 weights computation
        weights = spatial_softmax(logits, valid_mask)

        # 4. Ricostruzione 3D (Inverse Pinhole)
        B, _, H, W = depth.shape
        device = depth.device
        
        # 5. u-v grid creation based on bounding box crop
        uv_grid = build_uv_grid(box, H, W, device)

        # 6. 2D -> 3D points projection
        points_3d = depth_to_points(un_normalized_depth, K, uv_grid)

        # 7. translation regression
        t_pred = weighted_translation(points_3d, weights)

        return weights, t_pred

# Dataset
This dataset wrapper provides access to RGB, depth, camera intrinsics and ground-truth 6D pose
information from the preprocessed LineMOD dataset.

**Key assumptions and design choices:**
- The dataset is expected to follow the standard `Linemod_preprocessed` structure.
- A **per-object random train/test split** is performed at initialization.
- Bounding boxes are taken directly from the ground-truth annotations.
- Camera intrinsics are assumed to be shared across all objects and images.

> **Note:**  
> This class is provided for completeness and reproducibility.  
> You can safely skip reading its implementation and treat it as a black box.

In [None]:
class LinemodSceneDataset(Dataset):
    CLASSES = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
    OBJ_ID_TO_CLASS = {obj_id: i for i, obj_id in enumerate(CLASSES)}

    def __init__(self, dataset_root, split="train", split_ratio=0.8, seed=42):
        np.random.seed(seed)
        self.dataset_root = Path(dataset_root)
        self.split = split
        self.samples = []
        self.gt_data = {}

        for obj_id in self.CLASSES:
            obj_dir = self.dataset_root / "data" / f"{obj_id:02d}"

            rgb_dir = obj_dir / "rgb"
            num_images = len(list(rgb_dir.glob("*.png")))


            indexes = np.arange(num_images)
            np.random.shuffle(indexes)

            split_point = int(split_ratio * num_images)
            if split == "train":
                img_ids = indexes[:split_point]
            else:
                img_ids = indexes[split_point:]

            for img_id in img_ids:
                self.samples.append((obj_id, img_id))

            with open(obj_dir / "gt.yml") as f:
                self.gt_data[obj_id] = yaml.safe_load(f)


        any_obj = self.CLASSES[0]
        info_path = self.dataset_root / "data" / f"{any_obj:02d}" / "info.yml"

        with open(info_path) as f:
            info = yaml.safe_load(f)

        cam_info = next(iter(info.values()))

        self.K = torch.tensor(cam_info["cam_K"], dtype=torch.float32).view(3, 3)
        self.depth_scale = cam_info.get("depth_scale", 1.0)
        self.rgb_transform = T.ToTensor()


    def __len__(self):
        return len(self.samples)


    def __getitem__(self, idx):
        obj_id, img_id = self.samples[idx]

        base_dir = self.dataset_root / "data" / f"{obj_id:02d}"

        img_path = base_dir / "rgb" / f"{img_id:04d}.png"
        depth_path = base_dir / "depth" / f"{img_id:04d}.png"

        img = Image.open(img_path).convert("RGB")

        W, H = img.size

        rgb = self.rgb_transform(img)

        object = None
        for entry in self.gt_data[obj_id][img_id]:
            if int(entry["obj_id"]) == obj_id:
                object = entry
                break

        if object is None:
            raise RuntimeError(
                f"Object {obj_id} not found in image {img_id}"
            )

        R = torch.tensor(object["cam_R_m2c"], dtype=torch.float32).view(3, 3)
        q = rotation_matrix_to_quaternion(R)
        t = torch.tensor(object["cam_t_m2c"], dtype=torch.float32).view(3)

        return {
            "img_path": img_path,
            "depth_path": depth_path,
            "cam_intrinsics": self.K,
            "rgb": rgb,
            "bbox": object["obj_bb"],
            "label": self.OBJ_ID_TO_CLASS[obj_id],
            "rotation": q,
            "translation":t,
            "size": (W, H),
        }

# End-to-End Evaluation Loop

The following cells run the **full 6D pose estimation pipeline** on the LineMOD test set:

1. Object detection with **YOLO**
2. Rotation estimation using the **RGB-D fusion network**
3. Translation estimation via **encoder-decoder weighted depth aggregation**
4. Evaluation with **ADD-S**, including symmetry handling

Basic logging is included to track detection failures and invalid depth cases, but only aggregate ADD-S statistics are reported below.

In [None]:
resnet_tf = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225],
    ),
])

def crop_rgb(img_pil, bbox):
    x, y, w, h = bbox
    if w <= 1 or h <= 1:
        return None
    x1 = int(max(0, x))
    y1 = int(max(0, y))
    x2 = int(min(img_pil.width,  x + w))
    y2 = int(min(img_pil.height, y + h))
    if x2 <= x1 or y2 <= y1:
        return None
    return img_pil.crop((x1, y1, x2, y2))

def bbox_invalid(bbox):
    x, y, w, h = bbox
    return (w <= 1) or (h <= 1)


In [None]:
log = defaultdict(lambda: {
    "adds": [],
    "bbox_missing": 0,
    "false_positive": 0,
    "bbox_invalid": 0,
    "depth_missing": 0,
    "total": 0,
})

device = "cuda" if torch.cuda.is_available() else "cpu"

# linemod depth mean and std for un-normalization
depth_mean = 990.7
depth_std  = 311.8

# YOLO
yolo = YOLO(yolo_weights_path)

# RGBD (rotation)
rotation_extension_net = DepthRotationNet(pretrained=False).to(device)
rotation_extension_net.load_state_dict(torch.load(rotation_extension_weights_path, map_location=device))
rotation_extension_net.eval()

# EncDec (translation)
translation_extension_net = DepthTranslationNet(depth_mean, depth_std).to(device)
translation_extension_net.enc_dec.load_state_dict(torch.load(translation_extension_weights_path, map_location=device))
translation_extension_net.eval()

ds = LinemodSceneDataset(data_root, split="test")
models_3d = load_linemod_models(
    f"{data_root}/models",
    device=device
)

errors = []

results = yolo.predict(
        source=f"{yolo_dataset}/test/images",
        imgsz=640,
        batch=16,
        device=device,
        stream=False,
        save=False,
    )

for r, scene in zip(results, ds):
   # ---------------------------
   # GT
   # ---------------------------
   q_gt = scene["rotation"].to(device)
   t_gt = scene["translation"].to(device)
   obj_class = scene["label"]
   obj_id = LinemodSceneDataset.CLASSES[obj_class]
   log[obj_id]["total"] += 1

   # ---------------------------
   # YOLO bbox
   # ---------------------------
   boxes = r.boxes

   if boxes is None or len(boxes) == 0:
      log[obj_id]["bbox_missing"] += 1
      continue

   xyxy = boxes.xyxy
   cls  = boxes.cls.long()
   conf = boxes.conf
   mask = cls == obj_class
   if mask.sum() == 0:
        log[obj_id]["bbox_missing"] += 1
        log[obj_id]["false_positive"] += len(cls)
        continue

   log[obj_id]["false_positive"] += int((~mask).sum())
   if mask.sum() > 1:
      log[obj_id]["false_positive"] += int(mask.sum() - 1)

   idxs = torch.where(mask)[0]
   best = idxs[conf[idxs].argmax()]

   x1, y1, x2, y2 = xyxy[best]

   bbox = (
      x1.item(),
      y1.item(),
      (x2 - x1).item(),
      (y2 - y1).item(),
   )

   if bbox_invalid(bbox):
      log[obj_id]["bbox_invalid"] += 1
      continue

   # ---------------------------
   # LOAD RGB + DEPTH
   # ---------------------------
   img = Image.open(scene["img_path"]).convert("RGB")
   depth_img = Image.open(scene["depth_path"])

   # ---------------------------
   # CROP
   # ---------------------------
   crop_rgb_img = crop_rgb(img, bbox)
   if crop_rgb_img is None:
      log[obj_id]["bbox_invalid"] += 1
      continue

   x, y, w, h = bbox
   x1 = int(max(0, x))
   y1 = int(max(0, y))
   x2 = int(min(depth_img.width,  x + w))
   y2 = int(min(depth_img.height, y + h))
   if x2 <= x1 or y2 <= y1:
      log[obj_id]["bbox_invalid"] += 1
      continue

   crop_depth_img = depth_img.crop((x1, y1, x2, y2))

   # =========================================================
   # ROTATION — RGBD Fusion Net (224)
   # =========================================================
   rgb_224 = resnet_tf(crop_rgb_img).unsqueeze(0).to(device)

   depth_224 = T.Compose([
      T.Resize((224, 224), interpolation=T.InterpolationMode.NEAREST),
      T.ToTensor(),
   ])(crop_depth_img).unsqueeze(0).to(device)

   depth_224 = (depth_224 - depth_mean) / depth_std

   with torch.no_grad():
      q_pred = rotation_extension_net(rgb_224, depth_224)[0]

   R_pred = quaternion_to_rotation_matrix(q_pred)

   # =========================================================
   # TRANSLATION — Encoder Decoder (64)
   # =========================================================
   rgb_64 = T.Compose([
      T.Resize((64, 64)),
      T.ToTensor(),
      T.Normalize(
          mean=[0.485, 0.456, 0.406],
          std =[0.229, 0.224, 0.225],
      ),
    ])(crop_rgb_img).unsqueeze(0).to(device)

   depth_64 = T.Compose([
      T.Resize((64, 64), interpolation=T.InterpolationMode.NEAREST),
      T.ToTensor(),
   ])(crop_depth_img).unsqueeze(0).to(device)

   depth_64 = (depth_64 - depth_mean) / depth_std

   coord = make_coord_grid(64, 64, device).unsqueeze(0)
   box = torch.tensor(bbox, device=device).unsqueeze(0)
   K = scene["cam_intrinsics"].to(device)

   with torch.no_grad():
      _, t_pred = translation_extension_net(rgb_64, depth_64, coord, box, K)
      t_pred = t_pred[0]

   # =========================================================
   # ADD-S
   # =========================================================
   pts = models_3d[obj_id]

   if LINEMOD_SYMMETRIES.get(obj_class, SymmetryType.NONE) == SymmetryType.DISCRETE:
      errs = []
      for q_sym in SYMMETRIC_QUATS[obj_class]:
          q_gt_sym = quat_mul(q_gt, q_sym.to(device))
          R_gt_sym = quaternion_to_rotation_matrix(q_gt_sym)

          errs.append(add_metric(
                pts,
                R_pred, t_pred,
                R_gt_sym, t_gt,
            ))

      err = torch.stack(errs).min()
   else:
      R_gt = quaternion_to_rotation_matrix(q_gt)
      err = add_metric(
            pts,
            R_pred, t_pred,
            R_gt, t_gt,
        )

   log[obj_id]["adds"].append(err.item())
   errors.append(err.item())


errors = torch.tensor(errors)
print(f"ADD-S mean: {errors.mean():.2f} mm")
print(f"ADD-S median: {errors.median():.2f} mm")

out_dir = "results"
os.makedirs(out_dir, exist_ok=True)
csv_path = os.path.join(out_dir, "eval_extension.csv")
adds_txt_path = os.path.join(out_dir, "all_adds_extension.txt")

with open(csv_path, "w", newline="") as f_csv, \
    open(adds_txt_path, "w") as f_txt:
    writer = csv.writer(f_csv)
    writer.writerow([
            "obj_id",
            "num_samples",
            "adds_mean_mm",
            "adds_median_mm",
            "bbox_missing",
            "false_positive",
            "bbox_invalid",
            "depth_missing",
    ])

    for obj_id, d in sorted(log.items()):
        adds = np.array(d["adds"])

        # ---- CSV summary ----
        writer.writerow([
            obj_id,
            d["total"],
            adds.mean() if len(adds) > 0 else np.nan,
            np.median(adds) if len(adds) > 0 else np.nan,
            d["bbox_missing"],
            d["false_positive"],
            d["bbox_invalid"],
            d["depth_missing"],
        ])

        # ---- TXT: all errors ----
        for e in adds:
            f_txt.write(f"{obj_id} {e:.6f}\n")