In [1]:
import os
import torch
import random
import pprint
import numpy as np
from loguru import logger
from torch.utils.data import DataLoader
from torch import nn

from pose.dataset import pose_dataset
from pose.utils import (
    collate_fn,
    geodesic_distance,
    relative_pose_error,
    relative_pose_error_np,
    recall_object,
    aggregate_metrics
)
from pose.model import Mkpts_Reg_Model
from pose.animator import Animator


if os.name == 'nt':
    LM_dataset_path = 'd:/git_project/POPE/data/LM_dataset/'
    LM_dataset_json_path = 'd:/git_project/POPE/data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'd:/git_project/POPE/data/LM_dataset-points/'

    onepose_path = 'd:/git_project/POPE/data/onepose/'
    onepose_json_path = 'd:/git_project/POPE/data/pairs/Onepose-test.json'
    onepose_points_path = 'd:/git_project/POPE/data/onepose-points/'

    onepose_plusplus_path = 'd:/git_project/POPE/data/onepose_plusplus/'
    onepose_plusplus_json_path = 'd:/git_project/POPE/data/pairs/OneposePlusPlus-test.json'
    onepose_plusplus_points_path = 'd:/git_project/POPE/data/onepose_plusplus-points/'

    ycbv_path = 'd:/git_project/POPE/data/ycbv/'
    ycbv_json_path = 'd:/git_project/POPE/data/pairs/YCB-VIDEO-test.json'
    ycbv_points_path = 'd:/git_project/POPE/data/ycbv-points'
elif os.name == 'posix':
    LM_dataset_path = 'data/LM_dataset/'
    LM_dataset_json_path = 'data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'data/LM_dataset-points/'

    onepose_path = 'data/onepose/'
    onepose_json_path = 'data/pairs/Onepose-test.json'
    onepose_points_path = 'data/onepose-points/'

    onepose_plusplus_path = 'data/onepose_plusplus/'
    onepose_plusplus_json_path = 'data/pairs/OneposePlusPlus-test.json'
    onepose_plusplus_points_path = 'data/onepose_plusplus-points/'

    ycbv_path = 'data/ycbv/'
    ycbv_json_path = 'data/pairs/YCB-VIDEO-test.json'
    ycbv_points_path = 'data/ycbv-points'

paths = [
    # ('linemod', LM_dataset_path, LM_dataset_json_path, LM_dataset_points_path),
    ('onepose', onepose_path, onepose_json_path, onepose_points_path),
    # ('onepose_plusplus', onepose_plusplus_path, onepose_plusplus_json_path, onepose_plusplus_points_path),
    # ('ycbv', ycbv_path, ycbv_json_path, ycbv_points_path),
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = pose_dataset(paths)
mkpts_max_len, mkpts_sum_len = dataset.get_mkpts_info()

  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/10 [00:00<?, ?it/s]

data/onepose-points/0409-aptamil-box/mkpts0/566.png-581.png.txt does not exist
data/onepose-points/0409-aptamil-box/mkpts0/544.png-557.png.txt does not exist
data/onepose-points/0409-aptamil-box/mkpts0/562.png-621.png.txt does not exist
data/onepose-points/0409-aptamil-box/mkpts0/620.png-562.png.txt does not exist
data/onepose-points/0409-aptamil-box/mkpts0/579.png-561.png.txt does not exist
data/onepose-points/0409-aptamil-box/mkpts0/564.png-543.png.txt does not exist
data/onepose-points/0409-aptamil-box/mkpts0/607.png-553.png.txt does not exist


 10%|█         | 1/10 [00:04<00:44,  4.93s/it]

data/onepose-points/0409-aptamil-box/mkpts0/601.png-549.png.txt does not exist
data/onepose-points/0409-aptamil-box/mkpts0/606.png-547.png.txt does not exist


100%|██████████| 10/10 [00:34<00:00,  3.43s/it]


In [2]:
random.seed(20231223)
torch.manual_seed(20231223)
torch.cuda.manual_seed(20231223)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [3]:
num_sample = 500
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_fn(num_sample))
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_fn(num_sample))

In [4]:
net = torch.load('./weights/relative_r-gt_t-6d-500-2024-04-20-13-58-49-0.1459.pth').to(device)

