In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from trainer.RetrievalTrainer import trainer
from datasets.ScannetDataset import ScannetDataset
from datasets.CategoryTestTimeDataset import *
from datasets.Reader import *

from test_time.FeatureExtractor import FeatureExtractor
from test_time.RetrievalModule import RetrievalModule

from utils.logger import logger
from utils.ckpts import load_checkpoint, save_checkpoint
import utils.transform_estimation as te
from utils.eval import *
from utils.retrieval import *
from utils.visualize import *
from utils.preprocess import random_rotation, load_norm_pc, apply_transform
from utils.symmetry import *
from utils.eval_pose import *
from utils.Info.Scan2cadInfo import Scan2cadInfo
from utils.Info.CADLib import *
from utils.read_json import build_pcd

from model import load_model, fc

from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import KMeans
import matplotlib
import time

import open3d as o3d
import MinkowskiEngine as ME

torch.manual_seed(0)
torch.cuda.manual_seed(0)
        
class Config():
    def __init__(self, category):

        self.root = "/scannet/ShapeNetCore.v2.PC15k"
        self.scan2cad_root = "/scannet/crop_scan2cad_filter/data"
        self.cad_root = "/scannet/ShapeNetCore.v2.PC15k"
        if category == "chair":
            self.catid = "03001627"
        elif category == "table":
            self.catid = "04379243"
        else:
            self.catid = ""
        self.voxel_size = 0.03
        if category == "chair":
            self.resume = "./ckpts/cat_pose_id_01_FCGF16"
        elif category == "table":
            self.resume = "./ckpts/cat_table_pose_id_01_FCGF16"
        else:
            self.resume = ""
        
        self.split = "val"
        self.voxel_size = 0.03
        self.dim = [1024, 512,  256]
        self.embedding = "identity"
        self.model = "ResUNetBN2C"
        self.model_n_out = 16
        self.normalize_feature = True
        self.conv1_kernel_size = 3
        self.bn_momentum = 0.05
        self.nn_max_n = 500
        self.lib_path = "./CadLib"
        self.scan2cad_dict = "/scannet/scan2cad_download_link/unique_cads.csv"
        self.annotation_dir = "/scannet/scan2cad_download_link"
        self.device = torch.device("cuda")


ModuleNotFoundError: No module named 'utils.transform_estimation'

In [2]:

config = Config("chair")

logger = logger("./logs", "debug_shapenet_eval.txt")
logger.log(config.catid)
logger.log(config.resume)
logger.log(config.split)

dataset = CategoryDataset(config.cad_root, config.split, config.catid, "/scannet/tables", 
                                                     0.1, 0.5, config.voxel_size)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, 
                                        shuffle=False, num_workers=2, collate_fn=dataset.collate_pair_fn)

num_feats = 1
Model = load_model("ResUNetBN2C")
model = Model(
    num_feats,
    16,
    bn_momentum=0.05,
    normalize_feature=True,
    conv1_kernel_size=3,
    D=3)
model = model.to(torch.device("cuda"))


# Embedding network for retrieval
embedding = fc.identity().to("cuda")
distance = "chamfer"


checkpoint = torch.load(config.resume)

model.load_state_dict(checkpoint["state_dict"])

embedding.load_state_dict(checkpoint["embedding_state_dict"])

logger.log(checkpoint["epoch"])

model.eval()
embedding.eval()

np.random.seed(31)
torch.random.manual_seed(31)

2021-03-21 03:03:31 03001627


  0%|          | 0/662 [00:00<?, ?it/s]

2021-03-21 03:03:31 ./ckpts/cat_pose_id_01_FCGF16
2021-03-21 03:03:31 val


100%|██████████| 662/662 [00:01<00:00, 632.92it/s]


593
592
2021-03-21 03:03:35 99


<torch._C.Generator at 0x7fa0200e97d0>

In [None]:
########################################################################################


logger.log("start eval")
np.random.seed(31)
torch.random.manual_seed(31)
model.eval()
embedding.eval()

base_outputs = [];pos_outputs = []
base_origins = [];pos_origins = []

glob_feats = []

