In [1]:
%load_ext autoreload
%autoreload 2
import os
from pathlib import Path
import sys
import torch
from src.models import Resnet18, Resnet152
from src.dataset import FlowImageDataset
from src.trainer import train
from src.predictor import Predictor
from src.metrics import transl_error, rot_error, recall
from src.result_visualization import ResultVisualization
sys.path.append(os.path.abspath(os.path.join('..', 'utils')))
from load_data import load_data, load_meshes_o3d



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
obj_diameter = 0.24
do_training = True
use_pretrained = True
model_type = 'resnet152'

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [4]:
flow_images_path = Path("../../output", "flow_images")
poses_path = Path("../../output", "dataset_rendered", "shoe")
if do_training:
    train_ds = FlowImageDataset(flow_image_path=flow_images_path / "train", poses_path=poses_path, subset="train")
    val_ds = FlowImageDataset(flow_image_path=flow_images_path / "train", poses_path=poses_path, subset="val")
test_ds = FlowImageDataset(flow_image_path=flow_images_path / "test", poses_path=poses_path, subset="test")

100%|██████████| 40960/40960 [00:54<00:00, 751.64it/s] 
32768it [00:00, 37236.43it/s]


torch.Size([32768, 4, 4])


100%|██████████| 40960/40960 [00:20<00:00, 2042.34it/s]
8192it [00:00, 38646.83it/s]


torch.Size([8192, 4, 4])


100%|██████████| 10240/10240 [00:17<00:00, 596.59it/s]
10240it [00:00, 38935.04it/s]

torch.Size([10240, 4, 4])





In [5]:
if do_training:
    if model_type == 'resnet152':
        model = Resnet152(obj_diameter=obj_diameter, pretrained=use_pretrained).to(device)
    elif model_type == 'resnet18':
        model = Resnet18(obj_diameter=obj_diameter, pretrained=use_pretrained).to(device)
    model.name = model_type
    model = train(model, train_ds, val_ds, num_epochs=15, device=device, batch_size=32)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Batches: 100%|██████████| 1024/1024 [05:14<00:00,  3.25it/s]


train loss: 0.287620918097673
val loss: 0.27776656387140974
Saving model, val loss improved from inf to 0.27776656387140974


Batches: 100%|██████████| 1024/1024 [05:17<00:00,  3.23it/s]


train loss: 0.27634164511982817
val loss: 0.26992710668127984
Saving model, val loss improved from 0.27776656387140974 to 0.26992710668127984


Batches: 100%|██████████| 1024/1024 [05:17<00:00,  3.23it/s]


train loss: 0.26757992315106094
val loss: 0.2643775647156872
Saving model, val loss improved from 0.26992710668127984 to 0.2643775647156872


Batches: 100%|██████████| 1024/1024 [05:17<00:00,  3.23it/s]


train loss: 0.2574311575444881
val loss: 0.23683356307446957
Saving model, val loss improved from 0.2643775647156872 to 0.23683356307446957


Batches: 100%|██████████| 1024/1024 [05:17<00:00,  3.22it/s]


train loss: 0.2358769579732325
val loss: 0.2239613906131126
Saving model, val loss improved from 0.23683356307446957 to 0.2239613906131126


Batches: 100%|██████████| 1024/1024 [05:19<00:00,  3.21it/s]


train loss: 0.2249023077893071
val loss: 0.21884603495709598
Saving model, val loss improved from 0.2239613906131126 to 0.21884603495709598


Batches: 100%|██████████| 1024/1024 [05:20<00:00,  3.19it/s]


train loss: 0.22111121437046677
val loss: 0.21040274846018292
Saving model, val loss improved from 0.21884603495709598 to 0.21040274846018292


Batches: 100%|██████████| 1024/1024 [05:19<00:00,  3.21it/s]


train loss: 0.20866857252258342


Epochs:  53%|█████▎    | 8/15 [45:22<39:47, 341.12s/it]

val loss: 0.24330987792927772


Batches: 100%|██████████| 1024/1024 [05:19<00:00,  3.21it/s]


train loss: 0.20998961423174478
val loss: 0.20155801571672782
Saving model, val loss improved from 0.21040274846018292 to 0.20155801571672782


Batches: 100%|██████████| 1024/1024 [05:18<00:00,  3.21it/s]


train loss: 0.19758297451335238
val loss: 0.1950924547854811
Saving model, val loss improved from 0.20155801571672782 to 0.1950924547854811


Batches: 100%|██████████| 1024/1024 [05:23<00:00,  3.17it/s]


train loss: 0.1943886100416421


Epochs:  73%|███████▎  | 11/15 [1:02:29<22:49, 342.32s/it]

val loss: 0.22460994130233303


Batches: 100%|██████████| 1024/1024 [05:29<00:00,  3.11it/s]


