In [33]:
import numpy as np
from network import Network
from torch.utils.data import DataLoader
from config import get_config, print_usage
from data_utils.ShapeNetLoader import ShapeNetLoader

In [34]:
# Parse configuration
config, unparsed = get_config()
config.log_dir = "plane_dim3"
config.indim=3
config.cat_id=2
config.mode="test"

In [35]:
# build model
from models.acne_ae import AcneAe 
import torch

if config.input_feat in ["ppf"] \
    or config.pose_code in ["weighted_qt"]:
    require_normal = True
else:
    require_normal = False

# load data
dataset = ShapeNetLoader(
    data_dump_folder=config.data_dump_folder, indim=config.indim, freeze_data=False, id=config.cat_id, require_normal=require_normal,
    num_pts=config.num_pts, mode="test", jitter_type=config.pc_jitter_type)
data_loader_te = DataLoader(
    dataset=dataset,
    batch_size=config.test_batch_size,
    num_workers=0,
    shuffle=False,
    pin_memory=False,
)
    
torch.backends.cudnn.deterministic = True
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)
np.random.seed(1234)

model = eval(config.model)
model = model(config) # 2_1 output one_dimensional score

if config.use_cuda:
    model.cuda()

cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
num K: 10
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
cn_type: acn_b-10
num_head: 10
bn type: bn
num K: 10


In [36]:
# See data input
import open3d as o3d

def convert_to_o3dpcd(pts, color=None):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pts)
    if color is not None:
        pcd.paint_uniform_color(color)
    return pcd
    

In [42]:
# import KITTI incomplete pcds
import glob

kitti_pcds = glob.glob('/canonical_capsules/data_dump/kitti/cars/*')
kidx = 0
o3dpc = o3d.io.read_point_cloud(kitti_pcds[kidx])
pc_pts = np.asarray(o3dpc.points - o3dpc.get_center(), dtype=np.float)
o3d.visualization.draw_geometries([o3dpc])

In [45]:
# taken from network.test and model.forward_test

from tqdm import tqdm
from geom_torch import trans_pc_random, procruste_pose, get_rots 
import os
from loss_util import *
from vis_util import *

model.eval()

# Read checkpoint file.
load_res = torch.load(config.pt_file)

# Resume iterations
iter_idx = load_res["iter_idx"]
# Resume model
model.load_state_dict(load_res["model"], strict=False)

prefix = "testing"
oas = []
jdx = 0
vis_idx = [1, 2, 3]
# for data in data_loader: 
accs = {}
feats = {}
# model.aligner = None

for data in tqdm(data_loader_te, desc=prefix):         
    
    jdx += 1
    if (jdx%10 == 0):
        
        batch_size = len(data["pc"])    
        # move tensor into cuda
        if config.use_cuda:
            for key in data.keys():
                data[key] = data[key].cuda()
            
        with torch.no_grad():
#             pc = data["pc"]
            o3dpc = o3d.io.read_point_cloud(kitti_pcds[jdx])
            pc_pts = np.asarray(o3dpc.points - o3dpc.get_center(), dtype=np.float)
            pc_pts = torch.from_numpy(pc_pts[np.newaxis,...])
            pc = pc_pts.cuda().float()
            
            rt = None
            att_aligner = None
            if model.aligner is not None:
                model.aligner.eval()
                with torch.no_grad():
                    pc, ret_aligner = model.aligner.forward_align(pc, rt=rt)
                    att_aligner = ret_aligner["att_aligner"]

            x = pc.transpose(2, 1)

            # encoding x
            input_feat = x[..., None]
            gc_att = model.encoder(input_feat, att_aligner=att_aligner, return_att=True) # BCK1, B1GN1
            gc, att = gc_att

            # Evaluating pose 
            pose_local = evaluate_pose(x, att)
            kps = pose_local.squeeze(-1)

            # w/o canonicalized descriptor
            if config.pose_block == "procruste":
                if model.ref_kp_type != "None":
                    if self.ref_kp_type.startswith("mlp"):
                        kps_ref= model.ref_kp_net(gc.reshape(gc.shape[0], -1))
                        kps_ref = kps_ref - kps_ref.mean(dim=2, keepdim=True)
                    else:
                        raise NotImplementedError
                    R_can, T_can = procruste_pose(kps, kps_ref, std_noise=0) # kps_ref = R * kpsi  + T
                    kps = torch.matmul(R_can, kps) + T_can
                else:
                    R_can = None

            # reconstruction from canonical capsules 
            gc = torch.cat([kps[..., None], gc], dim=1)
            y = model.decoder(gc.transpose(2, 1).squeeze(-1))

            print(f'Visualising data #{jdx}')
            input_pcd = convert_to_o3dpcd(np.array(pc[0].cpu()), color=[0.85,0.85,0.85])
            out_pcd = convert_to_o3dpcd(np.array(y[0].cpu()), color=[1,0,0])
#             mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6, origin=[0,0,0])
            o3d.visualization.draw_geometries([input_pcd, out_pcd])

testing:   0%|                                              | 0/703 [00:00<?, ?it/s]

Visualising data #10


testing:   1%|▌                                    | 10/703 [00:26<30:26,  2.64s/it]

Visualising data #20


testing:   3%|█                                    | 20/703 [00:39<25:27,  2.24s/it]

Visualising data #30


testing:   4%|█▌                                   | 30/703 [00:49<20:56,  1.87s/it]