base_Ts = [];pos_Ts = [];
base_syms = [];pos_syms = []
        
        
num_data = len(dataset)

with torch.no_grad():

    for idx, data in enumerate(loader):
        logger.log("Eval feature Index: {}/{}".format(idx+1, len(loader)))

        base_input = ME.SparseTensor(data["base_feat"], data["base_coords"]).to("cuda")
        pos_input = ME.SparseTensor(data["pos_feat"], data["pos_coords"]).to("cuda")

        base_Ts.append(data["base_T"])
        pos_Ts.append(data["pos_T"])

        base_syms.append(data["base_sym"])
        pos_syms.append(data["pos_sym"])

        batch_size = len(data["base_idx"])

        base_output, base_feat = model(base_input)

        pos_output, pos_feat = model(pos_input)

        base_feat = embedding(base_feat)

        for i in range(batch_size):
            base_mask = base_output.C[:,0]==i
            pos_mask = pos_output.C[:,0]==i

            base_outputs.append(base_output.F[base_mask, :])
            pos_outputs.append(pos_output.F[pos_mask, :])
            base_origins.append(data["base_origin"][base_mask, :])
            pos_origins.append(data["pos_origin"][pos_mask, :])


        if distance == "l2":
            base_feat_norm = torch.nn.functional.normalize(base_feat, dim=1) 
            glob_feats.append(base_feat_norm)

        else:
            base_feat_norm = torch.nn.functional.normalize(base_feat.F, dim=1) 

            for i in range(batch_size):
                feat_mask = base_feat.C[:,0]==i
                glob_feats.append(base_feat_norm[feat_mask, :])  


    base_Ts = torch.cat(base_Ts, 0)
    pos_Ts = torch.cat(pos_Ts, 0)
    base_syms = torch.cat(base_syms, 0)
    pos_syms = torch.cat(pos_syms, 0)
    


k_nn=5
max_corr = 0.20

with torch.no_grad():
    t_losses_ransac = []
    r_losses_ransac = []

    t_losses_sym = []
    r_losses_sym = []
    
    chamf_ransac = []
    chamf_sym = []
    
    for idx in range(len(dataset)):
        logger.log("----------------------")
        if (idx+1) % 10 == 0:
            
            logger.log("Eval align Index: {}/{}  ".format(idx+1, len(dataset))) 
                            
            logger.log("RANSAC: T error: {:.4f} R error: {:.4f}".format(sum(t_losses_ransac)/len(t_losses_ransac), 
                                                           sum(r_losses_ransac)/len(r_losses_ransac)))
            
            logger.log("SYM: T error: {:.4f} R error: {:.4f}".format(sum(t_losses_sym)/len(t_losses_sym), 
                                                           sum(r_losses_sym)/len(r_losses_sym)))
            logger.log("----------------------")

        # matching based pose estimation
        xyz0, xyz1, T0, T1 = base_origins[idx], pos_origins[idx], base_Ts[idx,:,:], pos_Ts[idx,:,:]
        baseF, posF = base_outputs[idx], pos_outputs[idx]

        base_sym = base_syms[idx].item()
        pos_sym = pos_syms[idx].item()
        
        T_est_best, chamf_dist_best, T_est_ransac, chamf_dist_ransac = sym_pose(baseF, xyz0, posF, 
                                                                                xyz1, pos_sym, k_nn, max_corr)
        
        t_loss_sym, r_loss_sym = eval_pose(T_est_best, T0, T1, axis_symmetry=max(pos_sym, base_sym))

        t_loss_ransac, r_loss_ransac = eval_pose(T_est_ransac, T0, T1, axis_symmetry=max(pos_sym, base_sym))
        
        t_losses_ransac.append(t_loss_ransac)
        r_losses_ransac.append(r_loss_ransac)
        chamf_ransac.append(chamf_dist_ransac)

        logger.log("ransac: t: {:.4f} r: {:.4f} chamf: {:.4f}".format(t_loss_ransac, 
                                                                 r_loss_ransac, 
                                                                 chamf_dist_ransac ))
        
        logger.log("ransac avg: t: {:.4f} r: {:.4f} chamf: {:.4f}".format(sum(t_losses_ransac)/len(t_losses_ransac), 
                                                                     sum(r_losses_ransac)/len(r_losses_ransac),
                                                                     sum(chamf_ransac)/len(chamf_ransac)))
        
            
        t_losses_sym.append(t_loss_sym)
        r_losses_sym.append(r_loss_sym)
        chamf_sym.append(chamf_dist_best)

        logger.log("sym: t: {:.4f} r: {:.4f} chamf: {:.4f}".format(t_loss_sym, 
                                                              r_loss_sym, 
                                                              chamf_dist_best))
        
        logger.log("sym avg: t: {:.4f} r: {:.4f} chamf: {:.4f}".format(sum(t_losses_sym)/len(t_losses_sym), 
                                                                  sum(r_losses_sym)/len(r_losses_sym),
                                                                  sum(chamf_sym)/len(chamf_sym)))
        logger.log("----------------------")

        

