In [33]:
# Scegliere versione modello YOLO retrained da usare
yolomodelid = 30
# Scegliere versione modello PoseModel RGB da usare
posemodelidrgb = 161
# Scegliere versione modello PoseModel RGBD da usare
posemodelrgbdid = 189

img_dir = f'/content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb' # Insert RGB image path
img_depth_dir = f'/content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/depth' # Insert Depth image path

img_filenames = ["0000.png"] # Insert images names as "name.png"

objdet_saved_model_path = f'/content/drive/MyDrive/RetrainedYOLO_Models_Correct/YOLO_retrained_{yolomodelid}.pt'
pose_saved_model_path_rgb = f'/content/drive/MyDrive/RGB_PoseModels_Correct/posemodel_rgb_{posemodelidrgb}.pth'
pose_saved_model_path_rgbd = f'/content/drive/MyDrive/RGB+DEPTH_PoseModels_Correct/posemodel_rgb+dpt_{posemodelrgbdid}.pth'

only_rgb = False

In [3]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import plotly.graph_objects as go
from torch.utils.data import DataLoader
!pip install ultralytics
from ultralytics import YOLO
from google.colab import drive
import numpy as np
import cv2
import torch.nn as nn
import torchvision.models as models
import torch
from torchvision.transforms import Compose, Resize, ToTensor
import random
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms.functional import crop, resize, to_tensor

def set_seed(seed=42):
    random.seed(seed)  # Python
    np.random.seed(seed)  # NumPy
    torch.manual_seed(seed)  # CPU
    torch.cuda.manual_seed(seed)  # GPU singola
    torch.cuda.manual_seed_all(seed)  # Tutte le GPU

    # Comportamento deterministico per reproducibilità
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def normalize_quaternion(q):
    return q / q.norm(dim=-1, keepdim=True)

def quaternion_to_matrix(q):
    x, y, z, w = q.unbind(-1)
    B = q.size(0)
    R = torch.empty(B, 3, 3, device=q.device, dtype=q.dtype)
    R[:, 0, 0] = 1 - 2*(y*y + z*z)
    R[:, 0, 1] = 2*(x*y - z*w)
    R[:, 0, 2] = 2*(x*z + y*w)
    R[:, 1, 0] = 2*(x*y + z*w)
    R[:, 1, 1] = 1 - 2*(x*x + z*z)
    R[:, 1, 2] = 2*(y*z - x*w)
    R[:, 2, 0] = 2*(x*z - y*w)
    R[:, 2, 1] = 2*(y*z + x*w)
    R[:, 2, 2] = 1 - 2*(x*x + y*y)
    return R

class PoseRegressor_RGB(nn.Module):
    def __init__(self, num_obj_ids, embedding_dim=32, use_learned_default=True):
        super(PoseRegressor_RGB, self).__init__()

        resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])  # [B, 2048, 1, 1]

        self.obj_id_embedding = nn.Embedding(num_obj_ids, embedding_dim)

        # Embedding predefinito (se richiesto)
        self.use_learned_default = use_learned_default
        if use_learned_default:
            self.default_obj_embedding = nn.Parameter(torch.zeros(embedding_dim))

        self.fc_common = nn.Sequential(
            nn.Linear(2048 + embedding_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU())

        self.fc_translation = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 3))

        self.fc_rotation = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 4))

    def forward(self, x, obj_id=None):
        features = self.feature_extractor(x)  # [B, 2048, 1, 1]
        features = features.view(features.size(0), -1)  # [B, 2048]

        if obj_id is not None:
            obj_embed = self.obj_id_embedding(obj_id)  # [B, D]
        else:
            if self.use_learned_default:
                obj_embed = self.default_obj_embedding.unsqueeze(0).expand(x.size(0), -1)  # [B, D]
            else:
                obj_embed = torch.zeros(x.size(0), self.obj_id_embedding.embedding_dim, device=x.device)

        x = torch.cat([features, obj_embed], dim=1)  # [B, 2048 + D]
        x = self.fc_common(x)
        translation = self.fc_translation(x)
        rotation_q = self.fc_rotation(x)  # [B, 4]
        rotation_q = normalize_quaternion(rotation_q)  # normalizza il quaternione
        rotation = quaternion_to_matrix(rotation_q)     # converte in matrice 3x3
        return translation, rotation