Visualising data #40


testing:   6%|██                                   | 39/703 [01:11<20:25,  1.85s/it]


KeyboardInterrupt: 

# Archive

In [163]:
# Taken from vis_single - however I think this rotates 
# multiple times so it can construct a gif

from tqdm import tqdm
from geom_torch import trans_pc_random, procruste_pose, get_rots 
import os
from loss_util import *
from vis_util import *

model.eval()

# Read checkpoint file.
load_res = torch.load(config.pt_file)

# Resume iterations
iter_idx = load_res["iter_idx"]
# Resume model
model.load_state_dict(load_res["model"], strict=False)

prefix = "vis"
dump_res = []
idx = 0
if config.vis_idx > 0:
    vis_idx = [config.vis_idx]
else:
    vis_idx = [i for i in range(1, 4)]

# model.aligner = None
    
for data in tqdm(data_loader_te, desc=prefix): 
    idx += 1
    if idx not in vis_idx:
        continue
    # move tensor into cuda
    if config.use_cuda:
        for key in data.keys():
            data[key] = data[key].cuda()
    in_dict = {} 
    in_dict["data"] = data
    in_dict["mode"] = "vis" 
    in_dict["iter_idx"] = iter_idx
    in_dict["writer"] = None

    with torch.no_grad():
        # rotate input and show the decomposition and reconstruction 
        data = in_dict["data"]
        mode = in_dict["mode"]
        writer = in_dict["writer"]
        iter_idx = in_dict["iter_idx"]
        pc = data["pc"]

        assert pc.shape[2] == config.indim
        x_can = pc.transpose(2, 1)
        idx = 0
        Rs = get_rots(config.indim)

        for R in Rs:
            idx += 1
            x = torch.matmul(
                torch.from_numpy(R).to(x_can.device), x_can)

            att_aligner = None
            if model.aligner is not None:
                with torch.no_grad():
                    model.aligner.eval()
                    x_, ret_aligner = model.aligner.forward_align(x.transpose(2, 1), mode="vis")
                    att_aligner = ret_aligner["att_aligner"]

                    # visualization
                    idx = 0
                    x_pts = x[idx].transpose(1, 0).cpu().numpy()                    
                    
                    x = x_.transpose(2, 1)

            input_feat = x[..., None]
            gc_att = model.encoder(input_feat, att_aligner=att_aligner, return_att=True) # BCK1, B1GN1
            gc, att = gc_att
            
            # Evaluating pose 
            pose_local = evaluate_pose(x, att)
            kps = pose_local.squeeze(-1)
            gc = torch.cat([kps[..., None], gc], dim=1)
            pc_recons = model.decoder(
                gc.transpose(2, 1).squeeze(-1), return_splits=True)
            y = torch.cat(pc_recons, dim=2).transpose(2, 1)
            loss_chamfer = model.chamfer_loss(x.transpose(2, 1), y)
            print(f"R: {R}; loss: {loss_chamfer}")
            
            # Reconstruct in canonical
            label_map = []
            pts = []
            for i, patch in enumerate(pc_recons):
#                 print(patch.shape)
                pts_cur = patch[0].transpose(1, 0).cpu().numpy()
                label_map += [np.ones(len(pts_cur)) * i]
                pts += [pts_cur]

            recon_pts = np.concatenate(pts, axis=0)
    
            input_pcd = convert_to_o3dpcd(np.array(data["pc"][0].cpu()), color=[0.85,0.85,0.85])
            out_pcd = convert_to_o3dpcd(recon_pts)
            mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6, origin=[0,0,0])
            o3d.visualization.draw_geometries([mesh_frame, input_pcd, out_pcd])

vis:   0%|                                                  | 0/809 [00:00<?, ?it/s]

R: [[[1. 0. 0.]
  [0. 1. 0.]
  [0. 0. 1.]]]; loss: 0.0019144623074680567
R: [[[ 0.9045085  -0.29389262  0.309017  ]
  [ 0.38471043  0.875      -0.29389262]
  [-0.18401699  0.38471043  0.9045085 ]]]; loss: 0.0018811644986271858
R: [[[ 0.6545085  -0.47552827  0.58778524]
  [ 0.7550368   0.4514337  -0.47552827]
  [-0.03921894  0.7550368   0.6545085 ]]]; loss: 0.001962190493941307
R: [[[ 0.3454915  -0.47552827  0.809017  ]
  [ 0.8602387  -0.18401699 -0.47552827]
  [ 0.375       0.8602387   0.3454915 ]]]; loss: 0.002017565770074725
R: [[[ 0.09549151 -0.29389262  0.95105654]
  [ 0.57340115 -0.7647472  -0.29389262]
  [ 0.81369066  0.57340115  0.09549151]]]; loss: 0.00177321198862046
R: [[[ 3.7493994e-33 -6.1232343e-17  1.0000000e+00]
  [ 1.2246469e-16 -1.0000000e+00 -6.1232343e-17]
  [ 1.0000000e+00  1.2246469e-16  3.7493994e-33]]]; loss: 0.001931279432028532
R: [[[ 0.09549151  0.29389262  0.95105654]
  [-0.57340115 -0.7647472   0.29389262]
  [ 0.81369066 -0.57340115  0.09549151]]]; loss: 0.0

vis:   0%|                                                  | 0/809 [00:28<?, ?it/s]


KeyboardInterrupt: 