In [None]:
import os
import yaml
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader
from google.colab import drive
import cv2
import torch.nn as nn
import torchvision.models as models
import random
from torchvision.models import resnet50, ResNet50_Weights
!pip install trimesh
import trimesh

def compute_add(model_points, gt_R, gt_t, pred_R, pred_t):
    batch_size = pred_R.size(0)
    add_losses = []

    for b in range(batch_size):
        pts = model_points[b].to(pred_R.device)  # [N, 3]

        pred_transformed = pred_R[b] @ pts.T + pred_t[b][:, None]  # [3, N]
        gt_transformed = gt_R[b] @ pts.T + gt_t[b][:, None]        # [3, N]

        dist = torch.norm(pred_transformed - gt_transformed, dim=0)  # [N]
        add_loss_b = dist.mean()
        add_losses.append(add_loss_b)

    return torch.stack(add_losses).mean()

def load_ply_vertices(ply_path):
    mesh = trimesh.load(ply_path)
    return mesh.vertices.astype(np.float32)

def set_seed(seed=42):
    random.seed(seed)  # Python
    np.random.seed(seed)  # NumPy
    torch.manual_seed(seed)  # CPU
    torch.cuda.manual_seed(seed)  # GPU singola
    torch.cuda.manual_seed_all(seed)  # Tutte le GPU

    # Comportamento deterministico per reproducibilità
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def normalize_quaternion(q):
    return q / q.norm(dim=-1, keepdim=True)

def quaternion_to_matrix(q):
    # Assumiamo q shape: [B, 4] → ritorna [B, 3, 3]
    x, y, z, w = q.unbind(-1)

    B = q.size(0)
    R = torch.empty(B, 3, 3, device=q.device, dtype=q.dtype)
    R[:, 0, 0] = 1 - 2*(y*y + z*z)
    R[:, 0, 1] = 2*(x*y - z*w)
    R[:, 0, 2] = 2*(x*z + y*w)
    R[:, 1, 0] = 2*(x*y + z*w)
    R[:, 1, 1] = 1 - 2*(x*x + z*z)
    R[:, 1, 2] = 2*(y*z - x*w)
    R[:, 2, 0] = 2*(x*z - y*w)
    R[:, 2, 1] = 2*(y*z + x*w)
    R[:, 2, 2] = 1 - 2*(x*x + y*y)
    return R

def geodesic_loss(R_pred, R_gt):
    # R_pred, R_gt: [B, 3, 3]
    R_diff = torch.bmm(R_pred.transpose(1, 2), R_gt)
    trace = R_diff[:, 0, 0] + R_diff[:, 1, 1] + R_diff[:, 2, 2]  # batch trace
    cos_theta = (trace - 1) / 2
    cos_theta = torch.clamp(cos_theta, -1 + 1e-6, 1 - 1e-6)  # stabilità numerica
    theta = torch.acos(cos_theta)
    return torch.mean(theta)

class PreprocessedPoseDataset_RGBD(Dataset):
    def __init__(self, filenames, rgb_images, depth_images, pose_data, mean_t, std_t):
        self.rgb_images = rgb_images
        self.pose_data = pose_data
        self.depth_images = depth_images
        self.filenames = filenames
        self.mean_t = mean_t
        self.std_t = std_t

        self.obj_id_to_idx = {
            1: 0, 2: 1, 4: 2, 5: 3, 6: 4,
            8: 5, 9: 6, 10: 7, 11: 8, 12: 9,
            13: 10, 14: 11, 15: 12
        }

    def __len__(self):
        return len(self.rgb_images)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        name_key = os.path.splitext(filename)[0]

        rgb = self.rgb_images[idx]
        depth = self.depth_images[idx]
        sample = self.pose_data[name_key][0]

        rotation = torch.tensor(sample['cam_R_m2c'], dtype=torch.float32).view(3, 3)
        translation = torch.tensor(sample['cam_t_m2c'], dtype=torch.float32)
        translation = (translation - self.mean_t) / self.std_t  # Normalizzazione

        obj_id_raw = sample['obj_id']
        obj_id = torch.tensor(self.obj_id_to_idx[obj_id_raw], dtype=torch.long)

        return {
            'image': rgb,
            'depth_map': depth,
            'cam_R_m2c': rotation,
            'cam_t_m2c': translation,
            'obj_id': obj_id
        }