train loss: 0.1956089488521684


Epochs:  80%|████████  | 12/15 [1:08:21<17:15, 345.13s/it]

val loss: 0.1953978799865581


Batches: 100%|██████████| 1024/1024 [05:17<00:00,  3.22it/s]


train loss: 0.18796783794823568
val loss: 0.18815096112666652
Saving model, val loss improved from 0.1950924547854811 to 0.18815096112666652


Batches: 100%|██████████| 1024/1024 [05:22<00:00,  3.17it/s]


train loss: 0.19842633749794913


Epochs:  93%|█████████▎| 14/15 [1:19:46<05:44, 344.08s/it]

val loss: 0.22167186182923615


Batches: 100%|██████████| 1024/1024 [05:17<00:00,  3.22it/s]


train loss: 0.19444437419588212


Epochs: 100%|██████████| 15/15 [1:25:25<00:00, 341.70s/it]

val loss: 0.25620286050252616





In [6]:
if model_type == 'resnet152':
    model = Resnet152(obj_diameter=obj_diameter).load("saved_models/resnet152_best").to(device)
elif model_type == 'resnet18':
    model = Resnet18(obj_diameter=obj_diameter).load("saved_models/resnet18_best").to(device)

In [7]:
ds_predict = test_ds
predictions = Predictor(model, device).predict_all(ds_predict)

100%|██████████| 10240/10240 [02:46<00:00, 61.54it/s]


In [8]:
transl_thresh = 0.1 # meters
rot_thresh = 10 # degrees
targets = [ds_predict[i][1].to(device) for i in range(len(ds_predict))]
te = transl_error(predictions, targets)
re = rot_error(predictions, targets)
t_recall = recall(te, transl_thresh)
r_recall = recall(re, rot_thresh)

In [9]:
print(f"Mean translational error: {torch.mean(torch.tensor(te))}")
print(f"Mean rotational error: {torch.mean(torch.tensor(re))}")

print(f"Median translational error: {torch.median(torch.tensor(te))}")
print(f"Median rotational error: {torch.median(torch.tensor(re))}")

print(f"Recall for translational error < {transl_thresh}: {t_recall}")
print(f"Recall for rotational error < {rot_thresh}: {r_recall}")

Mean translational error: 0.44666481018066406
Mean rotational error: 73.95912170410156
Median translational error: 0.42612534761428833
Median rotational error: 60.67575454711914
Recall for translational error < 0.1: 0.02109375
Recall for rotational error < 10: 0.0169921875


In [10]:
# sort entries by a combination of translational and rotational error
sorted_te = sorted(range(len(te)), key=lambda k: te[k]+re[k]/100)
for i in sorted_te:
    print(f"Index: {i}, Translational error: {te[i]}, Rotational error: {re[i]}")
    print(f"Image name: {ds_predict.flow_image_names[i]}")
    print("\n")

Index: 878, Translational error: 0.04551338031888008, Rotational error: 5.428813792702358
Image name: shoe-asifn_yellow_right_34__shoe-green_viva_sandal_right_13


Index: 465, Translational error: 0.06321455538272858, Rotational error: 4.071310486739343
Image name: shoe-asifn_yellow_right_22__shoe-green_viva_sandal_right_24


Index: 1631, Translational error: 0.042953018099069595, Rotational error: 6.8505992176421895
Image name: shoe-asifn_yellow_right_56__shoe-crocs_yellow_sandal_right_12


Index: 360, Translational error: 0.06239660084247589, Rotational error: 5.302525968721283
Image name: shoe-asifn_yellow_right_19__shoe-green_viva_sandal_right_8


Index: 3335, Translational error: 0.05909223482012749, Rotational error: 6.509153565696849
Image name: shoe-crocs_white_cyan_right_46__shoe-asifn_yellow_right_25


Index: 10210, Translational error: 0.039956796914339066, Rotational error: 8.561853606181934
Image name: shoe-pink_tiny_crocs_right_8__shoe-crocs_yellow_sandal_right_56


Index

In [11]:
rgb_images, object_masks, poses, scene_names = load_data(Path("../../output", "dataset_rendered", "shoe"))
meshes = load_meshes_o3d("../../data/housecat6d_meshes/shoe")

Loading meshes...
Loaded 21 meshes.


In [12]:
ResultVisualization(meshes).visualize_results(dataset=ds_predict, rgb_images=rgb_images, poses=poses, predictions=predictions, te=te, re=re)

[Open3D INFO] EGL headless mode enabled.
FEngine (64 bits) created at 0x561ea59a3e50 (threading is enabled)
EGL(1.5)
OpenGL(4.1)


100%|██████████| 10240/10240 [41:45<00:00,  4.09it/s]


- pretrained weights for resnet18 (full retraining, only layer 4)
- different rotation representations (quats), https://github.com/naver/roma