In [1]:
%load_ext autoreload
%autoreload 2
import os
from pathlib import Path
import sys
import torch
from src.models import Resnet18
from src.dataset import FlowImageDataset
from src.trainer import train
from src.predictor import Predictor
from src.metrics import transl_error, rot_error, recall
from src.result_visualization import ResultVisualization
sys.path.append(os.path.abspath(os.path.join('..', 'utils')))
from load_data import load_data, load_meshes_o3d



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
obj_diameter = 0.24
do_training = False

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [4]:
flow_images_path = Path("../../output", "flow_images")
poses_path = Path("../../output", "dataset_rendered", "shoe")
if do_training:
    train_ds = FlowImageDataset(flow_image_path=flow_images_path / "train", poses_path=poses_path, subset="train")
    val_ds = FlowImageDataset(flow_image_path=flow_images_path / "train", poses_path=poses_path, subset="val")
test_ds = FlowImageDataset(flow_image_path=flow_images_path / "test", poses_path=poses_path, subset="test")

100%|██████████| 10240/10240 [00:03<00:00, 2637.47it/s]
10240it [00:00, 35375.44it/s]

torch.Size([10240, 4, 4])





In [5]:
if do_training:
    model = Resnet18(obj_diameter=obj_diameter).to(device)
    model.name = "resnet18"
    model = train(model, train_ds, val_ds, num_epochs=15, device=device, batch_size=32)

In [6]:
model = Resnet18(obj_diameter=obj_diameter).load("saved_models/resnet18_best").to(device)

In [7]:
ds_predict = test_ds
predictions = Predictor(model, device).predict_all(ds_predict)

  return F.conv2d(input, weight, bias, self.stride,
Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853085/work/aten/src/ATen/native/Cross.cpp:62.)
  R3 = torch.cross(R1, rot_r2)
100%|██████████| 10240/10240 [00:25<00:00, 397.13it/s]


In [8]:
transl_thresh = 0.1 # meters
rot_thresh = 10 # degrees
targets = [ds_predict[i][1].to(device) for i in range(len(ds_predict))]
te = transl_error(predictions, targets)
re = rot_error(predictions, targets)
t_recall = recall(te, transl_thresh)
r_recall = recall(re, rot_thresh)

In [9]:
print(f"Mean translational error: {torch.mean(torch.tensor(te))}")
print(f"Mean rotational error: {torch.mean(torch.tensor(re))}")

print(f"Median translational error: {torch.median(torch.tensor(te))}")
print(f"Median rotational error: {torch.median(torch.tensor(re))}")

print(f"Recall for translational error < {transl_thresh}: {t_recall}")
print(f"Recall for rotational error < {rot_thresh}: {r_recall}")

Mean translational error: 0.4530269205570221
Mean rotational error: 73.93890380859375
Median translational error: 0.43452417850494385
Median rotational error: 60.874534606933594
Recall for translational error < 0.1: 0.021484375
Recall for rotational error < 10: 0.018359375


In [10]:
# sort entries by a combination of translational and rotational error
sorted_te = sorted(range(len(te)), key=lambda k: te[k]+re[k]/100)
for i in sorted_te:
    print(f"Index: {i}, Translational error: {te[i]}, Rotational error: {re[i]}")
    print(f"Image name: {ds_predict.flow_image_names[i]}")
    print("\n")

Index: 6595, Translational error: 0.03743818774819374, Rotational error: 5.545174422342444
Image name: shoe-green_viva_sandal_right_22__shoe-asifn_yellow_right_24


Index: 4021, Translational error: 0.033299338072538376, Rotational error: 6.100979023226451
Image name: shoe-crocs_white_cyan_right_6__shoe-green_viva_sandal_right_8


Index: 7812, Translational error: 0.03323008120059967, Rotational error: 6.329857362450149
Image name: shoe-green_viva_sandal_right_58__shoe-crocs_white_cyan_right_57


Index: 1937, Translational error: 0.05988691747188568, Rotational error: 4.762713859813057
Image name: shoe-asifn_yellow_right_7__shoe-crocs_white_cyan_right_57


Index: 765, Translational error: 0.07558545470237732, Rotational error: 3.384139726130533
Image name: shoe-asifn_yellow_right_30__shoe-green_viva_sandal_right_20


Index: 2665, Translational error: 0.03556310385465622, Rotational error: 8.500895654855505
Image name: shoe-crocs_white_cyan_right_27__shoe-asifn_yellow_right_20


Index: 

In [11]:
rgb_images, object_masks, poses, scene_names = load_data(Path("../../output", "dataset_rendered", "shoe"))
meshes = load_meshes_o3d("../../data/housecat6d_meshes/shoe")

Loading meshes...


In [None]:
ResultVisualization(meshes).visualize_results(dataset=ds_predict, rgb_images=rgb_images, poses=poses, predictions=predictions, te=te, re=re)

  0%|          | 4/10240 [00:02<1:41:50,  1.68it/s]

: 

- pretrained weights for resnet18 (full retraining, only layer 4)
- different rotation representations (quats), https://github.com/naver/roma
- same for resnet34
- novel object for same category