net.eval()

Mkpts_Reg_Model(
  (embedding): Embedding()
  (transformer_mkpts): Transformer(
    (self_attn): MultiheadAttention(
      (out_proj): _LinearWithBias(in_features=76, out_features=76, bias=True)
    )
    (linear1): Linear(in_features=76, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=76, bias=True)
    (norm1): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (mlp1): Sequential(
    (0): Linear(in_features=38000, out_features=19000, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=19000, out_features=2000, bias=True)
    (4): LeakyReLU(negative_slope=0.01)
    (5): Dropout(p=0.2, inplace=False)
  )
  (convnextv2): ConvNeXtV2(
    (model): ConvNeXtV2(
      (downsample_laye

In [5]:
id2name_dict = {
    1: "ape",
    2: "benchvise",
    4: "camera",
    5: "can",
    6: "cat",
    8: "driller",
    9: "duck",
    10: "eggbox",
    11: "glue",
    12: "holepuncher",
    13: "iron",
    14: "lamp",
    15: "phone",
}

ycbv_dict = {
    1: 'one',
    2: 'two',
    3: 'three',
    4: 'four',
    5: 'five',
    6: 'six',
    7: 'seven',
    8: 'eight',
    9: 'nine',
    10: 'ten',
}

In [6]:
# linemod
ape_data = []
benchvise_data = []
camera_data = []
can_data = []
cat_data = []
driller_data = []
duck_data = []
eggbox_data = []
glue_data = []
holepuncher_data = []
iron_data = []
lamp_data = []
phone_data = []
# onepose
aptamil_data = []
jzhg_data = []
minipuff_data = []
hlyormosiapie_data = []
brownhouse_data = []
oreo_data = []
mfmilkcake_data = []
diycookies_data = []
taipingcookies_data = []
tee_data = []
# onepose++
toyrobot_data = []
yellowduck_data = []
sheep_data = []
fakebanana_data = []
teabox_data = []
orange_data = []
greenteapot_data = []
lecreusetcup_data = []
insta_data = []
# ycbv
one_data = []
two_data = []
three_data = []
four_data = []
five_data = []
six_data = []
seven_data = []
eight_data = []
nine_data = []
ten_data = []

all_data = {
    # linemod
    'ape_data': ape_data,
    'benchvise_data': benchvise_data,
    'camera_data': camera_data,
    'can_data': can_data,
    'cat_data': cat_data,
    'driller_data': driller_data,
    'duck_data': duck_data,
    'eggbox_data': eggbox_data,
    'glue_data': glue_data,
    'holepuncher_data': holepuncher_data,
    'iron_data': iron_data,
    'lamp_data': lamp_data,
    'phone_data': phone_data,
    # onepose
    'aptamil_data': aptamil_data,
    'jzhg_data': jzhg_data,
    'minipuff_data': minipuff_data,
    'hlyormosiapie_data': hlyormosiapie_data,
    'brownhouse_data': brownhouse_data,
    'oreo_data': oreo_data,
    'mfmilkcake_data': mfmilkcake_data,
    'diycookies_data': diycookies_data,
    'taipingcookies_data': taipingcookies_data,
    'tee_data': tee_data,
    # onepose++
    'toyrobot_data': toyrobot_data,
    'yellowduck_data': yellowduck_data,
    'sheep_data': sheep_data,
    'fakebanana_data': fakebanana_data,
    'teabox_data': teabox_data,
    'orange_data': orange_data,
    'greenteapot_data': greenteapot_data,
    'lecreusetcup_data': lecreusetcup_data,
    'insta_data': insta_data,
    # ycbv
    'one_data': one_data,
    'two_data': two_data,
    'three_data': three_data,
    'four_data': four_data,
    'five_data': five_data,
    'six_data': six_data,
    'seven_data': seven_data,
    'eight_data': eight_data,
    'nine_data': nine_data,
    'ten_data': ten_data,
}

In [7]:
linemod_type = ['ape_data', 'benchvise_data', 'camera_data', 'can_data', 'cat_data', 'driller_data', 'duck_data', 'eggbox_data', 'glue_data', 'holepuncher_data', 'iron_data', 'lamp_data', 'phone_data']
onepose_type = ['aptamil_data', 'jzhg_data', 'minipuff_data', 'hlyormosiapie_data', 'brownhouse_data', 'oreo_data', 'mfmilkcake_data', 'diycookies_data', 'taipingcookies_data', 'tee_data']
oneposeplusplus_type = ['toyrobot_data', 'yellowduck_data', 'sheep_data', 'fakebanana_data', 'teabox_data', 'orange_data', 'greenteapot_data', 'lecreusetcup_data', 'insta_data']
ycbv_type = ['one_data', 'two_data', 'three_data', 'four_data', 'five_data', 'six_data', 'seven_data', 'eight_data', 'nine_data', 'ten_data']

In [8]:
for i, batch in enumerate(test_dataloader):
    for data in batch:
        if 'lm' in data['name']:
            all_data[f"{id2name_dict[int(data['name'][2:])]}_data"].append(data)
        else:
            if data['name'] in ('12345678910'):
                all_data[f"{ycbv_dict[int(data['name'])]}_data"].append(data)
            else:
                all_data[f"{data['name']}_data"].append(data)

empty_keys = []
for key in all_data.keys():
    if len(all_data[key]) == 0:
        empty_keys.append(key)

for key in empty_keys:
    all_data.pop(key)

for key in all_data.keys():
    print(key, len(all_data[key]))

print('len(all_data):', len(all_data))

aptamil_data 77
jzhg_data 73
minipuff_data 41
hlyormosiapie_data 59
brownhouse_data 55
oreo_data 55
mfmilkcake_data 59
diycookies_data 68
taipingcookies_data 56
tee_data 89
len(all_data): 10


In [9]:
res_table = []

model_type = 'relative_r-gt_t'

for key in all_data.keys():
    if key in linemod_type:
        logger.info(f"LINEMOD: {key}")
    elif key in onepose_type:
        logger.info(f"ONEPOSE: {key}")
    elif key in oneposeplusplus_type:
        logger.info(f"ONEPOSE++: {key}")
    elif key in ycbv_type:
        logger.info(f"YCBV: {key}")
    metrics = dict()
    metrics.update({'R_errs':[], 't_errs':[], 'inliers':[], "identifiers":[]})
    recall_image, all_image = 0, 0
    for item in all_data[key]:
        all_image += 1
        K0 = item['K0']
        K1 = item['K1']
        pose0 = item['pose0']
        pose1 = item['pose1']
        pre_bbox = item['pre_bbox']
        gt_bbox = item['gt_bbox']
        mkpts0 = item['mkpts0']
        mkpts1 = item['mkpts1']
        pre_K = item['pre_K']
        img0 = item['img0']
        img1 = item['img1']
        name = item['name']
        pair_name = item['pair_name']
        # linemod
        if 'lm' in name:
            name = id2name_dict[int(name[2:])]
        # ycbv
        if name in '12345678910':
            name = ycbv_dict[int(name)]

        if name not in key:
            print(f'name: {name}, key: {key}')
            continue

        is_recalled = recall_object(pre_bbox, gt_bbox)

        recall_image = recall_image + int(is_recalled > 0.5)

        batch_mkpts0 = torch.from_numpy(mkpts0).unsqueeze(0).float().to(device)
        batch_mkpts1 = torch.from_numpy(mkpts1).unsqueeze(0).float().to(device)
        img0 = torch.from_numpy(img0).unsqueeze(0).float().to(device)
        img1 = torch.from_numpy(img1).unsqueeze(0).float().to(device)
        img0 = img0.permute(0, 3, 2, 1)
        img1 = img1.permute(0, 3, 2, 1)
        pre_t, pre_rot = net(batch_mkpts0, batch_mkpts1, img0, img1)
        pre_t = pre_t.squeeze(0).detach().cpu().numpy()
        pre_rot = pre_rot.squeeze(0).detach().cpu().numpy()

        if model_type == 'gt':
            t_err, R_err = relative_pose_error_np(pose1, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            t_err, R_err = relative_pose_error_np(relative_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative_r-gt_t':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            gt_pose = np.zeros_like(pose1)
            gt_pose[:3, :3] = relative_pose[:3, :3]
            gt_pose[:3, 3] = pose1[:3, 3]
            t_err, R_err = relative_pose_error_np(gt_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)

        metrics['R_errs'].append(R_err)
        metrics['t_errs'].append(t_err)
        metrics['identifiers'].append(pair_name)

    print(f"Acc: {recall_image}/{all_image}")
    val_metrics_4tb = aggregate_metrics(metrics, 5e-4)
    val_metrics_4tb["AP50"] = recall_image / all_image
    # logger.info('\n' + pprint.pformat(val_metrics_4tb))
    res_table.append([f"{name}"] + list(val_metrics_4tb.values()))

[32m2024-04-20 15:08:51.477[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: aptamil_data[0m


[32m2024-04-20 15:08:53.990[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 77 unique items...[0m
[32m2024-04-20 15:08:53.992[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: jzhg_data[0m


Acc: 73/77


[32m2024-04-20 15:08:56.056[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 73 unique items...[0m
[32m2024-04-20 15:08:56.058[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: minipuff_data[0m


Acc: 63/73


[32m2024-04-20 15:08:57.198[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 41 unique items...[0m
[32m2024-04-20 15:08:57.199[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: hlyormosiapie_data[0m


Acc: 39/41


[32m2024-04-20 15:08:58.844[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 59 unique items...[0m
[32m2024-04-20 15:08:58.847[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: brownhouse_data[0m


Acc: 53/59


[32m2024-04-20 15:09:00.382[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 55 unique items...[0m
[32m2024-04-20 15:09:00.385[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: oreo_data[0m


Acc: 42/55


[32m2024-04-20 15:09:01.938[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 55 unique items...[0m
[32m2024-04-20 15:09:01.940[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: mfmilkcake_data[0m


Acc: 50/55


[32m2024-04-20 15:09:03.581[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 59 unique items...[0m
[32m2024-04-20 15:09:03.583[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: diycookies_data[0m


Acc: 48/59


[32m2024-04-20 15:09:05.482[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 68 unique items...[0m
[32m2024-04-20 15:09:05.484[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: taipingcookies_data[0m


Acc: 55/68


[32m2024-04-20 15:09:07.046[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 56 unique items...[0m
[32m2024-04-20 15:09:07.048[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mONEPOSE: tee_data[0m


Acc: 35/56


[32m2024-04-20 15:09:09.513[0m | [1mINFO    [0m | [36mpose.utils[0m:[36maggregate_metrics[0m:[36m237[0m - [1mAggregating metrics over 89 unique items...[0m


Acc: 65/89


In [10]:
for i, key in enumerate(all_data.keys()):
    if i == 1:
        break
    if key in linemod_type:
        print(f"LINEMOD")
    elif key in onepose_type:
        print(f"ONEPOSE")
    elif key in oneposeplusplus_type:
        print(f"ONEPOSE++")
    elif key in ycbv_type:
        print(f"YCBV")

from tabulate import tabulate
headers = ["Category"] + list(val_metrics_4tb.keys())
all_data = np.array(res_table)[:, 1:].astype(np.float32)
res_table.append(["Avg"] + all_data.mean(0).tolist())
print(tabulate(res_table, headers=headers, tablefmt='fancy_grid'))

ONEPOSE
╒════════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤═════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤═════════════╤══════════╕
│ Category       │   R:auc@15 │   R:auc@30 │   R:ACC15 │   R:ACC30 │   R:medianErr │   R:meanErr │   t:auc@15 │   t:auc@30 │   t:ACC15 │   t:ACC30 │   t:medianErr │   t:meanErr │     AP50 │
╞════════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪═════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪═════════════╪══════════╡
│ aptamil        │   0.414872 │   0.68968  │  0.844156 │  0.987013 │       8.1579  │     9.64965 │   0.448331 │   0.7144   │  0.935065 │  0.987013 │       8.38028 │     8.7455  │ 0.948052 │
├────────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼─────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼─────────────┼──────────┤
│ jzhg           │   0.307717 │   0.575691

In [11]:
import pandas as pd
df = pd.DataFrame(res_table, columns=headers)

df_rounded = df.round(6)

df_rounded.to_excel("res.xlsx", index=False)