In [2]:
import torch
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.pvd_data_pc import ShapeNet15kPointCloudsPVD
from metrics.evaluation_metrics import compute_all_metrics

Jitting Chamfer 3D
Loaded JIT 3D CUDA chamfer distance


In [3]:
chair_dataset = ShapeNet15kPointClouds(
    root_dir="./ShapeNetCore.v2.PC15k",
    categories=["chair"],
    split="val",
    tr_sample_size=2048,
    te_sample_size=2048,
    scale=1.0,
    normalize_per_shape=False,
    normalize_std_per_axis=False,
    random_subsample=True,
)

airplane_dataset = ShapeNet15kPointClouds(
    root_dir="./ShapeNetCore.v2.PC15k",
    categories=["airplane"],
    split="val",
    tr_sample_size=2048,
    te_sample_size=2048,
    scale=1.0,
    normalize_per_shape=False,
    normalize_std_per_axis=False,
    random_subsample=True,
)

pvd_airplane_dataset = ShapeNet15kPointCloudsPVD(
    root_dir="./custom-dataset",
    categories=["airplane"],
    split="train",
    tr_sample_size=2048,
    te_sample_size=2048,
    scale=1.0,
    normalize_per_shape=False,
    normalize_std_per_axis=False,
    random_subsample=True,
)

pvd_chair_dataset = ShapeNet15kPointCloudsPVD(
    root_dir="./custom-dataset",
    categories=["chair"],
    split="train",
    tr_sample_size=2048,
    te_sample_size=2048,
    scale=1.0,
    normalize_per_shape=False,
    normalize_std_per_axis=False,
    random_subsample=True,
)

Total number of data:662
Min number of points: (train)2048 (test)2048
Total number of data:405
Min number of points: (train)2048 (test)2048
Total number of data:400
Min number of points: (train)2048 (test)2048
Total number of data:400
Min number of points: (train)2048 (test)2048


In [4]:
BATCH_SIZE = 50

shapenet_dataloader_airplane = torch.utils.data.DataLoader(
    airplane_dataset,
    batch_size=BATCH_SIZE,
    sampler=None,
    shuffle=True is None,
    num_workers=1,
    drop_last=True,
)

shapenet_dataloader_chair = torch.utils.data.DataLoader(
    chair_dataset,
    batch_size=BATCH_SIZE,
    sampler=None,
    shuffle=True is None,
    num_workers=1,
    drop_last=True,
)

pvd_dataloader_airplane = torch.utils.data.DataLoader(
    pvd_airplane_dataset,
    batch_size=BATCH_SIZE,
    sampler=None,
    shuffle=True is None,
    num_workers=1,
    drop_last=True,
)

pvd_dataloader_chair = torch.utils.data.DataLoader(
    pvd_chair_dataset,
    batch_size=BATCH_SIZE,
    sampler=None,
    shuffle=True is None,
    num_workers=1,
    drop_last=True,
)

In [5]:
shapenet_input_arr = []
pvd_output_arr = []

for batch_shapenet, batch_pvd in zip(shapenet_dataloader_airplane, pvd_dataloader_airplane):
    shapenet_input = batch_shapenet["test_points"].to("cuda")
    pvd_output = batch_pvd["test_points"].to("cuda")

    shapenet_input_arr.append(shapenet_input)
    pvd_output_arr.append(pvd_output)

In [6]:
input_airplane_pcs = torch.stack(shapenet_input_arr).view(400, 2048, 3)
output_airplane_pcs = torch.stack(pvd_output_arr).view(400, 2048, 3)
input_airplane_pcs.shape, output_airplane_pcs.shape

(torch.Size([400, 2048, 3]), torch.Size([400, 2048, 3]))

In [7]:
results = compute_all_metrics(input_airplane_pcs, output_airplane_pcs, 100)
results = {k: (v.cpu().detach().item()
                if not isinstance(v, float) else v) for k, v in results.items()}

print(results)

100%|██████████| 400/400 [03:43<00:00,  1.79it/s]
100%|██████████| 400/400 [06:51<00:00,  1.17s/it]
100%|██████████| 400/400 [07:44<00:00,  1.11s/it]