2021-03-21 03:03:37 start eval
2021-03-21 03:03:45 Eval feature Index: 1/19
2021-03-21 03:03:51 Eval feature Index: 2/19
2021-03-21 03:03:59 Eval feature Index: 3/19
2021-03-21 03:04:04 Eval feature Index: 4/19
2021-03-21 03:04:13 Eval feature Index: 5/19
2021-03-21 03:04:18 Eval feature Index: 6/19
2021-03-21 03:04:27 Eval feature Index: 7/19
2021-03-21 03:04:32 Eval feature Index: 8/19
2021-03-21 03:04:41 Eval feature Index: 9/19
2021-03-21 03:04:46 Eval feature Index: 10/19
2021-03-21 03:04:54 Eval feature Index: 11/19
2021-03-21 03:04:59 Eval feature Index: 12/19
2021-03-21 03:05:08 Eval feature Index: 13/19
2021-03-21 03:05:13 Eval feature Index: 14/19
2021-03-21 03:05:22 Eval feature Index: 15/19
2021-03-21 03:05:27 Eval feature Index: 16/19
2021-03-21 03:05:36 Eval feature Index: 17/19
2021-03-21 03:05:41 Eval feature Index: 18/19
2021-03-21 03:05:45 Eval feature Index: 19/19
2021-03-21 03:05:47 ----------------------
2021-03-21 03:05:54 ransac: t: 0.1250 r: 0.0966 chamf: 0.0413

2021-03-21 03:08:44 ransac: t: 0.1113 r: 0.1282 chamf: 0.0576
2021-03-21 03:08:44 ransac avg: t: 0.0669 r: 0.0959 chamf: 0.0529
2021-03-21 03:08:44 sym: t: 0.0355 r: 0.1174 chamf: 0.0356
2021-03-21 03:08:44 sym avg: t: 0.0472 r: 0.0700 chamf: 0.0488
2021-03-21 03:08:44 ----------------------
2021-03-21 03:08:44 ----------------------
2021-03-21 03:08:54 ransac: t: 0.0358 r: 0.1602 chamf: 0.0630
2021-03-21 03:08:54 ransac avg: t: 0.0655 r: 0.0987 chamf: 0.0534
2021-03-21 03:08:54 sym: t: 0.0410 r: 0.1278 chamf: 0.0558
2021-03-21 03:08:54 sym avg: t: 0.0469 r: 0.0726 chamf: 0.0491
2021-03-21 03:08:54 ----------------------
2021-03-21 03:08:54 ----------------------
2021-03-21 03:09:01 ransac: t: 0.0343 r: 0.0567 chamf: 0.0357
2021-03-21 03:09:01 ransac avg: t: 0.0642 r: 0.0970 chamf: 0.0526
2021-03-21 03:09:01 sym: t: 0.0300 r: 0.0085 chamf: 0.0352
2021-03-21 03:09:01 sym avg: t: 0.0462 r: 0.0699 chamf: 0.0485
2021-03-21 03:09:01 ----------------------
2021-03-21 03:09:01 ---------------

