In [1]:
import os
import cv2
import json
import torch
import pprint
import numpy as np
from tqdm import tqdm
from loguru import logger
from tabulate import tabulate

from pose.utils import collate_fn, geodesic_distance, relative_pose_error, aggregate_metrics, recall_object, project_points

In [2]:
with open("data/pairs/OneposePlusPlus-test.json") as f:
    dir_list = json.load(f)
len(dir_list)

9

In [3]:
if os.name == 'unix':
    ROOT_DIR = 'data/oneposeplusplus/'
elif os.name == 'nt':
    ROOT_DIR = 'e:/datasets/OnePose++/lowtexture_test_data/'

In [4]:
res_table = []

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

num_sample = 3000

net = torch.load('./weights/oneposeplusplus-6d-3000-2024-01-06-00-25-35-0.2782.pth').to(device)

net.eval()

Mkpts_Reg_Model(
  (embedding): Embedding()
  (transformerlayer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
    )
    (linear1): Linear(in_features=76, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=76, bias=True)
    (norm1): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((76,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=76, out_features=76, bias=True)
        )
        (linear1): Linear(in_features=76, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inp

In [6]:
for label_idx, test_dict in enumerate(dir_list):
    logger.info(f"OneposePlusPlus: {label_idx + 1}/{len(dir_list)}")
    metrics = dict()
    metrics.update({'R_errs':[], 't_errs':[], 'inliers':[], "identifiers":[]})
    sample_data = dir_list[label_idx]["0"][0]
    label = sample_data.split("/")[0]
    name = label.split("-")[1]
    dir_name = os.path.dirname(sample_data)
    FULL_ROOT_DIR = os.path.join(ROOT_DIR, dir_name)
    recall_image, all_image = 0, 0
    for rotation_key, rotation_list in zip(test_dict.keys(), test_dict.values()):
        for pair_idx, pair_name in enumerate(tqdm(rotation_list)):
            all_image = all_image + 1
            base_name = os.path.basename(pair_name)
            idx0_name = base_name.split("-")[0]
            idx1_name = base_name.split("-")[1]
            image0_name = os.path.join(FULL_ROOT_DIR, idx0_name)
            image1_name = os.path.join(FULL_ROOT_DIR.replace("color", "color"), idx1_name)

            K0_path = image0_name.replace("color", "intrin_ba").replace("png", "txt")
            K1_path = image1_name.replace("color", "intrin_ba").replace("png", "txt")
            K0 = np.loadtxt(K0_path)
            K1 = np.loadtxt(K1_path)

            pose0_path = image0_name.replace("color", "poses_ba").replace("png", "txt")
            pose1_path = image1_name.replace("color", "poses_ba").replace("png", "txt")
            pose0 = np.loadtxt(pose0_path)
            pose1 = np.loadtxt(pose1_path)
            if pose0.shape[0] == 3:
                pose0 = np.concatenate([pose0, np.array([[0, 0, 0, 1]])], axis=0)
                pose1 = np.concatenate([pose1, np.array([[0, 0, 0, 1]])], axis=0)

            points_file_path = os.path.join('d:/git_project/POPE/data/oneposeplusplus-points/', pair_name.split("/")[0])
            pre_bbox_path = os.path.join(points_file_path, "pre_bbox")
            mkpts0_path = os.path.join(points_file_path, "mkpts0")
            mkpts1_path = os.path.join(points_file_path, "mkpts1")
            pre_K_path = os.path.join(points_file_path, "pre_K")
            points_name = pair_name.split("/")[-1]
            pre_bbox_path = os.path.join(pre_bbox_path, f'{points_name}.txt')
            mkpts0_path = os.path.join(mkpts0_path, f'{points_name}.txt')
            mkpts1_path = os.path.join(mkpts1_path, f'{points_name}.txt')
            pre_K_path = os.path.join(pre_K_path, f'{points_name}.txt')

            if not os.path.exists(pre_bbox_path):
                continue
            pre_bbox = np.loadtxt(pre_bbox_path)
            mkpts0 = np.loadtxt(mkpts0_path)
            mkpts1 = np.loadtxt(mkpts1_path)
            pre_K = np.loadtxt(pre_K_path)

            if mkpts0.shape[0] > num_sample:
                rand_idx = np.random.choice(mkpts0.shape[0], num_sample, replace=False)
                mkpts0 = mkpts0[rand_idx]
                mkpts1 = mkpts1[rand_idx]
            else:
                mkpts0 = np.concatenate([mkpts0, np.zeros((num_sample - mkpts0.shape[0], 2))], axis=0)
                mkpts1 = np.concatenate([mkpts1, np.zeros((num_sample - mkpts1.shape[0], 2))], axis=0)

            _3d_bbox = np.loadtxt(f"{os.path.join(ROOT_DIR, label)}/box3d_corners.txt")
            bbox_pts_3d, _ = project_points(_3d_bbox, pose1[:3, :4], K1)
            bbox_pts_3d = bbox_pts_3d.astype(np.int32)
            x0, y0, w, h = cv2.boundingRect(bbox_pts_3d)
            x1, y1 = x0 + w, y0 + h
            gt_bbox = np.array([x0, y0, x1, y1])
            is_recalled = recall_object(pre_bbox, gt_bbox)
            recall_image = recall_image + int(is_recalled > 0.5)

            batch_mkpts0 = torch.from_numpy(mkpts0).unsqueeze(0).float().to(device)
            batch_mkpts1 = torch.from_numpy(mkpts1).unsqueeze(0).float().to(device)
            pre_t, pre_rot = net(batch_mkpts0, batch_mkpts1)
            # pre_t = pre_t.cpu()
            # pre_rot = pre_rot.cpu()

            # batch_pose0 = torch.from_numpy(pose0).unsqueeze(0).float().to(device)
            # batch_pose1 = torch.from_numpy(pose1).unsqueeze(0).float().to(device)
            # batch_relative_pose = torch.matmul(batch_pose1, batch_pose0.permute(0, 2, 1))
            batch_relative_pose = np.matmul(pose1, np.linalg.inv(pose0))
            batch_relative_pose = torch.from_numpy(batch_relative_pose).unsqueeze(0).float().to(device)
            t_err, R_err = relative_pose_error(batch_relative_pose, pre_rot, pre_t, ignore_gt_t_thr=0.0)

            metrics['t_errs'] = metrics['t_errs'] + np.array(t_err.reshape(-1).cpu().detach().numpy()).tolist()
            metrics['R_errs'] = metrics['R_errs'] + np.array(R_err.reshape(-1).cpu().detach().numpy()).tolist()
            metrics['identifiers'].append(pair_name)

    print(f'Acc: {recall_image}/{all_image}')
    val_metrics_4tb = aggregate_metrics(metrics, 5e-4)
    val_metrics_4tb['AP50'] = recall_image / all_image
    logger.info('\n' + pprint.pformat(val_metrics_4tb))

    res_table.append([f"{name}"] + list(val_metrics_4tb.values()))

[32m2024-01-06 00:32:22.207[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 1/9[0m
  0%|          | 0/26 [00:00<?, ?it/s]

100%|██████████| 26/26 [00:02<00:00, 11.09it/s]
100%|██████████| 32/32 [00:00<00:00, 48.25it/s]
100%|██████████| 34/34 [00:00<00:00, 46.83it/s]
100%|██████████| 52/52 [00:01<00:00, 49.64it/s]
100%|██████████| 48/48 [00:01<00:00, 46.08it/s]
100%|██████████| 47/47 [00:00<00:00, 49.27it/s]
[32m2024-01-06 00:32:29.004[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9623430962343096,
 'R:ACC15': 0.4686192468619247,
 'R:ACC30': 0.9330543933054394,
 'R:auc@15': 0.207892341161539,
 'R:auc@30': 0.4606191817354125,
 'R:medianErr': 15.887227058410645,
 't:ACC15': 0.008368200836820083,
 't:ACC30': 0.1087866108786611,
 't:auc@15': 0.004231427537181221,
 't:auc@30': 0.030818501875490324,
 't:medianErr': 62.63356018066406}[0m
[32m2024-01-06 00:32:29.005[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 2/9[0m


Acc: 230/239


100%|██████████| 45/45 [00:00<00:00, 46.54it/s]
100%|██████████| 50/50 [00:01<00:00, 45.59it/s]
100%|██████████| 73/73 [00:01<00:00, 41.56it/s]
100%|██████████| 55/55 [00:01<00:00, 49.07it/s]
100%|██████████| 50/50 [00:00<00:00, 52.17it/s]
100%|██████████| 42/42 [00:00<00:00, 52.57it/s]
[32m2024-01-06 00:32:35.717[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9809523809523809,
 'R:ACC15': 0.5968253968253968,
 'R:ACC30': 0.9619047619047619,
 'R:auc@15': 0.24581368354262495,
 'R:auc@30': 0.5359132167965016,
 'R:medianErr': 12.971778869628906,
 't:ACC15': 0.03492063492063492,
 't:ACC30': 0.12698412698412698,
 't:auc@15': 0.01341314129097752,
 't:auc@30': 0.04309797319785627,
 't:medianErr': 58.431270599365234}[0m
[32m2024-01-06 00:32:35.718[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 3/9[0m


Acc: 309/315


100%|██████████| 38/38 [00:00<00:00, 40.41it/s]
100%|██████████| 44/44 [00:00<00:00, 50.80it/s]
100%|██████████| 43/43 [00:00<00:00, 52.09it/s]
100%|██████████| 35/35 [00:00<00:00, 44.41it/s]
100%|██████████| 35/35 [00:00<00:00, 41.49it/s]
100%|██████████| 38/38 [00:00<00:00, 42.43it/s]
[32m2024-01-06 00:32:40.896[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.3562231759656652,
 'R:ACC15': 0.5536480686695279,
 'R:ACC30': 0.9356223175965666,
 'R:auc@15': 0.26547054691887717,
 'R:auc@30': 0.5171009637426067,
 'R:medianErr': 13.465435981750488,
 't:ACC15': 0.02575107296137339,
 't:ACC30': 0.0944206008583691,
 't:auc@15': 0.008980405586471887,
 't:auc@30': 0.03140952679220699,
 't:medianErr': 64.56393432617188}[0m
[32m2024-01-06 00:32:40.897[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 4/9[0m


Acc: 83/233


100%|██████████| 42/42 [00:00<00:00, 44.87it/s]
100%|██████████| 62/62 [00:01<00:00, 49.65it/s]
100%|██████████| 58/58 [00:01<00:00, 49.16it/s]
100%|██████████| 45/45 [00:00<00:00, 50.56it/s]
100%|██████████| 44/44 [00:00<00:00, 52.46it/s]
100%|██████████| 49/49 [00:00<00:00, 52.67it/s]
[32m2024-01-06 00:32:46.937[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.97,
 'R:ACC15': 0.6166666666666667,
 'R:ACC30': 0.9333333333333333,
 'R:auc@15': 0.2537344587379032,
 'R:auc@30': 0.5395139815674888,
 'R:medianErr': 12.78298282623291,
 't:ACC15': 0.02666666666666667,
 't:ACC30': 0.10333333333333333,
 't:auc@15': 0.010024604479471842,
 't:auc@30': 0.03635841555065579,
 't:medianErr': 64.46404266357422}[0m
[32m2024-01-06 00:32:46.938[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 5/9[0m


Acc: 291/300


100%|██████████| 63/63 [00:01<00:00, 48.58it/s]
100%|██████████| 66/66 [00:01<00:00, 47.72it/s]
100%|██████████| 58/58 [00:01<00:00, 47.41it/s]
100%|██████████| 57/57 [00:01<00:00, 46.46it/s]
100%|██████████| 67/67 [00:01<00:00, 47.69it/s]
100%|██████████| 61/61 [00:01<00:00, 50.40it/s]
[32m2024-01-06 00:32:54.703[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9758064516129032,
 'R:ACC15': 0.5902964959568733,
 'R:ACC30': 0.9622641509433962,
 'R:auc@15': 0.25961142929416464,
 'R:auc@30': 0.5384947600176179,
 'R:medianErr': 12.700307846069336,
 't:ACC15': 0.013477088948787063,
 't:ACC30': 0.09973045822102426,
 't:auc@15': 0.005785726322746449,
 't:auc@30': 0.02863331466565128,
 't:medianErr': 60.335289001464844}[0m
[32m2024-01-06 00:32:54.704[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 6/9[0m


Acc: 363/372


100%|██████████| 32/32 [00:00<00:00, 42.18it/s]
100%|██████████| 63/63 [00:01<00:00, 43.90it/s]
100%|██████████| 59/59 [00:01<00:00, 47.66it/s]
100%|██████████| 53/53 [00:01<00:00, 49.34it/s]
100%|██████████| 66/66 [00:01<00:00, 49.47it/s]
100%|██████████| 67/67 [00:01<00:00, 49.27it/s]
[32m2024-01-06 00:33:01.921[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9823529411764705,
 'R:ACC15': 0.5411764705882353,
 'R:ACC30': 0.9382352941176471,
 'R:auc@15': 0.2179476830655453,
 'R:auc@30': 0.48623555885226116,
 'R:medianErr': 14.106171607971191,
 't:ACC15': 0.011764705882352941,
 't:ACC30': 0.14411764705882352,
 't:auc@15': 0.006713291710498287,
 't:auc@30': 0.038047041098276765,
 't:medianErr': 55.93498420715332}[0m
[32m2024-01-06 00:33:01.922[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 7/9[0m


Acc: 334/340


100%|██████████| 25/25 [00:00<00:00, 42.82it/s]
100%|██████████| 41/41 [00:00<00:00, 43.13it/s]
100%|██████████| 30/30 [00:00<00:00, 48.47it/s]
100%|██████████| 32/32 [00:00<00:00, 49.92it/s]
100%|██████████| 32/32 [00:00<00:00, 49.96it/s]
100%|██████████| 34/34 [00:00<00:00, 42.15it/s]
[32m2024-01-06 00:33:06.181[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.9742268041237113,
 'R:ACC15': 0.6082474226804123,
 'R:ACC30': 0.9845360824742269,
 'R:auc@15': 0.2610518759468577,
 'R:auc@30': 0.5488145997434137,
 'R:medianErr': 12.470929145812988,
 't:ACC15': 0.030927835051546393,
 't:ACC30': 0.12371134020618557,
 't:auc@15': 0.015338228166717845,
 't:auc@30': 0.040225840516106774,
 't:medianErr': 60.05063438415527}[0m
[32m2024-01-06 00:33:06.182[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 8/9[0m


Acc: 189/194


100%|██████████| 32/32 [00:00<00:00, 46.23it/s]
100%|██████████| 35/35 [00:00<00:00, 44.78it/s]
100%|██████████| 38/38 [00:00<00:00, 48.24it/s]
100%|██████████| 59/59 [00:01<00:00, 44.84it/s]
100%|██████████| 30/30 [00:00<00:00, 48.43it/s]
100%|██████████| 43/43 [00:00<00:00, 54.37it/s]
[32m2024-01-06 00:33:11.189[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.8312236286919831,
 'R:ACC15': 0.48523206751054854,
 'R:ACC30': 0.9578059071729957,
 'R:auc@15': 0.22278595701719298,
 'R:auc@30': 0.4875080503324248,
 'R:medianErr': 15.292442321777344,
 't:ACC15': 0.02109704641350211,
 't:ACC30': 0.08016877637130802,
 't:auc@15': 0.013964076115947566,
 't:auc@30': 0.03352129435908106,
 't:medianErr': 63.18133544921875}[0m
[32m2024-01-06 00:33:11.190[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mOneposePlusPlus: 9/9[0m


Acc: 197/237


100%|██████████| 36/36 [00:00<00:00, 47.43it/s]
100%|██████████| 74/74 [00:01<00:00, 49.91it/s]
100%|██████████| 103/103 [00:02<00:00, 46.28it/s]
100%|██████████| 100/100 [00:02<00:00, 45.27it/s]
100%|██████████| 96/96 [00:02<00:00, 47.63it/s]
100%|██████████| 112/112 [00:02<00:00, 51.45it/s]
[32m2024-01-06 00:33:22.079[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m88[0m - [1m
{'AP50': 0.6238003838771593,
 'R:ACC15': 0.3861003861003861,
 'R:ACC30': 0.9362934362934363,
 'R:auc@15': 0.15111212662901813,
 'R:auc@30': 0.42010370067954833,
 'R:medianErr': 17.830310821533203,
 't:ACC15': 0.02702702702702703,
 't:ACC30': 0.12162162162162163,
 't:auc@15': 0.011116256930192924,
 't:auc@30': 0.04031679820907009,
 't:medianErr': 63.180198669433594}[0m


Acc: 325/521


In [7]:
headers = ["Category"] + list(val_metrics_4tb.keys())
all_data = np.array(res_table)[:, 1:].astype(np.float32)
res_table.append(["Avg"] + all_data.mean(0).tolist())
print(tabulate(res_table, headers=headers, tablefmt='fancy_grid'))

╒══════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤════════════╤════════════╤═══════════╤═══════════╤═══════════════╤══════════╕
│ Category     │   R:auc@15 │   R:auc@30 │   R:ACC15 │   R:ACC30 │   R:medianErr │   t:auc@15 │   t:auc@30 │   t:ACC15 │   t:ACC30 │   t:medianErr │     AP50 │
╞══════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪════════════╪════════════╪═══════════╪═══════════╪═══════════════╪══════════╡
│ toyrobot     │   0.207892 │   0.460619 │  0.468619 │  0.933054 │       15.8872 │ 0.00423143 │  0.0308185 │ 0.0083682 │ 0.108787  │       62.6336 │ 0.962343 │
├──────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼────────────┼────────────┼───────────┼───────────┼───────────────┼──────────┤
│ yellowduck   │   0.245814 │   0.535913 │  0.596825 │  0.961905 │       12.9718 │ 0.0134131  │  0.043098  │ 0.0349206 │ 0.126984  │       58.4313 │ 0.980952 │
├──────────────┼────────────┼───────────