In [1]:
%load_ext autoreload
%autoreload 2
import os
from pathlib import Path
import sys
import torch
from src.models import Resnet18, Resnet152
from src.dataset import FlowImageDataset
from src.trainer import train
from src.predictor import Predictor
from src.metrics import transl_error, rot_error, recall
from src.result_visualization import ResultVisualization
sys.path.append(os.path.abspath(os.path.join('..', 'utils')))
from load_data import load_data, load_meshes_o3d



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
obj_diameter = 0.24
do_training = True
use_pretrained = True
model_type = 'resnet152'

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [4]:
flow_images_path = Path("../../output", "flow_images")
poses_path = Path("../../output", "dataset_rendered", "shoe")
if do_training:
    train_ds = FlowImageDataset(flow_image_path=flow_images_path / "train", poses_path=poses_path, subset="train")
    val_ds = FlowImageDataset(flow_image_path=flow_images_path / "train", poses_path=poses_path, subset="val")
test_ds = FlowImageDataset(flow_image_path=flow_images_path / "test", poses_path=poses_path, subset="test")

100%|██████████| 40960/40960 [01:01<00:00, 661.08it/s] 
32768it [00:01, 28323.69it/s]


torch.Size([32768, 4, 4])


100%|██████████| 40960/40960 [00:33<00:00, 1230.83it/s]
8192it [00:00, 27742.76it/s]


torch.Size([8192, 4, 4])


100%|██████████| 10240/10240 [00:17<00:00, 578.15it/s]
10240it [00:00, 34821.92it/s]

torch.Size([10240, 4, 4])





In [5]:
if do_training:
    if model_type == 'resnet152':
        model = Resnet152(obj_diameter=obj_diameter, pretrained=use_pretrained).to(device)
    elif model_type == 'resnet18':
        model = Resnet18(obj_diameter=obj_diameter, pretrained=use_pretrained).to(device)
    model.name = model_type
    model = train(model, train_ds, val_ds, lr=0.01, gamma = 0.9, num_epochs=100, device=device, batch_size=32)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Batches: 100%|██████████| 1024/1024 [05:19<00:00,  3.20it/s]


train loss: 0.18458347888372373
val loss: 0.17918225919129327
Saving model, val loss improved from inf to 0.17918225919129327


Batches: 100%|██████████| 1024/1024 [05:24<00:00,  3.16it/s]


train loss: 0.16309547120908974
val loss: 0.1680671272915788
Saving model, val loss improved from 0.17918225919129327 to 0.1680671272915788


Batches: 100%|██████████| 1024/1024 [05:10<00:00,  3.30it/s]


train loss: 0.15100617118150694
val loss: 0.14751798199722543
Saving model, val loss improved from 0.1680671272915788 to 0.14751798199722543


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.14080013160128146


Epochs:   4%|▍         | 4/100 [22:14<8:46:14, 328.90s/it]

val loss: 0.2674938851268962


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.13443644208746264
val loss: 0.1435037330375053
Saving model, val loss improved from 0.14751798199722543 to 0.1435037330375053


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.13006486438098364
val loss: 0.13420761303859763
Saving model, val loss improved from 0.1435037330375053 to 0.13420761303859763


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.12884418181783985


Epochs:   7%|▋         | 7/100 [37:54<8:12:12, 317.55s/it]

val loss: 0.14844125171657652


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.12374034424283309
val loss: 0.12389422618434764
Saving model, val loss improved from 0.13420761303859763 to 0.12389422618434764


Batches: 100%|██████████| 1024/1024 [05:00<00:00,  3.40it/s]


train loss: 0.12096088723046705


Epochs:   9%|▉         | 9/100 [48:28<8:01:39, 317.57s/it]

val loss: 0.21902211441192776


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.11629187023208942
val loss: 0.11929396408959292
Saving model, val loss improved from 0.12389422618434764 to 0.11929396408959292


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.11409525988347013


Epochs:  11%|█         | 11/100 [58:54<7:47:32, 315.20s/it]

val loss: 0.12637451192131266


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.11134654535271693


Epochs:  12%|█▏        | 12/100 [1:04:07<7:41:04, 314.37s/it]

val loss: 0.12285113826510496


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.10984927478421014
val loss: 0.11597447391250171
Saving model, val loss improved from 0.11929396408959292 to 0.11597447391250171


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]]


train loss: 0.10673686757945688


Epochs:  14%|█▍        | 14/100 [1:14:33<7:29:33, 313.64s/it]

val loss: 0.11808934563305229


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.10353326156109688