2021-03-21 03:12:00 ransac: t: 0.0587 r: 0.0531 chamf: 0.0542
2021-03-21 03:12:00 ransac avg: t: 0.0661 r: 0.1095 chamf: 0.0535
2021-03-21 03:12:00 sym: t: 0.0184 r: 0.0582 chamf: 0.0506
2021-03-21 03:12:00 sym avg: t: 0.0518 r: 0.0859 chamf: 0.0499
2021-03-21 03:12:00 ----------------------
2021-03-21 03:12:00 ----------------------
2021-03-21 03:12:09 ransac: t: 0.0153 r: 0.0753 chamf: 0.0595
2021-03-21 03:12:09 ransac avg: t: 0.0650 r: 0.1088 chamf: 0.0536
2021-03-21 03:12:09 sym: t: 0.0153 r: 0.0753 chamf: 0.0595
2021-03-21 03:12:09 sym avg: t: 0.0510 r: 0.0856 chamf: 0.0501
2021-03-21 03:12:09 ----------------------
2021-03-21 03:12:09 ----------------------
2021-03-21 03:12:18 ransac: t: 0.0916 r: 0.1144 chamf: 0.0577
2021-03-21 03:12:18 ransac avg: t: 0.0656 r: 0.1089 chamf: 0.0537
2021-03-21 03:12:18 sym: t: 0.0505 r: 0.0168 chamf: 0.0551
2021-03-21 03:12:18 sym avg: t: 0.0510 r: 0.0842 chamf: 0.0502
2021-03-21 03:12:18 ----------------------
2021-03-21 03:12:18 ---------------

2021-03-21 03:15:06 ransac: t: 0.2205 r: 0.2202 chamf: 0.0710
2021-03-21 03:15:06 ransac avg: t: 0.0671 r: 0.1083 chamf: 0.0532
2021-03-21 03:15:06 sym: t: 0.1363 r: 0.1444 chamf: 0.0696
2021-03-21 03:15:06 sym avg: t: 0.0533 r: 0.0849 chamf: 0.0504
2021-03-21 03:15:06 ----------------------
2021-03-21 03:15:06 ----------------------
2021-03-21 03:15:13 ransac: t: 0.0921 r: 0.1601 chamf: 0.0504
2021-03-21 03:15:13 ransac avg: t: 0.0674 r: 0.1090 chamf: 0.0532
2021-03-21 03:15:13 sym: t: 0.0557 r: 0.0633 chamf: 0.0474
2021-03-21 03:15:13 sym avg: t: 0.0533 r: 0.0846 chamf: 0.0504
2021-03-21 03:15:13 ----------------------
2021-03-21 03:15:13 ----------------------
2021-03-21 03:15:21 ransac: t: 1.3530 r: 3.1309 chamf: 0.0952
2021-03-21 03:15:21 ransac avg: t: 0.0853 r: 0.1510 chamf: 0.0538
2021-03-21 03:15:21 sym: t: 0.0829 r: 0.0838 chamf: 0.0603
2021-03-21 03:15:21 sym avg: t: 0.0537 r: 0.0846 chamf: 0.0505
2021-03-21 03:15:21 ----------------------
2021-03-21 03:15:21 ---------------

2021-03-21 03:18:04 ransac: t: 0.0966 r: 0.1308 chamf: 0.0500
2021-03-21 03:18:04 ransac avg: t: 0.0837 r: 0.1461 chamf: 0.0536
2021-03-21 03:18:04 sym: t: 0.0966 r: 0.1308 chamf: 0.0500
2021-03-21 03:18:04 sym avg: t: 0.0581 r: 0.0869 chamf: 0.0503
2021-03-21 03:18:04 ----------------------
2021-03-21 03:18:04 ----------------------
2021-03-21 03:18:11 ransac: t: 0.0730 r: 0.0527 chamf: 0.0612
2021-03-21 03:18:11 ransac avg: t: 0.0836 r: 0.1451 chamf: 0.0537
2021-03-21 03:18:11 sym: t: 0.0730 r: 0.0527 chamf: 0.0612
2021-03-21 03:18:11 sym avg: t: 0.0582 r: 0.0866 chamf: 0.0504
2021-03-21 03:18:11 ----------------------
2021-03-21 03:18:11 ----------------------
2021-03-21 03:18:18 ransac: t: 0.5751 r: 3.0462 chamf: 0.0522
2021-03-21 03:18:18 ransac avg: t: 0.0887 r: 0.1754 chamf: 0.0537
2021-03-21 03:18:18 sym: t: 0.5751 r: 3.0462 chamf: 0.0522
2021-03-21 03:18:18 sym avg: t: 0.0636 r: 0.1174 chamf: 0.0504
2021-03-21 03:18:18 ----------------------
2021-03-21 03:18:18 ---------------