class PoseRegressor_RGBD(nn.Module):
    def __init__(self, num_obj_ids, embedding_dim=32, use_learned_default=True):
        super(PoseRegressor_RGBD, self).__init__()

        # Feature extractor RGB (pre-trained)
        resnet_rgb = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.rgb_extractor = nn.Sequential(*list(resnet_rgb.children())[:-1])  # [B, 2048, 1, 1]

        # Feature extractor depth (lighter, from scratch or pre-trained on grayscale)
        resnet_depth = models.resnet18(weights=None)
        resnet_depth.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.depth_extractor = nn.Sequential(*list(resnet_depth.children())[:-1])  # [B, 512, 1, 1]

        self.obj_id_embedding = nn.Embedding(num_obj_ids, embedding_dim)

        # Embedding predefinito (se richiesto)
        self.use_learned_default = use_learned_default
        if use_learned_default:
            self.default_obj_embedding = nn.Parameter(torch.zeros(embedding_dim))

        self.fc_common = nn.Sequential(
            nn.Linear(2048 + embedding_dim + 512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU())

        self.fc_translation = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 3))

        self.fc_rotation = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 4))

    def forward(self, rgb, depth, obj_id=None):
        feat_rgb = self.rgb_extractor(rgb)       # [B, 2048, 1, 1]
        feat_depth = self.depth_extractor(depth) # [B, 512, 1, 1]

        feat_rgb = feat_rgb.view(feat_rgb.size(0), -1)       # [B, 2048]
        feat_depth = feat_depth.view(feat_depth.size(0), -1) # [B, 512]

        if obj_id is not None:
            obj_embed = self.obj_id_embedding(obj_id)  # [B, D]
        else:
            if self.use_learned_default:
                obj_embed = self.default_obj_embedding.unsqueeze(0).expand(rgb.size(0), -1)  # [B, D]
            else:
                obj_embed = torch.zeros(x.size(0), self.obj_id_embedding.embedding_dim, device=rgb.device)

        features = torch.cat([feat_rgb, feat_depth], dim=1)
        x = torch.cat([features, obj_embed], dim=1)  # [B, 2048 + D]
        x = self.fc_common(x)
        translation = self.fc_translation(x)
        rotation_q = self.fc_rotation(x)  # [B, 4]
        rotation_q = normalize_quaternion(rotation_q)  # normalizza il quaternione
        rotation = quaternion_to_matrix(rotation_q)     # converte in matrice 3x3
        return translation, rotation

if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

drive.mount('/content/drive', force_remount=True)

Collecting trimesh
  Downloading trimesh-4.6.10-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.6.10-py3-none-any.whl (711 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m711.2/711.2 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.6.10
Mounted at /content/drive


In [None]:
set_seed(42)
pose_file = "/content/drive/MyDrive/DatasetCorrect/gt_all_new.yml" # Indirizzo della ground truth
img_dir = "/content/drive/MyDrive/DatasetCorrect/dataset_tensors/tensors_rgb.pt"
img_depth_dir = "/content/drive/MyDrive/DatasetCorrect/dataset_tensors/tensors_dpt.pt" # Indirizzo di depth dell'analizzato

id_dataset_to_ordered = {1: 0, 2: 1, 4: 2, 5: 3, 6: 4,
            8: 5, 9: 6, 10: 7, 11: 8, 12: 9,
            13: 10, 14: 11, 15: 12}

models_points = {}

for i in range(15):
  if i != 2 and i != 6:
    pts = (load_ply_vertices(f"/content/drive/MyDrive/DatasetCorrect/normalized_models/obj_norm_{i+1:02d}.ply"))
    models_points[id_dataset_to_ordered[i+1]] = torch.tensor(pts, dtype=torch.float32).to(device)

with open(pose_file, 'r') as f:
  pose_data = yaml.load(f, Loader=yaml.FullLoader)

def compute_translation_stats(pose_data):
    translations = []
    for v in pose_data.values():
        t = np.array(v[0]['cam_t_m2c'])
        translations.append(t)
    translations = np.stack(translations)
    mean = translations.mean(axis=0)
    std = translations.std(axis=0)
    return mean, std

mean_t, std_t = compute_translation_stats(pose_data)
mean_t = torch.tensor(mean_t, dtype=torch.float32)
std_t = torch.tensor(std_t, dtype=torch.float32)

## Creating DataLoader

filenames, rgb_tensors = torch.load(img_dir)  # Carica tutti i tensori RGB

with open("/content/drive/MyDrive/DatasetCorrect/dataset_indexes/train_indexes.txt", "r") as f:
    valid_names = set(line.strip() for line in f if line.strip())

with open("/content/drive/MyDrive/DatasetCorrect/dataset_indexes/val_indexes.txt", "r") as f:
    valid_names_val = set(line.strip() for line in f if line.strip())

tr_filenames = []
val_filenames = []

tr_rgb_tensors = []
val_rgb_tensors = []

for fname, tensor in zip(filenames, rgb_tensors):
    name_no_ext = os.path.splitext(fname)[0]
    if name_no_ext in valid_names:
        tr_filenames.append(fname)
        tr_rgb_tensors.append(tensor)
    elif name_no_ext in valid_names_val:
        val_filenames.append(fname)
        val_rgb_tensors.append(tensor)

torch.cuda.empty_cache()

depth_tensors = torch.load(img_depth_dir)  # Carica tutti i tensori di profondità

tr_dpt_tensors = []
val_dpt_tensors = []

for fname, tensor in zip(filenames, depth_tensors):
    name_no_ext = os.path.splitext(fname)[0]
    if name_no_ext in tr_filenames:
        tr_dpt_tensors.append(tensor)
    elif name_no_ext in val_filenames:
        val_dpt_tensors.append(tensor)

del filenames
torch.cuda.empty_cache()

tr_dataset = PreprocessedPoseDataset_RGBD(tr_filenames, tr_rgb_tensors, tr_dpt_tensors, pose_data, mean_t, std_t)
val_dataset = PreprocessedPoseDataset_RGBD(val_filenames, val_rgb_tensors, val_dpt_tensors, pose_data, mean_t, std_t)

tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True, num_workers=2, drop_last = True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, drop_last = True)

