In [1]:
import os
import torch
import random
import pprint
import numpy as np
from loguru import logger
from torch.utils.data import DataLoader
from torch import nn

from pose.dataset import pose_dataset
from pose.utils import (
    collate_fn,
    geodesic_distance,
    relative_pose_error,
    relative_pose_error_np,
    recall_object,
    aggregate_metrics
)
from pose.model import Mkpts_Reg_Model
from pose.animator import Animator


if os.name == 'nt':
    LM_dataset_path = 'd:/git_project/POPE/data/LM_dataset/'
    LM_dataset_json_path = 'd:/git_project/POPE/data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'd:/git_project/POPE/data/LM_dataset-points/'

    onepose_path = 'e:/datasets/OnePose/test_data/'
    onepose_json_path = 'd:/git_project/POPE/data/pairs/Onepose-test.json'
    onepose_points_path = 'd:/git_project/POPE/data/onepose-points/'

    oneposeplusplus_path = 'e:/datasets/OnePose++/lowtexture_test_data/'
    oneposeplusplus_json_path = 'd:/git_project/POPE/data/pairs/OneposePlusPlus-test.json'
    oneposeplusplus_points_path = 'd:/git_project/POPE/data/oneposeplusplus-points/'
elif os.name == 'posix':
    LM_dataset_path = 'data/LM_dataset/'
    LM_dataset_json_path = 'data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'data/LM_dataset-points/'

    onepose_path = 'data/onepose/'
    onepose_json_path = 'data/pairs/Onepose-test.json'
    onepose_points_path = 'data/onepose-points/'

    oneposeplusplus_path = 'data/oneposeplusplus/'
    oneposeplusplus_json_path = 'data/pairs/OneposePlusPlus-test.json'
    oneposeplusplus_points_path = 'data/oneposeplusplus-points/'

paths = [
    # ('linemod', LM_dataset_path, LM_dataset_json_path, LM_dataset_points_path),
    # ('onepose', onepose_path, onepose_json_path, onepose_points_path),
    ('oneposeplusplus', oneposeplusplus_path, oneposeplusplus_json_path, oneposeplusplus_points_path),
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = pose_dataset(paths)
mkpts_max_len, mkpts_sum_len = dataset.get_mkpts_info()

 44%|████▍     | 4/9 [00:41<00:52, 10.44s/it]

d:/git_project/POPE/data/oneposeplusplus-points/0706-teabox-box\mkpts0\820.png-761.png.txt does not exist


 89%|████████▉ | 8/9 [01:25<00:10, 10.22s/it]

d:/git_project/POPE/data/oneposeplusplus-points/0712-insta-others\mkpts0\1570.png-125.png.txt does not exist
d:/git_project/POPE/data/oneposeplusplus-points/0712-insta-others\mkpts0\1605.png-185.png.txt does not exist
d:/git_project/POPE/data/oneposeplusplus-points/0712-insta-others\mkpts0\1628.png-210.png.txt does not exist


100%|██████████| 9/9 [01:45<00:00, 11.70s/it]


In [2]:
random.seed(20231223)
torch.manual_seed(20231223)
torch.cuda.manual_seed(20231223)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [3]:
num_sample = 500
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_fn(num_sample))
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_fn(num_sample))

In [4]:
net = torch.load('./weights/oneposeplusplus-relative_r-gt_t-6d-500-2024-01-08-12-54-48-0.2991.pth').to(device)

net.eval()