2021-03-21 03:21:10 ransac: t: 0.0274 r: 0.0436 chamf: 0.0590
2021-03-21 03:21:10 ransac avg: t: 0.0827 r: 0.1590 chamf: 0.0523
2021-03-21 03:21:10 sym: t: 0.0274 r: 0.0436 chamf: 0.0590
2021-03-21 03:21:10 sym avg: t: 0.0616 r: 0.1086 chamf: 0.0494
2021-03-21 03:21:10 ----------------------
2021-03-21 03:21:10 ----------------------
2021-03-21 03:21:18 ransac: t: 0.0583 r: 0.0784 chamf: 0.0605
2021-03-21 03:21:18 ransac avg: t: 0.0825 r: 0.1584 chamf: 0.0524
2021-03-21 03:21:18 sym: t: 0.0378 r: 0.0448 chamf: 0.0549
2021-03-21 03:21:18 sym avg: t: 0.0614 r: 0.1081 chamf: 0.0494
2021-03-21 03:21:18 ----------------------
2021-03-21 03:21:18 ----------------------
2021-03-21 03:21:18 Eval align Index: 120/592  
2021-03-21 03:21:18 RANSAC: T error: 0.0825 R error: 0.1584
2021-03-21 03:21:18 SYM: T error: 0.0614 R error: 0.1081
2021-03-21 03:21:18 ----------------------
2021-03-21 03:21:25 ransac: t: 0.0829 r: 0.0768 chamf: 0.0317
2021-03-21 03:21:25 ransac avg: t: 0.0825 r: 0.1577 chamf:

2021-03-21 03:24:12 ransac: t: 0.0311 r: 0.0914 chamf: 0.0468
2021-03-21 03:24:12 ransac avg: t: 0.0882 r: 0.1723 chamf: 0.0525
2021-03-21 03:24:12 sym: t: 0.0311 r: 0.0914 chamf: 0.0468
2021-03-21 03:24:12 sym avg: t: 0.0631 r: 0.1095 chamf: 0.0496
2021-03-21 03:24:12 ----------------------
2021-03-21 03:24:12 ----------------------
2021-03-21 03:24:20 ransac: t: 0.0168 r: 0.1375 chamf: 0.0595
2021-03-21 03:24:20 ransac avg: t: 0.0877 r: 0.1720 chamf: 0.0525
2021-03-21 03:24:20 sym: t: 0.0168 r: 0.1375 chamf: 0.0595
2021-03-21 03:24:20 sym avg: t: 0.0628 r: 0.1097 chamf: 0.0496
2021-03-21 03:24:20 ----------------------
2021-03-21 03:24:20 ----------------------
2021-03-21 03:24:27 ransac: t: 0.0815 r: 0.1386 chamf: 0.0362
2021-03-21 03:24:27 ransac avg: t: 0.0876 r: 0.1718 chamf: 0.0524
2021-03-21 03:24:27 sym: t: 0.0248 r: 0.0213 chamf: 0.0263
2021-03-21 03:24:27 sym avg: t: 0.0625 r: 0.1091 chamf: 0.0495
2021-03-21 03:24:27 ----------------------
2021-03-21 03:24:27 ---------------