{'lgan_mmd-CD': 0.06689093261957169, 'lgan_cov-CD': 0.3400000035762787, 'lgan_mmd_smp-CD': 0.020403025671839714, 'lgan_mmd-EMD': 0.5590656995773315, 'lgan_cov-EMD': 0.3725000023841858, 'lgan_mmd_smp-EMD': 0.29166239500045776, '1-NN-CD-acc_t': 0.6175000071525574, '1-NN-CD-acc_f': 0.8924999833106995, '1-NN-CD-acc': 0.7549999952316284, '1-NN-EMD-acc_t': 0.5774999856948853, '1-NN-EMD-acc_f': 0.7925000190734863, '1-NN-EMD-acc': 0.6850000023841858}





In [None]:
airplane = {
    "lgan_mmd-CD": 0.06689093261957169,
    "lgan_cov-CD": 0.3400000035762787,
    "lgan_mmd_smp-CD": 0.020403025671839714,
    "lgan_mmd-EMD": 0.5590656995773315,
    "lgan_cov-EMD": 0.3725000023841858,
    "lgan_mmd_smp-EMD": 0.29166239500045776,
    "1-NN-CD-acc_t": 0.6175000071525574,
    "1-NN-CD-acc_f": 0.8924999833106995,
    "1-NN-CD-acc": 0.7549999952316284,
    "1-NN-EMD-acc_t": 0.5774999856948853,
    "1-NN-EMD-acc_f": 0.7925000190734863,
    "1-NN-EMD-acc": 0.6850000023841858,
}

# 1-NNA-CD: 75.49%
# 1-NNA-EMD: 68.5%

In [11]:
shapenet_input_arr = []
pvd_output_arr = []

for batch_shapenet, batch_pvd in zip(shapenet_dataloader_chair, pvd_dataloader_chair):
    shapenet_input = batch_shapenet["test_points"].to("cuda")
    pvd_output = batch_pvd["test_points"].to("cuda")

    shapenet_input_arr.append(shapenet_input)
    pvd_output_arr.append(pvd_output)

In [12]:
input_chair_pcs = torch.stack(shapenet_input_arr).view(400, 2048, 3)
output_chair_pcs = torch.stack(pvd_output_arr).view(400, 2048, 3)
input_chair_pcs.shape, output_chair_pcs.shape

(torch.Size([400, 2048, 3]), torch.Size([400, 2048, 3]))

In [13]:
results = compute_all_metrics(input_chair_pcs, output_chair_pcs, 100)
results = {k: (v.cpu().detach().item()
                if not isinstance(v, float) else v) for k, v in results.items()}

print(results)

100%|██████████| 400/400 [07:44<00:00,  1.17s/it]
100%|██████████| 400/400 [07:44<00:00,  1.17s/it]
100%|██████████| 400/400 [07:43<00:00,  1.08s/it]

{'lgan_mmd-CD': 0.09327861666679382, 'lgan_cov-CD': 0.5174999833106995, 'lgan_mmd_smp-CD': 0.10668836534023285, 'lgan_mmd-EMD': 0.5605520009994507, 'lgan_cov-EMD': 0.5400000214576721, 'lgan_mmd_smp-EMD': 0.6099603176116943, '1-NN-CD-acc_t': 0.6200000047683716, '1-NN-CD-acc_f': 0.5475000143051147, '1-NN-CD-acc': 0.5837500095367432, '1-NN-EMD-acc_t': 0.5525000095367432, '1-NN-EMD-acc_f': 0.5674999952316284, '1-NN-EMD-acc': 0.5600000023841858}





In [None]:
chair = {
    "lgan_mmd-CD": 0.09327861666679382,
    "lgan_cov-CD": 0.5174999833106995,
    "lgan_mmd_smp-CD": 0.10668836534023285,
    "lgan_mmd-EMD": 0.5605520009994507,
    "lgan_cov-EMD": 0.5400000214576721,
    "lgan_mmd_smp-EMD": 0.6099603176116943,
    "1-NN-CD-acc_t": 0.6200000047683716,
    "1-NN-CD-acc_f": 0.5475000143051147,
    "1-NN-CD-acc": 0.5837500095367432,
    "1-NN-EMD-acc_t": 0.5525000095367432,
    "1-NN-EMD-acc_f": 0.5674999952316284,
    "1-NN-EMD-acc": 0.5600000023841858,
}
# 1-NNA-CD: 58.37%
# 1-NNA-EMD: 56.0%