Mkpts_Reg_Model(
  (embedding): Embedding()
  (transformerlayer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
    )
    (linear1): Linear(in_features=76, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=76, bias=True)
    (norm1): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
        )
        (linear1): Linear(in_features=76, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inp

In [5]:
id2name_dict = {
    1: "ape",
    2: "benchvise",
    4: "camera",
    5: "can",
    6: "cat",
    8: "driller",
    9: "duck",
    10: "eggbox",
    11: "glue",
    12: "holepuncher",
    13: "iron",
    14: "lamp",
    15: "phone",
}

In [6]:
# linemod
ape_data = []
benchvise_data = []
camera_data = []
can_data = []
cat_data = []
driller_data = []
duck_data = []
eggbox_data = []
glue_data = []
holepuncher_data = []
iron_data = []
lamp_data = []
phone_data = []
# onepose
aptamil_data = []
jzhg_data = []
minipuff_data = []
hlyormosiapie_data = []
brownhouse_data = []
oreo_data = []
mfmilkcake_data = []
diycookies_data = []
taipingcookies_data = []
tee_data = []
# onepose++
toyrobot_data = []
yellowduck_data = []
sheep_data = []
fakebanana_data = []
teabox_data = []
orange_data = []
greenteapot_data = []
lecreusetcup_data = []
insta_data = []

all_data = {
    # linemod
    'ape_data': ape_data,
    'benchvise_data': benchvise_data,
    'camera_data': camera_data,
    'can_data': can_data,
    'cat_data': cat_data,
    'driller_data': driller_data,
    'duck_data': duck_data,
    'eggbox_data': eggbox_data,
    'glue_data': glue_data,
    'holepuncher_data': holepuncher_data,
    'iron_data': iron_data,
    'lamp_data': lamp_data,
    'phone_data': phone_data,
    # onepose
    'aptamil_data': aptamil_data,
    'jzhg_data': jzhg_data,
    'minipuff_data': minipuff_data,
    'hlyormosiapie_data': hlyormosiapie_data,
    'brownhouse_data': brownhouse_data,
    'oreo_data': oreo_data,
    'mfmilkcake_data': mfmilkcake_data,
    'diycookies_data': diycookies_data,
    'taipingcookies_data': taipingcookies_data,
    'tee_data': tee_data,
    # onepose++
    'toyrobot_data': toyrobot_data,
    'yellowduck_data': yellowduck_data,
    'sheep_data': sheep_data,
    'fakebanana_data': fakebanana_data,
    'teabox_data': teabox_data,
    'orange_data': orange_data,
    'greenteapot_data': greenteapot_data,
    'lecreusetcup_data': lecreusetcup_data,
    'insta_data': insta_data,
}

In [7]:
linemod_type = ['ape_data', 'benchvise_data', 'camera_data', 'can_data', 'cat_data', 'driller_data', 'duck_data', 'eggbox_data', 'glue_data', 'holepuncher_data', 'iron_data', 'lamp_data', 'phone_data']
onepose_type = ['aptamil_data', 'jzhg_data', 'minipuff_data', 'hlyormosiapie_data', 'brownhouse_data', 'oreo_data', 'mfmilkcake_data', 'diycookies_data', 'taipingcookies_data', 'tee_data']
oneposeplusplus_type = ['toyrobot_data', 'yellowduck_data', 'sheep_data', 'fakebanana_data', 'teabox_data', 'orange_data', 'greenteapot_data', 'lecreusetcup_data', 'insta_data']

In [8]:
for i, batch in enumerate(test_dataloader):
    for data in batch:
        if 'lm' in data['name']:
            all_data[f"{id2name_dict[int(data['name'][2:])]}_data"].append(data)
        else:
            all_data[f"{data['name']}_data"].append(data)

empty_keys = []
for key in all_data.keys():
    if len(all_data[key]) == 0:
        empty_keys.append(key)

for key in empty_keys:
    all_data.pop(key)

for key in all_data.keys():
    print(key, len(all_data[key]))

print('len(all_data):', len(all_data))

toyrobot_data 57
yellowduck_data 66
sheep_data 52
fakebanana_data 48
teabox_data 65
orange_data 75
greenteapot_data 40
lecreusetcup_data 47
insta_data 94
len(all_data): 9


In [9]:
res_table = []

model_type = 'relative_r-gt_t'

for key in all_data.keys():
    if key in linemod_type:
        logger.info(f"LINEMOD: {key}")
    elif key in onepose_type:
        logger.info(f"ONEPOSE: {key}")
    elif key in oneposeplusplus_type:
        logger.info(f"ONEPOSE++: {key}")
    metrics = dict()
    metrics.update({'R_errs':[], 't_errs':[], 'inliers':[], "identifiers":[]})
    recall_image, all_image = 0, 0
    for item in all_data[key]:
        all_image += 1
        K0 = item['K0']
        K1 = item['K1']
        pose0 = item['pose0']
        pose1 = item['pose1']
        pre_bbox = item['pre_bbox']
        gt_bbox = item['gt_bbox']
        mkpts0 = item['mkpts0']
        mkpts1 = item['mkpts1']
        pre_K = item['pre_K']
        name = item['name']
        pair_name = item['pair_name']
        if 'lm' in name:
            name = id2name_dict[int(name[2:])]
        if name not in key:
            print(f'name: {name}, key: {key}')
            continue

        is_recalled = recall_object(pre_bbox, gt_bbox)

        recall_image = recall_image + int(is_recalled > 0.5)

        batch_mkpts0 = torch.from_numpy(mkpts0).unsqueeze(0).float().to(device)
        batch_mkpts1 = torch.from_numpy(mkpts1).unsqueeze(0).float().to(device)
        pre_t, pre_rot = net(batch_mkpts0, batch_mkpts1)
        pre_t = pre_t.squeeze(0).detach().cpu().numpy()
        pre_rot = pre_rot.squeeze(0).detach().cpu().numpy()

        if model_type == 'gt':
            t_err, R_err = relative_pose_error_np(pose1, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            t_err, R_err = relative_pose_error_np(relative_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative_r-gt_t':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            gt_pose = np.zeros_like(pose1)
            gt_pose[:3, :3] = relative_pose[:3, :3]
            gt_pose[:3, 3] = pose1[:3, 3]
            t_err, R_err = relative_pose_error_np(gt_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)

        metrics['R_errs'].append(R_err)
        metrics['t_errs'].append(t_err)
        metrics['identifiers'].append(pair_name)

    print(f"Acc: {recall_image}/{all_image}")
    val_metrics_4tb = aggregate_metrics(metrics, 5e-4)
    val_metrics_4tb["AP50"] = recall_image / all_image
    logger.info('\n' + pprint.pformat(val_metrics_4tb))
    res_table.append([f"{name}"] + list(val_metrics_4tb.values()))

[32m2024-01-09 18:46:34.078[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: toyrobot_data[0m


[32m2024-01-09 18:46:36.505[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 57 unique items...[0m
[32m2024-01-09 18:46:36.507[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.9824561403508771,
 'R:ACC15': 0.42105263157894735,
 'R:ACC30': 1.0,
 'R:auc@15': 0.17441816157479478,
 'R:auc@30': 0.45039877170968623,
 'R:meanErr': 16.746686732132154,
 'R:medianErr': 16.910678289008633,
 't:ACC15': 0.9122807017543859,
 't:ACC30': 1.0,
 't:auc@15': 0.4860189299902005,
 't:auc@30': 0.7326261309560894,
 't:meanErr': 8.210698529548818,
 't:medianErr': 7.001649781412204}[0m
[32m2024-01-09 18:46:36.508[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: yellowduck_data[0m


Acc: 56/57


[32m2024-01-09 18:46:37.060[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 66 unique items...[0m
[32m2024-01-09 18:46:37.061[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 1.0,
 'R:ACC15': 0.5909090909090909,
 'R:ACC30': 0.9848484848484849,
 'R:auc@15': 0.2765160081798727,
 'R:auc@30': 0.5564451853669257,
 'R:meanErr': 13.587550045729929,
 'R:medianErr': 12.99495717049,
 't:ACC15': 0.8484848484848485,
 't:ACC30': 1.0,
 't:auc@15': 0.381567459568012,
 't:auc@30': 0.6805587330374812,
 't:meanErr': 9.741439352180862,
 't:medianErr': 8.871546270649084}[0m
[32m2024-01-09 18:46:37.062[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: sheep_data[0m


Acc: 66/66


[32m2024-01-09 18:46:37.499[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 52 unique items...[0m
[32m2024-01-09 18:46:37.501[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.36538461538461536,
 'R:ACC15': 0.4807692307692308,
 'R:ACC30': 0.9423076923076923,
 'R:auc@15': 0.26698720127905856,
 'R:auc@30': 0.48839144596287204,
 'R:meanErr': 16.54961828413549,
 'R:medianErr': 15.751907093999314,
 't:ACC15': 0.8653846153846154,
 't:ACC30': 0.9615384615384616,
 't:auc@15': 0.4891064703521516,
 't:auc@30': 0.7133443335622872,
 't:meanErr': 9.249308320796477,
 't:medianErr': 7.284670053653413}[0m
[32m2024-01-09 18:46:37.502[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: fakebanana_data[0m


Acc: 19/52


[32m2024-01-09 18:46:37.931[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 48 unique items...[0m
[32m2024-01-09 18:46:37.932[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.9791666666666666,
 'R:ACC15': 0.4583333333333333,
 'R:ACC30': 0.9791666666666666,
 'R:auc@15': 0.17004019232694242,
 'R:auc@30': 0.4723171283854175,
 'R:meanErr': 16.14300129286243,
 'R:medianErr': 16.2673991395759,
 't:ACC15': 0.8333333333333334,
 't:ACC30': 1.0,
 't:auc@15': 0.37851097769017006,
 't:auc@30': 0.6749347872327565,
 't:meanErr': 10.048648565956848,
 't:medianErr': 9.575894216131594}[0m
[32m2024-01-09 18:46:37.932[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: teabox_data[0m


Acc: 47/48


[32m2024-01-09 18:46:38.484[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 65 unique items...[0m
[32m2024-01-09 18:46:38.485[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.9538461538461539,
 'R:ACC15': 0.5846153846153846,
 'R:ACC30': 0.9692307692307692,
 'R:auc@15': 0.2952544465009674,
 'R:auc@30': 0.5425873659278255,
 'R:meanErr': 14.019385777084452,
 'R:medianErr': 13.036377316456589,
 't:ACC15': 0.8461538461538461,
 't:ACC30': 1.0,
 't:auc@15': 0.40097474640438296,
 't:auc@30': 0.6870902322417695,
 't:meanErr': 9.55057130682673,
 't:medianErr': 8.005827965965667}[0m
[32m2024-01-09 18:46:38.486[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: orange_data[0m


Acc: 62/65


[32m2024-01-09 18:46:39.131[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 75 unique items...[0m
[32m2024-01-09 18:46:39.133[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.9866666666666667,
 'R:ACC15': 0.4266666666666667,
 'R:ACC30': 0.9733333333333334,
 'R:auc@15': 0.21006696667742955,
 'R:auc@30': 0.4702857812312772,
 'R:meanErr': 16.152388619565446,
 'R:medianErr': 15.782031879196564,
 't:ACC15': 0.9466666666666667,
 't:ACC30': 0.9866666666666667,
 't:auc@15': 0.4963366165880925,
 't:auc@30': 0.7365126366257913,
 't:meanErr': 8.205684904915957,
 't:medianErr': 7.504364951294381}[0m
[32m2024-01-09 18:46:39.134[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: greenteapot_data[0m


Acc: 74/75


[32m2024-01-09 18:46:39.472[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 40 unique items...[0m
[32m2024-01-09 18:46:39.474[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.925,
 'R:ACC15': 0.45,
 'R:ACC30': 0.975,
 'R:auc@15': 0.24792726908893983,
 'R:auc@30': 0.497338270481018,
 'R:meanErr': 15.471956317318307,
 'R:medianErr': 16.52800862766831,
 't:ACC15': 0.65,
 't:ACC30': 1.0,
 't:auc@15': 0.34094300949045003,
 't:auc@30': 0.6377749247147259,
 't:meanErr': 11.123970178345344,
 't:medianErr': 10.496113477478195}[0m
[32m2024-01-09 18:46:39.474[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: lecreusetcup_data[0m


Acc: 37/40


[32m2024-01-09 18:46:39.880[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 47 unique items...[0m
[32m2024-01-09 18:46:39.882[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.8297872340425532,
 'R:ACC15': 0.5319148936170213,
 'R:ACC30': 1.0,
 'R:auc@15': 0.3021982389868752,
 'R:auc@30': 0.5438976496853016,
 'R:meanErr': 13.996661578413418,
 'R:medianErr': 12.11751671267035,
 't:ACC15': 0.723404255319149,
 't:ACC30': 0.9787234042553191,
 't:auc@15': 0.3355227267020558,
 't:auc@30': 0.6291141161580183,
 't:meanErr': 11.781983064644159,
 't:medianErr': 9.633321072300133}[0m
[32m2024-01-09 18:46:39.883[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mONEPOSE++: insta_data[0m


Acc: 39/47


[32m2024-01-09 18:46:40.682[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 94 unique items...[0m
[32m2024-01-09 18:46:40.684[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1m
{'AP50': 0.5638297872340425,
 'R:ACC15': 0.425531914893617,
 'R:ACC30': 0.9148936170212766,
 'R:auc@15': 0.19249210014402962,
 'R:auc@30': 0.43140082592587065,
 'R:meanErr': 20.29240936292988,
 'R:medianErr': 17.55420453460684,
 't:ACC15': 0.8297872340425532,
 't:ACC30': 0.9468085106382979,
 't:auc@15': 0.41311574965308573,
 't:auc@30': 0.6625310291437257,
 't:meanErr': 12.372718798740152,
 't:medianErr': 8.693510440130208}[0m


Acc: 53/94


In [10]:
from tabulate import tabulate
headers = ["Category"] + list(val_metrics_4tb.keys())
all_data = np.array(res_table)[:, 1:].astype(np.float32)
res_table.append(["Avg"] + all_data.mean(0).tolist())
print(tabulate(res_table, \
    headers=headers, tablefmt='fancy_grid'))

╒══════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤═════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤═════════════╤══════════╕
│ Category     │   R:auc@15 │   R:auc@30 │   R:ACC15 │   R:ACC30 │   R:medianErr │   R:meanErr │   t:auc@15 │   t:auc@30 │   t:ACC15 │   t:ACC30 │   t:medianErr │   t:meanErr │     AP50 │
╞══════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪═════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪═════════════╪══════════╡
│ toyrobot     │   0.174418 │   0.450399 │  0.421053 │  1        │       16.9107 │     16.7467 │   0.486019 │   0.732626 │  0.912281 │  1        │       7.00165 │     8.2107  │ 0.982456 │
├──────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼─────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼─────────────┼──────────┤
│ yellowduck   │   0.276516 │   0.556445 │  0.590909 │  0.98