Epochs:  15%|█▌        | 15/100 [1:19:45<7:23:49, 313.28s/it]

val loss: 0.12300014894572087


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.10060893171612406


Epochs:  16%|█▌        | 16/100 [1:24:58<7:18:13, 313.01s/it]

val loss: 0.12008620597771369


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.1046256441368314


Epochs:  17%|█▋        | 17/100 [1:30:10<7:12:46, 312.84s/it]

val loss: 0.21540052228374407


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.09914891182052088
val loss: 0.11377903845277615
Saving model, val loss improved from 0.11597447391250171 to 0.11377903845277615


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]]


train loss: 0.0950524995860178


Epochs:  19%|█▉        | 19/100 [1:40:36<7:02:24, 312.89s/it]

val loss: 0.1566012630937621


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.09185875917683006
val loss: 0.11065335251623765
Saving model, val loss improved from 0.11377903845277615 to 0.11065335251623765


Batches: 100%|██████████| 1024/1024 [04:54<00:00,  3.48it/s]]


train loss: 0.08807422863173997


Epochs:  21%|██        | 21/100 [1:51:04<6:52:34, 313.35s/it]

val loss: 0.11806579618132673


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.08509467281328398
val loss: 0.1046891455334844
Saving model, val loss improved from 0.11065335251623765 to 0.1046891455334844


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]]


train loss: 0.0829674150008941


Epochs:  23%|██▎       | 23/100 [2:01:30<6:41:53, 313.16s/it]

val loss: 0.10984811205707956


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.07955158176991972


Epochs:  24%|██▍       | 24/100 [2:06:42<6:36:25, 312.97s/it]

val loss: 0.18581404435099103


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.07738268315733876


Epochs:  25%|██▌       | 25/100 [2:11:55<6:31:02, 312.83s/it]

val loss: 0.14685940550407395


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.07389818027695583


Epochs:  26%|██▌       | 26/100 [2:17:07<6:25:42, 312.74s/it]

val loss: 0.12809747227584012


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.07019550841323507
val loss: 0.10124453554453794
Saving model, val loss improved from 0.1046891455334844 to 0.10124453554453794


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]]


train loss: 0.06705344829060778


Epochs:  28%|██▊       | 28/100 [2:27:34<6:15:28, 312.89s/it]

val loss: 0.10426049646048341


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.49it/s]


train loss: 0.0647681712780468


Epochs:  29%|██▉       | 29/100 [2:32:46<6:10:08, 312.80s/it]

val loss: 0.10289272177033126


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.06160580897631007


Epochs:  30%|███       | 30/100 [2:37:59<6:04:51, 312.74s/it]

val loss: 0.10157460383197758


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.05765492416685447


Epochs:  31%|███       | 31/100 [2:43:11<5:59:33, 312.66s/it]

val loss: 0.1032660627970472


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.05459813262314128


Epochs:  32%|███▏      | 32/100 [2:48:24<5:54:18, 312.63s/it]

val loss: 0.10301046645327006


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.051531642600821215


Epochs:  33%|███▎      | 33/100 [2:53:36<5:49:05, 312.61s/it]

val loss: 0.10171873883518856


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.04846961762814317


Epochs:  34%|███▍      | 34/100 [2:58:49<5:43:52, 312.62s/it]

val loss: 0.10273090329428669


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.04468151664877951


Epochs:  35%|███▌      | 35/100 [3:04:02<5:38:39, 312.60s/it]

val loss: 0.1023027904011542


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.04180139896379842


Epochs:  36%|███▌      | 36/100 [3:09:14<5:33:23, 312.55s/it]

val loss: 0.10220722256053705


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.039056406585586956


Epochs:  37%|███▋      | 37/100 [3:14:27<5:28:09, 312.54s/it]

val loss: 0.10270441762986593


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.037632820083672414


Epochs:  38%|███▊      | 38/100 [3:19:39<5:22:56, 312.52s/it]

val loss: 0.10299710846447852


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.03357050178965437


Epochs:  39%|███▉      | 39/100 [3:24:52<5:17:44, 312.54s/it]

val loss: 0.10367671790299937


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.031115279857658606


Epochs:  40%|████      | 40/100 [3:30:04<5:12:32, 312.54s/it]

val loss: 0.10407665514503606


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.028666786900430452


Epochs:  41%|████      | 41/100 [3:35:17<5:07:20, 312.54s/it]

val loss: 0.10482774402771611


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.026650783580407733


Epochs:  42%|████▏     | 42/100 [3:40:29<5:02:08, 312.56s/it]