In [None]:
set_seed(42)
## ATTENZIONE: PREDIZIONI NORMALIZZATE, VANNO POI RIPORTATE IN MONDO "REALE"
LR = 0.000000001

model = PoseRegressor_RGBD(num_obj_ids=13).to(device)

alpha = 0.01 # Translation Loss coefficient
beta = 0.01 # Rotation Loss coefficient
gamma = 1 # ADD Loss Coefficient
beta_t_lossL1smooth = 1

lr_backbone_rgb = 0.000000001
lr_backbone_dpt = 0.000000001
lr_common = 0.000000001
lr_translation = 0.000000001
lr_rotation = 0.000000005

optimizer = torch.optim.Adam([
    {"params": model.rgb_extractor.parameters(), "lr": lr_backbone_rgb},
    {"params": model.depth_extractor.parameters(), "lr": lr_backbone_dpt},
    {"params": model.fc_common.parameters(), "lr": lr_common},
    {"params": model.fc_translation.parameters(), "lr": lr_translation},
    {"params": model.fc_rotation.parameters(), "lr": lr_rotation},], lr=LR)

posemodelidrgb = 189 # Scegliere versione modello PoseModel RGB da cui partire
n_epochs = 30 # Numero di epoche da trainare

# Per ogni parametro del modello, assicurati che i gradienti siano abilitati
for param in model.parameters():
    param.requires_grad = True

if posemodelidrgb > 0:
  checkpoint = torch.load(f"/content/drive/MyDrive/RGB+DEPTH_PoseModels_Correct/posemodel_rgb+dpt_{posemodelidrgb}.pth")
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

optimizer.param_groups[0]['lr'] = lr_backbone_rgb
optimizer.param_groups[1]['lr'] = lr_backbone_dpt
optimizer.param_groups[2]['lr'] = lr_common
optimizer.param_groups[3]['lr'] = lr_translation
optimizer.param_groups[4]['lr'] = lr_rotation