class PoseRegressor_RGBD(nn.Module):
    def __init__(self, num_obj_ids, embedding_dim=32, use_learned_default=True):
        super(PoseRegressor_RGBD, self).__init__()

        # Feature extractor RGB (pre-trained)
        resnet_rgb = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.rgb_extractor = nn.Sequential(*list(resnet_rgb.children())[:-1])  # [B, 2048, 1, 1]

        # Feature extractor depth (lighter, from scratch or pre-trained on grayscale)
        resnet_depth = models.resnet18(weights=None)
        resnet_depth.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.depth_extractor = nn.Sequential(*list(resnet_depth.children())[:-1])  # [B, 512, 1, 1]

        self.obj_id_embedding = nn.Embedding(num_obj_ids, embedding_dim)

        # Embedding predefinito (se richiesto)
        self.use_learned_default = use_learned_default
        if use_learned_default:
            self.default_obj_embedding = nn.Parameter(torch.zeros(embedding_dim))

        self.fc_common = nn.Sequential(
            nn.Linear(2048 + embedding_dim + 512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU())

        self.fc_translation = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 3))

        self.fc_rotation = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 4))

    def forward(self, rgb, depth, obj_id=None):
        feat_rgb = self.rgb_extractor(rgb)       # [B, 2048, 1, 1]
        feat_depth = self.depth_extractor(depth) # [B, 512, 1, 1]

        feat_rgb = feat_rgb.view(feat_rgb.size(0), -1)       # [B, 2048]
        feat_depth = feat_depth.view(feat_depth.size(0), -1) # [B, 512]

        if obj_id is not None:
            obj_embed = self.obj_id_embedding(obj_id)  # [B, D]
        else:
            if self.use_learned_default:
                obj_embed = self.default_obj_embedding.unsqueeze(0).expand(rgb.size(0), -1)  # [B, D]
            else:
                obj_embed = torch.zeros(x.size(0), self.obj_id_embedding.embedding_dim, device=rgb.device)

        features = torch.cat([feat_rgb, feat_depth], dim=1)
        x = torch.cat([features, obj_embed], dim=1)  # [B, 2048 + D]
        x = self.fc_common(x)
        translation = self.fc_translation(x)
        rotation_q = self.fc_rotation(x)  # [B, 4]
        rotation_q = normalize_quaternion(rotation_q)  # normalizza il quaternione
        rotation = quaternion_to_matrix(rotation_q)     # converte in matrice 3x3
        return translation, rotation

if torch.cuda.is_available():
  device = "cuda"
else:
  device = "cpu"

drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [6]:
## Object Detection
set_seed(42)
ObjDetModel = YOLO(objdet_saved_model_path).to(device)
results = ObjDetModel.predict(source=img_dir, save=True, stream = True)
preds = {}
for i, r in enumerate(results):
  if i < len(img_filenames):
    boxes = r.boxes.xyxy       # Bounding boxes
    confs = r.boxes.conf       # Confidence scores
    classes = r.boxes.cls      # Class IDs
    if len(confs) == 0: # Nessun oggetto trovato in immagine
        continue
    max_conf_idx = confs.argmax() # Trova l'indice della box con confidenza massima

    best_box = boxes[max_conf_idx].cpu().numpy()
    best_conf = confs[max_conf_idx].item()
    best_class = classes[max_conf_idx].item()

    preds[img_filenames[i][:-4]] = {'pred_bbox': best_box, 'pred_class': best_class}


image 1/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0000.png: 480x640 1 obj_01, 7.0ms
image 2/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0001.png: 480x640 1 obj_01, 6.9ms
image 3/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0002.png: 480x640 1 obj_01, 7.5ms
image 4/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0003.png: 480x640 1 obj_01, 6.1ms
image 5/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0004.png: 480x640 1 obj_01, 6.5ms
image 6/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0005.png: 480x640 1 obj_01, 6.3ms
image 7/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0006.png: 480x640 1 obj_01, 6.5ms
image 8/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0007.png: 480x640 1 obj_01, 6.4ms
image 9/1236 /content/drive/MyDrive/LineMODZ/Linemod_preprocessed/data/01/rgb/0008.png: 480x640 1 obj_0