val loss: 0.10459270514547825


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.02460253593562811


Epochs:  43%|████▎     | 43/100 [3:45:42<4:56:55, 312.55s/it]

val loss: 0.10502217376779299


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.02260251587267703


Epochs:  44%|████▍     | 44/100 [3:50:54<4:51:43, 312.57s/it]

val loss: 0.10608575701189693


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.020720983920000435


Epochs:  45%|████▌     | 45/100 [3:56:07<4:46:30, 312.56s/it]

val loss: 0.10580888838740066


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.019402310192162986


Epochs:  46%|████▌     | 46/100 [4:01:20<4:41:18, 312.57s/it]

val loss: 0.1057721934048459


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.01792389429465402


Epochs:  47%|████▋     | 47/100 [4:06:32<4:36:07, 312.59s/it]

val loss: 0.10637813572247978


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.49it/s]


train loss: 0.016669014142280503


Epochs:  48%|████▊     | 48/100 [4:11:45<4:30:54, 312.58s/it]

val loss: 0.10657858489139471


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.0155041937932765


Epochs:  49%|████▉     | 49/100 [4:16:57<4:25:41, 312.57s/it]

val loss: 0.10605456623306964


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.014734410153323552


Epochs:  50%|█████     | 50/100 [4:22:10<4:20:29, 312.59s/it]

val loss: 0.10613852569076698


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.013727382501201646


Epochs:  51%|█████     | 51/100 [4:27:23<4:15:16, 312.58s/it]

val loss: 0.10719715966843069


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.01293220004572504


Epochs:  52%|█████▏    | 52/100 [4:32:35<4:10:03, 312.57s/it]

val loss: 0.10700489878945518


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.01221306989509685


Epochs:  53%|█████▎    | 53/100 [4:37:48<4:04:50, 312.57s/it]

val loss: 0.10696141264634207


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.011612859176693746


Epochs:  54%|█████▍    | 54/100 [4:43:00<3:59:39, 312.59s/it]

val loss: 0.10718280936998781


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.011058689506171504


Epochs:  55%|█████▌    | 55/100 [4:48:13<3:54:26, 312.58s/it]

val loss: 0.10704968174104579


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.01070017675829149


Epochs:  56%|█████▌    | 56/100 [4:53:26<3:49:24, 312.82s/it]

val loss: 0.10719361435621977


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.01022781214214774


Epochs:  57%|█████▋    | 57/100 [4:58:39<3:44:15, 312.92s/it]

val loss: 0.10715497013006825


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.009748621559538151


Epochs:  58%|█████▊    | 58/100 [5:03:52<3:38:58, 312.82s/it]

val loss: 0.10678760905284435


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.009540376274344453


Epochs:  59%|█████▉    | 59/100 [5:09:05<3:33:43, 312.77s/it]

val loss: 0.10708678032096941


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.009195955424274871


Epochs:  60%|██████    | 60/100 [5:14:17<3:28:29, 312.75s/it]

val loss: 0.10664306068792939


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.008977113491710043


Epochs:  61%|██████    | 61/100 [5:19:30<3:23:14, 312.69s/it]

val loss: 0.10689619864569977


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.008758518190461473


Epochs:  62%|██████▏   | 62/100 [5:24:43<3:18:02, 312.69s/it]

val loss: 0.10674839047715068


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.008540687144659387


Epochs:  63%|██████▎   | 63/100 [5:29:55<3:12:48, 312.67s/it]

val loss: 0.10728257795562968


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.008387019864130707


Epochs:  64%|██████▍   | 64/100 [5:35:08<3:07:34, 312.63s/it]

val loss: 0.10702602539095096


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.008147835503223178


Epochs:  65%|██████▌   | 65/100 [5:40:20<3:02:21, 312.63s/it]

val loss: 0.10783588024787605


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.00805584102135981


Epochs:  66%|██████▌   | 66/100 [5:45:33<2:57:10, 312.65s/it]

val loss: 0.10700647241901606


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007897585927139517


Epochs:  67%|██████▋   | 67/100 [5:50:46<2:51:57, 312.65s/it]

val loss: 0.10682258468295913


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007874108280248038


Epochs:  68%|██████▊   | 68/100 [5:55:58<2:46:44, 312.65s/it]

val loss: 0.10749687692441512


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.49it/s]


train loss: 0.007725472936726874


Epochs:  69%|██████▉   | 69/100 [6:01:11<2:41:31, 312.63s/it]

val loss: 0.10764883596857544


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.00768292591737918


Epochs:  70%|███████   | 70/100 [6:06:24<2:36:18, 312.63s/it]