for i in range(n_epochs):
  model.train()

  total_loss = 0.0
  total_t_loss = 0.0
  total_r_loss = 0.0
  total_samples = 0
  total_add = 0.0

  for batch in tr_dataloader:
      batch_size = batch['image'].size(0)
      total_samples += batch_size

      images = batch['image'].to(device)                 # [B, 3, 224, 224]
      depth_maps = batch['depth_map'].to(device)
      gt_t = batch['cam_t_m2c'].to(device).float()               # [B, 3]
      gt_r = batch['cam_R_m2c'].to(device).float()               # [B, 3, 3]
      obj_ids = batch['obj_id'].to(device)               # [B]

      pred_t, pred_r = model(images, depth_maps, obj_ids)

      model_points_batch = [models_points[obj_ids[j].item()] for j in range(batch_size)]

      t_loss = nn.functional.smooth_l1_loss(pred_t, gt_t, beta=beta_t_lossL1smooth, reduction='mean')
      #t_loss = nn.functional.mse_loss(pred_t, gt_t, reduction='mean')
      r_loss = geodesic_loss(pred_r, gt_r)
      add_loss = compute_add(model_points_batch, gt_r, gt_t, pred_r, pred_t)
      loss = gamma * add_loss + alpha * t_loss + beta * r_loss

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      total_loss += loss.item() * batch_size
      total_t_loss += t_loss.item() * batch_size
      total_r_loss += r_loss.item() * batch_size
      total_add += add_loss.item() * batch_size

  avg_loss = total_loss / total_samples
  avg_rot_loss = total_r_loss / total_samples
  avg_trans_loss = total_t_loss / total_samples
  avg_add = total_add / total_samples

  # VALIDATION
  model.eval()
  val_total_loss = 0.0
  val_total_t_loss = 0.0
  val_total_r_loss = 0.0
  val_total_samples = 0
  val_total_add = 0.0

  with torch.no_grad():
      for val_batch in val_dataloader:
          batch_size = val_batch['image'].size(0)
          val_total_samples += batch_size

          images = val_batch['image'].to(device)
          depth_maps = batch['depth_map'].to(device)
          gt_t = val_batch['cam_t_m2c'].to(device).float()
          gt_r = val_batch['cam_R_m2c'].to(device).float()
          obj_ids = val_batch['obj_id'].to(device)

          pred_t, pred_r = model(images,depth_maps, obj_ids)

          model_points_batch = [models_points[obj_ids[j].item()] for j in range(batch_size)]

          t_loss = nn.functional.smooth_l1_loss(pred_t, gt_t, beta=beta_t_lossL1smooth, reduction='mean')
          r_loss = geodesic_loss(pred_r, gt_r)
          add_loss = compute_add(model_points_batch, gt_r, gt_t, pred_r, pred_t)
          loss = gamma * add_loss + alpha * t_loss + beta * r_loss

          val_total_loss += loss.item() * batch_size
          val_total_t_loss += t_loss.item() * batch_size
          val_total_r_loss += r_loss.item() * batch_size
          val_total_add += add_loss.item() * batch_size

  val_avg_loss = val_total_loss / val_total_samples
  val_avg_rot_loss = val_total_r_loss / val_total_samples
  val_avg_trans_loss = val_total_t_loss / val_total_samples
  val_avg_add = val_total_add / val_total_samples

  torch.save({'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict()}, f"/content/drive/MyDrive/RGB+DEPTH_PoseModels_Correct/posemodel_rgb+dpt_{posemodelidrgb+i+1}.pth")
  print(f"{i+1}/{n_epochs} -> - Epoch {posemodelidrgb+i+1}")
  print(f"Average Training ADD (per Sample): {avg_add:.4f} | Avg. Rot. Loss: {avg_rot_loss:.4f} rad / {avg_rot_loss*(180 / torch.pi):.4f}° | Avg. Trans. Loss: {avg_trans_loss:.4f}" )
  print(f"Average Validation ADD (per Sample): {val_avg_loss:.4f} | Avg. Rot. Loss: {val_avg_rot_loss:.4f} rad / {val_avg_rot_loss*(180 / torch.pi):.4f}° | Avg. Trans. Loss: {val_avg_trans_loss:.4f}")
  print("-----------------------------------------------------------------------")

1/30 -> - Epoch 190
Average Training ADD (per Sample): 1.4705 | Avg. Rot. Loss: 0.7592 rad / 43.4992° | Avg. Trans. Loss: 0.2211
Average Validation ADD (per Sample): 2.0414 | Avg. Rot. Loss: 0.7826 rad / 44.8388° | Avg. Trans. Loss: 0.4810
-----------------------------------------------------------------------
2/30 -> - Epoch 191
Average Training ADD (per Sample): 1.4699 | Avg. Rot. Loss: 0.7570 rad / 43.3716° | Avg. Trans. Loss: 0.2208
Average Validation ADD (per Sample): 2.0460 | Avg. Rot. Loss: 0.8087 rad / 46.3351° | Avg. Trans. Loss: 0.4851
-----------------------------------------------------------------------
3/30 -> - Epoch 192
Average Training ADD (per Sample): 1.4575 | Avg. Rot. Loss: 0.7471 rad / 42.8069° | Avg. Trans. Loss: 0.2192
Average Validation ADD (per Sample): 2.1372 | Avg. Rot. Loss: 0.7736 rad / 44.3212° | Avg. Trans. Loss: 0.5358
-----------------------------------------------------------------------
4/30 -> - Epoch 193
Average Training ADD (per Sample): 1.4658 | 

In [None]:
print("✅ Addestramento completato. Puoi disconnettere il runtime.")
os.kill(os.getpid(), 9)