In [1]:
import os
import torch
import random
import pprint
import numpy as np
from loguru import logger
from torch.utils.data import DataLoader
from torch import nn

from pose.dataset import pose_dataset
from pose.utils import (
    collate_fn,
    geodesic_distance,
    relative_pose_error,
    relative_pose_error_np,
    recall_object,
    aggregate_metrics
)
from pose.model import Mkpts_Reg_Model
from pose.animator import Animator


if os.name == 'nt':
    LM_dataset_path = 'd:/git_project/POPE/data/LM_dataset/'
    LM_dataset_json_path = 'd:/git_project/POPE/data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'd:/git_project/POPE/data/LM_dataset-points/'

    onepose_path = 'd:/git_project/POPE/data/onepose/'
    onepose_json_path = 'd:/git_project/POPE/data/pairs/Onepose-test.json'
    onepose_points_path = 'd:/git_project/POPE/data/onepose-points/'

    onepose_plusplus_path = 'd:/git_project/POPE/data/onepose_plusplus/'
    onepose_plusplus_json_path = 'd:/git_project/POPE/data/pairs/OneposePlusPlus-test.json'
    onepose_plusplus_points_path = 'd:/git_project/POPE/data/onepose_plusplus-points/'

    ycbv_path = 'd:/git_project/POPE/data/ycbv/'
    ycbv_json_path = 'd:/git_project/POPE/data/pairs/YCB-VIDEO-test.json'
    ycbv_points_path = 'd:/git_project/POPE/data/ycbv-points'
elif os.name == 'posix':
    LM_dataset_path = 'data/LM_dataset/'
    LM_dataset_json_path = 'data/pairs/LINEMOD-test.json'
    LM_dataset_points_path = 'data/LM_dataset-points/'

    onepose_path = 'data/onepose/'
    onepose_json_path = 'data/pairs/Onepose-test.json'
    onepose_points_path = 'data/onepose-points/'

    onepose_plusplus_path = 'data/onepose_plusplus/'
    onepose_plusplus_json_path = 'data/pairs/OneposePlusPlus-test.json'
    onepose_plusplus_points_path = 'data/onepose_plusplus-points/'

    ycbv_path = 'data/ycbv/'
    ycbv_json_path = 'data/pairs/YCB-VIDEO-test.json'
    ycbv_points_path = 'data/ycbv-points'

paths = [
    # ('linemod', LM_dataset_path, LM_dataset_json_path, LM_dataset_points_path),
    # ('onepose', onepose_path, onepose_json_path, onepose_points_path),
    # ('onepose_plusplus', onepose_plusplus_path, onepose_plusplus_json_path, onepose_plusplus_points_path),
    ('ycbv', ycbv_path, ycbv_json_path, ycbv_points_path),
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = pose_dataset(paths)
mkpts_max_len, mkpts_sum_len = dataset.get_mkpts_info()

  0%|          | 0/10 [00:00<?, ?it/s]

data/ycbv-points/0801-1-other/mkpts0/1572-1.png-1536-1.png.txt does not exist
data/ycbv-points/0801-1-other/mkpts0/1586-1.png-1566-1.png.txt does not exist


100%|██████████| 10/10 [00:07<00:00,  1.26it/s]


In [2]:
random.seed(20231223)
torch.manual_seed(20231223)
torch.cuda.manual_seed(20231223)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [3]:
num_sample = 300
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True, collate_fn=collate_fn(num_sample))

In [4]:
net = torch.load('./weights/20240505/ycbv-mkpts-imgs-relative_r-gt_t-6d-300-2024-05-04-19-36-56-5580.2383.pth').to(device)

net.eval()