2021-03-21 03:27:16 ransac: t: 0.0503 r: 0.0751 chamf: 0.0718
2021-03-21 03:27:16 ransac avg: t: 0.0870 r: 0.1774 chamf: 0.0526
2021-03-21 03:27:16 sym: t: 0.0503 r: 0.0751 chamf: 0.0718
2021-03-21 03:27:16 sym avg: t: 0.0638 r: 0.1222 chamf: 0.0496
2021-03-21 03:27:16 ----------------------
2021-03-21 03:27:16 ----------------------
2021-03-21 03:27:23 ransac: t: 0.0605 r: 0.0488 chamf: 0.0261
2021-03-21 03:27:23 ransac avg: t: 0.0869 r: 0.1766 chamf: 0.0525
2021-03-21 03:27:23 sym: t: 0.0221 r: 0.0551 chamf: 0.0231
2021-03-21 03:27:23 sym avg: t: 0.0636 r: 0.1218 chamf: 0.0494
2021-03-21 03:27:23 ----------------------
2021-03-21 03:27:23 ----------------------
2021-03-21 03:27:30 ransac: t: 0.0367 r: 0.0579 chamf: 0.0332
2021-03-21 03:27:30 ransac avg: t: 0.0866 r: 0.1759 chamf: 0.0524
2021-03-21 03:27:30 sym: t: 0.0367 r: 0.0579 chamf: 0.0332
2021-03-21 03:27:30 sym avg: t: 0.0634 r: 0.1214 chamf: 0.0493
2021-03-21 03:27:30 ----------------------
2021-03-21 03:27:30 ---------------

2021-03-21 03:30:20 ransac: t: 0.1065 r: 0.1053 chamf: 0.0624
2021-03-21 03:30:20 ransac avg: t: 0.0902 r: 0.1824 chamf: 0.0526
2021-03-21 03:30:20 sym: t: 0.1554 r: 0.1515 chamf: 0.0614
2021-03-21 03:30:20 sym avg: t: 0.0618 r: 0.1155 chamf: 0.0495
2021-03-21 03:30:20 ----------------------
2021-03-21 03:30:20 ----------------------
2021-03-21 03:30:20 Eval align Index: 190/592  
2021-03-21 03:30:20 RANSAC: T error: 0.0902 R error: 0.1824
2021-03-21 03:30:20 SYM: T error: 0.0618 R error: 0.1155
2021-03-21 03:30:20 ----------------------
2021-03-21 03:30:28 ransac: t: 0.0273 r: 0.0352 chamf: 0.0296
2021-03-21 03:30:28 ransac avg: t: 0.0898 r: 0.1816 chamf: 0.0525
2021-03-21 03:30:28 sym: t: 0.0438 r: 0.0217 chamf: 0.0292
2021-03-21 03:30:28 sym avg: t: 0.0617 r: 0.1150 chamf: 0.0494
2021-03-21 03:30:28 ----------------------
2021-03-21 03:30:28 ----------------------
2021-03-21 03:30:34 ransac: t: 0.0577 r: 0.1025 chamf: 0.0485
2021-03-21 03:30:34 ransac avg: t: 0.0897 r: 0.1812 chamf:

In [None]:
"""              
np.save("/zty-vol/results/stats/shapenet-{}-T-ransac.npy".format(config.catid), np.array(t_losses_ransac))
np.save("/zty-vol/results/stats/shapenet-{}-R-ransac.npy".format(config.catid), np.array(r_losses_ransac))

np.save("/zty-vol/results/stats/shapenet-{}-T-sym.npy".format(config.catid), np.array(t_losses_sym))
np.save("/zty-vol/results/stats/shapenet-{}-R-sym.npy".format(config.catid), np.array(r_losses_sym))


np.save("/zty-vol/results/stats/shapenet-{}-chamf-ransac.npy".format(config.catid), np.array(chamf_ransac))
np.save("/zty-vol/results/stats/shapenet-{}-chamf-sym.npy".format(config.catid), np.array(chamf_sym))
"""  

logger.log("ransac avg: t: {:.4f} r: {:.4f} chamf: {:.4f}".format(sum(t_losses_ransac)/len(t_losses_ransac), 
                                                             sum(r_losses_ransac)/len(r_losses_ransac),
                                                             sum(chamf_ransac)/len(chamf_ransac)))

logger.log("sym avg: t: {:.4f} r: {:.4f} chamf: {:.4f}".format(sum(t_losses_sym)/len(t_losses_sym), 
                                                                  sum(r_losses_sym)/len(r_losses_sym),
                                                                  sum(chamf_sym)/len(chamf_sym)))