val loss: 0.10771310293057468


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.00766284284327412


Epochs:  71%|███████   | 71/100 [6:11:36<2:31:06, 312.65s/it]

val loss: 0.10710155917331576


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007556885430631155


Epochs:  72%|███████▏  | 72/100 [6:16:49<2:25:53, 312.64s/it]

val loss: 0.1071447620051913


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.007409393024772726


Epochs:  73%|███████▎  | 73/100 [6:22:01<2:20:40, 312.61s/it]

val loss: 0.10774025395221543


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007446199426567546


Epochs:  74%|███████▍  | 74/100 [6:27:14<2:15:28, 312.62s/it]

val loss: 0.10727056337054819


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.00732971825345885


Epochs:  75%|███████▌  | 75/100 [6:32:27<2:10:14, 312.59s/it]

val loss: 0.10741435746604111


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007269351586955963


Epochs:  76%|███████▌  | 76/100 [6:37:39<2:05:02, 312.60s/it]

val loss: 0.10683195595629513


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.0073194034368953


Epochs:  77%|███████▋  | 77/100 [6:42:52<1:59:50, 312.61s/it]

val loss: 0.1075487092602998


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.007197170675226516


Epochs:  78%|███████▊  | 78/100 [6:48:04<1:54:36, 312.59s/it]

val loss: 0.10705840155424085


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007257690981987253


Epochs:  79%|███████▉  | 79/100 [6:53:17<1:49:24, 312.61s/it]

val loss: 0.10700013922178186


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.0071379713399437605


Epochs:  80%|████████  | 80/100 [6:58:30<1:44:12, 312.61s/it]

val loss: 0.10735132238187362


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.007195452568339533


Epochs:  81%|████████  | 81/100 [7:03:42<1:38:59, 312.59s/it]

val loss: 0.10734699065505993


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.007138637718981045


Epochs:  82%|████████▏ | 82/100 [7:08:55<1:33:46, 312.56s/it]

val loss: 0.10687710893398616


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007123184342617606


Epochs:  83%|████████▎ | 83/100 [7:14:07<1:28:33, 312.59s/it]

val loss: 0.10729435745452065


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.0070240648465187405


Epochs:  84%|████████▍ | 84/100 [7:19:20<1:23:21, 312.59s/it]

val loss: 0.10729391440690961


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.0070728102407429105


Epochs:  85%|████████▌ | 85/100 [7:24:32<1:18:08, 312.59s/it]

val loss: 0.10708676195645239


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007082822025040514


Epochs:  86%|████████▌ | 86/100 [7:29:45<1:12:56, 312.61s/it]

val loss: 0.10694304294884205


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.00714281573846165


Epochs:  87%|████████▋ | 87/100 [7:34:58<1:07:43, 312.60s/it]

val loss: 0.10740416552289389


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.007089902144798543


Epochs:  88%|████████▊ | 88/100 [7:40:10<1:02:30, 312.57s/it]

val loss: 0.10672250534116756


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007072038432397676


Epochs:  89%|████████▉ | 89/100 [7:45:23<57:18, 312.58s/it]  

val loss: 0.10762579904985614


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007010468798171132


Epochs:  90%|█████████ | 90/100 [7:50:35<52:05, 312.59s/it]

val loss: 0.10718669020570815


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.49it/s]


train loss: 0.006981026992889383


Epochs:  91%|█████████ | 91/100 [7:55:48<46:53, 312.59s/it]

val loss: 0.10691308059904259


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.006966480551227505


Epochs:  92%|█████████▏| 92/100 [8:01:01<41:40, 312.60s/it]

val loss: 0.107053330342751


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.007000626757871942


Epochs:  93%|█████████▎| 93/100 [8:06:13<36:28, 312.61s/it]

val loss: 0.10707954164536204


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.0069846140834215475


Epochs:  94%|█████████▍| 94/100 [8:11:26<31:15, 312.60s/it]

val loss: 0.10705991406575777


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.006938392284610018


Epochs:  95%|█████████▌| 95/100 [8:16:38<26:02, 312.57s/it]

val loss: 0.10700226882181596


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.006973384451157472


Epochs:  96%|█████████▌| 96/100 [8:21:51<20:50, 312.59s/it]

val loss: 0.10740478265506681


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.0069631517710604385


Epochs:  97%|█████████▋| 97/100 [8:27:04<15:37, 312.59s/it]

val loss: 0.10749675107945222


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.006950721161047113


Epochs:  98%|█████████▊| 98/100 [8:32:16<10:25, 312.58s/it]