In [37]:
### RGB + Depth Pose Prediction
if only_rgb is False:
  set_seed(42)

  transform_rgb = transforms.Compose([
      transforms.Resize((224, 224)),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.3349, 0.3165, 0.3106], std=[0.2419, 0.2388, 0.2392])
  ])

  transform_depth = transforms.Compose([
      transforms.Resize((224, 224)),
      transforms.ToTensor(),
      transforms.Lambda(lambda img: torch.from_numpy(np.array(img, dtype=np.float32)) / 65535.0),
      transforms.Normalize(mean=[0.5], std=[0.5])
  ])

  pose_model_rgbd = PoseRegressor_RGBD(num_obj_ids=13).to(device)
  optimizer = torch.optim.Adam([
    {"params": pose_model_rgbd.rgb_extractor.parameters(), "lr": 1e-6},
    {"params": pose_model_rgbd.depth_extractor.parameters(), "lr": 1e-6},
    {"params": pose_model_rgbd.fc_common.parameters(), "lr": 1e-6},
    {"params": pose_model_rgbd.fc_translation.parameters(), "lr": 1e-6},
    {"params": pose_model_rgbd.fc_rotation.parameters(), "lr": 1e-6},], lr=1e-6)
  checkpoint = torch.load(pose_saved_model_path_rgbd)
  pose_model_rgbd.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  pose_model_rgbd.eval()

  cropped_tensors = []
  cropped_tensors_depth = []
  classes = []
  for i, filename in enumerate(img_filenames):
      name_key = filename[:-4]
      if name_key not in preds:
          continue  # Salta immagini senza predizione

      img_path = os.path.join(img_dir, filename)
      image = Image.open(img_path).convert("RGB")

      depth_path = os.path.join(img_depth_dir, filename)
      depth_map = Image.open(depth_path)

      x1, y1, x2, y2 = preds[name_key]['pred_bbox']
      classes.append(preds[name_key]['pred_class'])
      cropped = image.crop((x1, y1, x2, y2))
      cropped_tensor = transform_rgb(cropped)
      cropped_tensors.append(cropped_tensor)
      cropped_depth = depth_map.crop((x1, y1, x2, y2))
      cropped_tensor_depth = transform_depth(cropped_depth)
      cropped_tensors_depth.append(cropped_tensor_depth)

  batch_tensor_rgb = torch.stack(cropped_tensors).to(device)
  batch_tensor_depth = torch.stack(cropped_tensors_depth).to(device)
  obj_ids = torch.tensor(classes).to(device)

  with torch.no_grad():
      obj_ids = obj_ids.to(dtype=torch.long, device=device)
      translations_rgbd, rotations_rgbd = pose_model_rgbd(batch_tensor_rgb, batch_tensor_depth, obj_ids)  # [B, 3] and [B, 3, 3]
      scale = torch.tensor([77.1325, 67.1468, 128.0780], device=translations_rgbd.device)
      offset = torch.tensor([10.1119, -40.2, 874.9376], device=translations_rgbd.device)
      translations_rgbd = translations_rgbd * scale + offset
      print(translations_rgbd)
      print(rotations_rgbd)

else:
    set_seed(42)

    transform_rgb = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3349, 0.3165, 0.3106], std=[0.2419, 0.2388, 0.2392])
    ])

    pose_model_rgb = PoseRegressor_RGB(num_obj_ids=13).to(device)
    optimizer = torch.optim.Adam([
    {"params": pose_model_rgb.rgb_extractor.parameters(), "lr": 1e-6},
    {"params": pose_model_rgb.depth_extractor.parameters(), "lr": 1e-6},
    {"params": pose_model_rgb.fc_common.parameters(), "lr": 1e-6},
    {"params": pose_model_rgb.fc_translation.parameters(), "lr": 1e-6},
    {"params": pose_model_rgb.fc_rotation.parameters(), "lr": 1e-6},], lr=1e-6)
    checkpoint = torch.load(pose_saved_model_path_rgb)
    pose_model_rgb.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    pose_model_rgb.eval()

    cropped_tensors = []
    classes = []
    for i, filename in enumerate(img_filenames):
        name_key = filename[:-4]
        if name_key not in preds:
            continue  # Salta immagini senza predizione

        img_path = os.path.join(img_dir, filename)
        image = Image.open(img_path).convert("RGB")


        x1, y1, x2, y2 = preds[name_key]['pred_bbox']
        classes.append(preds[name_key]['pred_class'])
        cropped = image.crop((x1, y1, x2, y2))
        cropped_tensor = transform_rgb(cropped)
        cropped_tensors.append(cropped_tensor)

    batch_tensor_rgb = torch.stack(cropped_tensors).to(device)
    obj_ids = torch.tensor(classes).to(device)

    with torch.no_grad():
        obj_ids = obj_ids.to(dtype=torch.long, device=device)
        translations_rgb, rotations_rgb = pose_model_rgb(batch_tensor_rgb, obj_ids)  # [B, 3] and [B, 3, 3]
        scale = torch.tensor([77.1325, 67.1468, 128.0780], device=translations_rgb.device)
        offset = torch.tensor([10.1119, -40.2, 874.9376], device=translations_rgb.device)
        translations_rgb = translations_rgb * scale + offset
        print(translations_rgb)
        print(rotations_rgb)



tensor([[  12.6521,   -1.9922, 1012.2130]], device='cuda:0')
tensor([[[-0.6387,  0.7538, -0.1542],
         [ 0.5876,  0.3485, -0.7303],
         [-0.4968, -0.5570, -0.6655]]], device='cuda:0')