Mkpts_Reg_Model(
  (embedding): Embedding()
  (transformer_mkpts): Transformer(
    (self_attn): MultiheadAttention(
      (out_proj): _LinearWithBias(in_features=76, out_features=76, bias=True)
    )
    (linear1): Linear(in_features=76, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=76, bias=True)
    (norm1): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (mlp1): Sequential(
    (0): Linear(in_features=22800, out_features=11400, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=11400, out_features=2000, bias=True)
    (4): LeakyReLU(negative_slope=0.01)
    (5): Dropout(p=0.2, inplace=False)
  )
  (convnextv2): ConvNeXtV2(
    (model): ConvNeXtV2(
      (downsample_laye

In [5]:
id2name_dict = {
    1: "ape",
    2: "benchvise",
    4: "camera",
    5: "can",
    6: "cat",
    8: "driller",
    9: "duck",
    10: "eggbox",
    11: "glue",
    12: "holepuncher",
    13: "iron",
    14: "lamp",
    15: "phone",
}

ycbv_dict = {
    1: 'one',
    2: 'two',
    3: 'three',
    4: 'four',
    5: 'five',
    6: 'six',
    7: 'seven',
    8: 'eight',
    9: 'nine',
    10: 'ten',
}

# linemod
ape_data = []
benchvise_data = []
camera_data = []
can_data = []
cat_data = []
driller_data = []
duck_data = []
eggbox_data = []
glue_data = []
holepuncher_data = []
iron_data = []
lamp_data = []
phone_data = []
# onepose
aptamil_data = []
jzhg_data = []
minipuff_data = []
hlyormosiapie_data = []
brownhouse_data = []
oreo_data = []
mfmilkcake_data = []
diycookies_data = []
taipingcookies_data = []
tee_data = []
# onepose++
toyrobot_data = []
yellowduck_data = []
sheep_data = []
fakebanana_data = []
teabox_data = []
orange_data = []
greenteapot_data = []
lecreusetcup_data = []
insta_data = []
# ycbv
one_data = []
two_data = []
three_data = []
four_data = []
five_data = []
six_data = []
seven_data = []
eight_data = []
nine_data = []
ten_data = []

all_data = {
    # linemod
    'ape_data': ape_data,
    'benchvise_data': benchvise_data,
    'camera_data': camera_data,
    'can_data': can_data,
    'cat_data': cat_data,
    'driller_data': driller_data,
    'duck_data': duck_data,
    'eggbox_data': eggbox_data,
    'glue_data': glue_data,
    'holepuncher_data': holepuncher_data,
    'iron_data': iron_data,
    'lamp_data': lamp_data,
    'phone_data': phone_data,
    # onepose
    'aptamil_data': aptamil_data,
    'jzhg_data': jzhg_data,
    'minipuff_data': minipuff_data,
    'hlyormosiapie_data': hlyormosiapie_data,
    'brownhouse_data': brownhouse_data,
    'oreo_data': oreo_data,
    'mfmilkcake_data': mfmilkcake_data,
    'diycookies_data': diycookies_data,
    'taipingcookies_data': taipingcookies_data,
    'tee_data': tee_data,
    # onepose++
    'toyrobot_data': toyrobot_data,
    'yellowduck_data': yellowduck_data,
    'sheep_data': sheep_data,
    'fakebanana_data': fakebanana_data,
    'teabox_data': teabox_data,
    'orange_data': orange_data,
    'greenteapot_data': greenteapot_data,
    'lecreusetcup_data': lecreusetcup_data,
    'insta_data': insta_data,
    # ycbv
    'one_data': one_data,
    'two_data': two_data,
    'three_data': three_data,
    'four_data': four_data,
    'five_data': five_data,
    'six_data': six_data,
    'seven_data': seven_data,
    'eight_data': eight_data,
    'nine_data': nine_data,
    'ten_data': ten_data,
}

linemod_type = ['ape_data', 'benchvise_data', 'camera_data', 'can_data', 'cat_data', 'driller_data', 'duck_data', 'eggbox_data', 'glue_data', 'holepuncher_data', 'iron_data', 'lamp_data', 'phone_data']
onepose_type = ['aptamil_data', 'jzhg_data', 'minipuff_data', 'hlyormosiapie_data', 'brownhouse_data', 'oreo_data', 'mfmilkcake_data', 'diycookies_data', 'taipingcookies_data', 'tee_data']
oneposeplusplus_type = ['toyrobot_data', 'yellowduck_data', 'sheep_data', 'fakebanana_data', 'teabox_data', 'orange_data', 'greenteapot_data', 'lecreusetcup_data', 'insta_data']
ycbv_type = ['one_data', 'two_data', 'three_data', 'four_data', 'five_data', 'six_data', 'seven_data', 'eight_data', 'nine_data', 'ten_data']

for i, batch in enumerate(dataloader):
    for data in batch:
        if 'lm' in data['name']:
            all_data[f"{id2name_dict[int(data['name'][2:])]}_data"].append(data)
        else:
            if data['name'] in ('12345678910'):
                all_data[f"{ycbv_dict[int(data['name'])]}_data"].append(data)
            else:
                all_data[f"{data['name']}_data"].append(data)

empty_keys = []
for key in all_data.keys():
    if len(all_data[key]) == 0:
        empty_keys.append(key)

for key in empty_keys:
    all_data.pop(key)

for key in all_data.keys():
    print(key, len(all_data[key]))

print('len(all_data):', len(all_data))

one_data 213
two_data 126
three_data 248
four_data 175
five_data 142
six_data 143
seven_data 16
eight_data 107
nine_data 122
ten_data 140
len(all_data): 10


In [6]:
res_table = []

model_type = 'relative_r-gt_t'

choose = []

for key in all_data.keys():
    if key in linemod_type:
        logger.info(f"LINEMOD: {key}")
    elif key in onepose_type:
        logger.info(f"ONEPOSE: {key}")
    elif key in oneposeplusplus_type:
        logger.info(f"ONEPOSE++: {key}")
    elif key in ycbv_type:
        logger.info(f"YCBV: {key}")
    metrics = dict()
    metrics.update({'R_errs':[], 't_errs':[], 'inliers':[], "identifiers":[]})
    recall_image, all_image = 0, 0
    for item in all_data[key]:
        all_image += 1
        K0 = item['K0']
        K1 = item['K1']
        pose0 = item['pose0']
        pose1 = item['pose1']
        pre_bbox = item['pre_bbox']
        gt_bbox = item['gt_bbox']
        mkpts0 = item['mkpts0']
        mkpts1 = item['mkpts1']
        pre_K = item['pre_K']
        img0 = item['img0']
        img1 = item['img1']
        name = item['name']
        pair_name = item['pair_name']
        # linemod
        if 'lm' in name:
            name = id2name_dict[int(name[2:])]
        # ycbv
        if name in '12345678910':
            name = ycbv_dict[int(name)]

        if name not in key:
            print(f'name: {name}, key: {key}')
            continue

        is_recalled = recall_object(pre_bbox, gt_bbox)

        recall_image = recall_image + int(is_recalled > 0.5)

        batch_mkpts0 = torch.from_numpy(mkpts0).unsqueeze(0).float().to(device)
        batch_mkpts1 = torch.from_numpy(mkpts1).unsqueeze(0).float().to(device)
        img0 = torch.from_numpy(img0).unsqueeze(0).float().to(device)
        img1 = torch.from_numpy(img1).unsqueeze(0).float().to(device)
        img0 = img0.permute(0, 3, 2, 1)
        img1 = img1.permute(0, 3, 2, 1)
        # print(batch_mkpts0.shape, batch_mkpts1.shape, img0.shape, img1.shape)
        pre_t, pre_rot = net(batch_mkpts0, batch_mkpts1, img0, img1)
        # print(pre_t.shape, pre_rot.shape)
        pre_t = pre_t.squeeze(0).detach().cpu().numpy()
        pre_rot = pre_rot.squeeze(0).detach().cpu().numpy()

        if model_type == 'gt':
            t_err, R_err = relative_pose_error_np(pose1, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            t_err, R_err = relative_pose_error_np(relative_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)
        elif model_type == 'relative_r-gt_t':
            relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            gt_pose = np.zeros_like(pose1)
            gt_pose[:3, :3] = relative_pose[:3, :3]
            gt_pose[:3, 3] = pose1[:3, 3]
            t_err, R_err = relative_pose_error_np(gt_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)

        # 用choose列表记录前10个最低R_err的信息
        if len(choose) < 10:
            choose.append((name, pair_name, R_err, t_err))
        else:
            choose.sort(key=lambda x: x[2], reverse=True)
            if R_err < choose[0][2]:
                choose[0] = (name, pair_name, R_err, t_err)
        # print(f"{name} {pair_name} R_err: {R_err:.4f}, t_err: {t_err:.4f}")

[32m2024-05-08 10:34:16.399[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mYCBV: one_data[0m
[32m2024-05-08 10:34:21.833[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mYCBV: two_data[0m
[32m2024-05-08 10:34:24.851[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mYCBV: three_data[0m
[32m2024-05-08 10:34:30.775[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mYCBV: four_data[0m
[32m2024-05-08 10:34:34.947[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mYCBV: five_data[0m
[32m2024-05-08 10:34:38.335[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mYCBV: six_data[0m
[32m2024-05-08 10:34:41.789[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mYCBV: seven_data[0m
[32m2024-05-08 10:34:42.179[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - 

In [7]:
for name, pair_name, R_err, t_err in choose:
    print(f"{name} {pair_name} R_err: {R_err:.4f}, t_err: {t_err:.4f}")

eight 0801-8-other/8-3/color/1068-8.png-1082-8.png R_err: 0.6368, t_err: 8.4450
three 0801-3-other/3-3/color/649-3.png-654-3.png R_err: 0.6149, t_err: 7.7593
five 0801-5-other/5-3/color/721-5.png-751-5.png R_err: 0.5854, t_err: 12.1255
five 0801-5-other/5-3/color/169-5.png-175-5.png R_err: 0.5442, t_err: 1.6957
six 0801-6-other/6-3/color/2104-6.png-2083-6.png R_err: 0.5436, t_err: 3.0736
four 0801-4-other/4-3/color/1719-4.png-1711-4.png R_err: 0.4736, t_err: 6.9853
two 0801-2-other/2-3/color/651-2.png-629-2.png R_err: 0.4438, t_err: 8.2516
seven 0801-7-other/7-3/color/1669-7.png-1639-7.png R_err: 0.4285, t_err: 5.7694
nine 0801-9-other/9-3/color/1027-9.png-1034-9.png R_err: 0.3911, t_err: 1.6336
three 0801-3-other/3-3/color/1543-3.png-1547-3.png R_err: 0.3901, t_err: 7.9239