val loss: 0.10724482165824156


Batches: 100%|██████████| 1024/1024 [04:53<00:00,  3.49it/s]


train loss: 0.006970251349684986


Epochs:  99%|█████████▉| 99/100 [8:37:29<05:12, 312.59s/it]

val loss: 0.10678169057064224


Batches: 100%|██████████| 1024/1024 [04:52<00:00,  3.50it/s]


train loss: 0.006923157507344513


Epochs: 100%|██████████| 100/100 [8:42:41<00:00, 313.62s/it]

val loss: 0.10714975230803248





In [6]:
if model_type == 'resnet152':
    model = Resnet152(obj_diameter=obj_diameter).load("saved_models/resnet152_best").to(device)
elif model_type == 'resnet18':
    model = Resnet18(obj_diameter=obj_diameter).load("saved_models/resnet18_best").to(device)

In [16]:
ds_predict = test_ds
predictions = Predictor(model, device).predict_all(ds_predict)

100%|██████████| 10240/10240 [02:57<00:00, 57.59it/s]


In [17]:
transl_thresh = 0.1 # meters
rot_thresh = 10 # degrees
targets = [ds_predict[i][1].to(device) for i in range(len(ds_predict))]
te = transl_error(predictions, targets)
re = rot_error(predictions, targets)
t_recall = recall(te, transl_thresh)
r_recall = recall(re, rot_thresh)

In [18]:
print(f"Mean translational error: {torch.mean(torch.tensor(te))}")
print(f"Mean rotational error: {torch.mean(torch.tensor(re))}")

print(f"Median translational error: {torch.median(torch.tensor(te))}")
print(f"Median rotational error: {torch.median(torch.tensor(re))}")

print(f"Recall for translational error < {transl_thresh}: {t_recall}")
print(f"Recall for rotational error < {rot_thresh}: {r_recall}")

Mean translational error: 0.40934133529663086
Mean rotational error: 63.2928352355957
Median translational error: 0.35013917088508606
Median rotational error: 39.128639221191406
Recall for translational error < 0.1: 0.04482421875
Recall for rotational error < 10: 0.04912109375


In [10]:
# sort entries by a combination of translational and rotational error
sorted_te = sorted(range(len(te)), key=lambda k: te[k]+re[k]/100)
for i in sorted_te:
    print(f"Index: {i}, Translational error: {te[i]}, Rotational error: {re[i]}")
    print(f"Image name: {ds_predict.flow_image_names[i]}")
    print("\n")

Index: 7864, Translational error: 0.018313737586140633, Rotational error: 2.4241432554672273
Image name: shoe-green_viva_sandal_right_5__shoe-asifn_yellow_right_17


Index: 3300, Translational error: 0.038770515471696854, Rotational error: 2.0578544577615534
Image name: shoe-crocs_white_cyan_right_45__shoe-asifn_yellow_right_1


Index: 1300, Translational error: 0.06406048685312271, Rotational error: 2.3189997003474856
Image name: shoe-asifn_yellow_right_47__shoe-crocs_yellow_sandal_right_24


Index: 373, Translational error: 0.039687495678663254, Rotational error: 5.121319488058804
Image name: shoe-asifn_yellow_right_1__shoe-crocs_white_cyan_right_41


Index: 8146, Translational error: 0.053556300699710846, Rotational error: 4.509541580531632
Image name: shoe-green_viva_sandal_right_9__shoe-pink_tiny_crocs_right_6


Index: 4754, Translational error: 0.05111522972583771, Rotational error: 4.857525743093997
Image name: shoe-crocs_yellow_sandal_right_28__shoe-green_viva_sandal_right_42



In [11]:
rgb_images, object_masks, poses, scene_names = load_data(Path("../../output", "dataset_rendered", "shoe"))
meshes = load_meshes_o3d("../../data/housecat6d_meshes/shoe")

Loading meshes...
Loaded 21 meshes.


In [12]:
ResultVisualization(meshes).visualize_results(dataset=ds_predict, rgb_images=rgb_images, poses=poses, predictions=predictions, te=te, re=re)

[Open3D INFO] EGL headless mode enabled.
FEngine (64 bits) created at 0x564e053df8a0 (threading is enabled)
EGL(1.5)
OpenGL(4.1)


100%|██████████| 10240/10240 [33:00<00:00,  5.17it/s]


- different rotation representations (quats), https://github.com/naver/roma
- adaption layer
- only pairs of objects with same distance 
- fixed radius
- maybe polar coordinates?
- maybe other object
- differentiable rendering