In [15]:
import os
import open3d as o3d

import torch
import merger.merger_net as merger_net
from merger.merger_net import Net
import json
import tqdm
import numpy as np
import argparse

arg_parser = argparse.ArgumentParser(description="Predictor for Skeleton Merger on KeypointNet dataset. Outputs a npz file with two arrays: kpcd - (N, k, 3) xyz coordinates of keypoints detected; nfact - (N, 2) normalization factor, or max and min coordinate values in a point cloud.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
pcd_path='/home/luhr/correspondence/SkeletonMerger_garment/cloth3d/train'
checkpoint_path='20.pth'
device='cuda'
n_keypoint=20
batch=4
max_points=10000

In [10]:
def load_cloth_mesh(path):
    """Load .obj of cloth mesh. Only quad-mesh is acceptable!
    Return:
        - vertices: ndarray, (N, 3)
        - triangle_faces: ndarray, (S, 3)
        - stretch_edges: ndarray, (M1, 2)
        - bend_edges: ndarray, (M2, 2)
        - shear_edges: ndarray, (M3, 2)
    This function was written by Zhenjia Xu
    email: xuzhenjia [at] cs (dot) columbia (dot) edu
    website: https://www.zhenjiaxu.com/
    """
    vertices, faces = [], []
    with open(path, 'r') as f:
        lines = f.readlines()
    for line in lines:
        # 3D vertex
        if line.startswith('v '):
            vertices.append([float(n)
                             for n in line.replace('v ', '').split(' ')])
        # Face
        elif line.startswith('f '):
            idx = [n.split('/') for n in line.replace('f ', '').split(' ')]
            face = [int(n[0]) - 1 for n in idx]
            assert(len(face) == 4)
            faces.append(face)

    triangle_faces = []
    for face in faces:
        triangle_faces.append([face[0], face[1], face[2]])
        triangle_faces.append([face[0], face[2], face[3]])

    stretch_edges, shear_edges, bend_edges = set(), set(), set()

    # Stretch & Shear
    for face in faces:
        stretch_edges.add(tuple(sorted([face[0], face[1]])))
        stretch_edges.add(tuple(sorted([face[1], face[2]])))
        stretch_edges.add(tuple(sorted([face[2], face[3]])))
        stretch_edges.add(tuple(sorted([face[3], face[0]])))

        shear_edges.add(tuple(sorted([face[0], face[2]])))
        shear_edges.add(tuple(sorted([face[1], face[3]])))

    # Bend
    neighbours = dict()
    for vid in range(len(vertices)):
        neighbours[vid] = set()
    for edge in stretch_edges:
        neighbours[edge[0]].add(edge[1])
        neighbours[edge[1]].add(edge[0])
    for vid in range(len(vertices)):
        neighbour_list = list(neighbours[vid])
        N = len(neighbour_list)
        for i in range(N - 1):
            for j in range(i+1, N):
                bend_edge = tuple(
                    sorted([neighbour_list[i], neighbour_list[j]]))
                if bend_edge not in shear_edges:
                    bend_edges.add(bend_edge)
    return np.array(vertices), np.array(triangle_faces),\
        np.array(list(stretch_edges)), np.array(
            list(bend_edges)), np.array(list(shear_edges))




def prepare_data(path):
    data=[]
    paths=[]
    ori_data=[]
    for root,dirs,files in os.walk(path):
        for file in files:
            if file.endswith('processed.obj'):
                file_path=os.path.join(root,file)
                vertices, trangle_faces, stretch_edges, bend_edges, shear_edges = load_cloth_mesh(os.path.join(root, file))
                points=np.array(vertices)
                ori_data.append(points)
                # downsample to 10000 points
                if len(points)>max_points:
                    points=points[np.random.choice(len(points),max_points,replace=False)]
                    data.append(points)
                    paths.append(file_path)
    return ori_data,data,paths

def find_nearest_point(points,kp):
    kp=np.array(kp)
    points=np.array(points)
    dist=np.linalg.norm(kp[np.newaxis,:,:]-points[:,np.newaxis,:],axis=2)
    return np.argmin(dist,axis=0)

def visualize_pointcloud(pc,kp):
    pcd=o3d.geometry.PointCloud()
    pcd.points=o3d.utility.Vector3dVector(pc)
    colors=np.zeros_like(pc)
    colors[kp]=np.array([1,0,0])
    pcd.colors=o3d.utility.Vector3dVector(colors)
    o3d.visualization.draw_geometries([pcd])

In [16]:
net = Net(max_points, n_keypoint).to(device)
net.load_state_dict(torch.load(checkpoint_path, map_location=torch.device(device))['model_state_dict'])
net.eval()

RuntimeError: Error(s) in loading state_dict for Net:
	Missing key(s) in state_dict: "DEC.10.0.emb", "DEC.10.1.emb", "DEC.10.2.emb", "DEC.10.3.emb", "DEC.10.4.emb", "DEC.10.5.emb", "DEC.10.6.emb", "DEC.10.7.emb", "DEC.10.8.emb", "DEC.10.9.emb", "DEC.11.0.emb", "DEC.11.1.emb", "DEC.11.2.emb", "DEC.11.3.emb", "DEC.11.4.emb", "DEC.11.5.emb", "DEC.11.6.emb", "DEC.11.7.emb", "DEC.11.8.emb", "DEC.11.9.emb", "DEC.11.10.emb", "DEC.12.0.emb", "DEC.12.1.emb", "DEC.12.2.emb", "DEC.12.3.emb", "DEC.12.4.emb", "DEC.12.5.emb", "DEC.12.6.emb", "DEC.12.7.emb", "DEC.12.8.emb", "DEC.12.9.emb", "DEC.12.10.emb", "DEC.12.11.emb", "DEC.13.0.emb", "DEC.13.1.emb", "DEC.13.2.emb", "DEC.13.3.emb", "DEC.13.4.emb", "DEC.13.5.emb", "DEC.13.6.emb", "DEC.13.7.emb", "DEC.13.8.emb", "DEC.13.9.emb", "DEC.13.10.emb", "DEC.13.11.emb", "DEC.13.12.emb", "DEC.14.0.emb", "DEC.14.1.emb", "DEC.14.2.emb", "DEC.14.3.emb", "DEC.14.4.emb", "DEC.14.5.emb", "DEC.14.6.emb", "DEC.14.7.emb", "DEC.14.8.emb", "DEC.14.9.emb", "DEC.14.10.emb", "DEC.14.11.emb", "DEC.14.12.emb", "DEC.14.13.emb", "DEC.15.0.emb", "DEC.15.1.emb", "DEC.15.2.emb", "DEC.15.3.emb", "DEC.15.4.emb", "DEC.15.5.emb", "DEC.15.6.emb", "DEC.15.7.emb", "DEC.15.8.emb", "DEC.15.9.emb", "DEC.15.10.emb", "DEC.15.11.emb", "DEC.15.12.emb", "DEC.15.13.emb", "DEC.15.14.emb", "DEC.16.0.emb", "DEC.16.1.emb", "DEC.16.2.emb", "DEC.16.3.emb", "DEC.16.4.emb", "DEC.16.5.emb", "DEC.16.6.emb", "DEC.16.7.emb", "DEC.16.8.emb", "DEC.16.9.emb", "DEC.16.10.emb", "DEC.16.11.emb", "DEC.16.12.emb", "DEC.16.13.emb", "DEC.16.14.emb", "DEC.16.15.emb", "DEC.17.0.emb", "DEC.17.1.emb", "DEC.17.2.emb", "DEC.17.3.emb", "DEC.17.4.emb", "DEC.17.5.emb", "DEC.17.6.emb", "DEC.17.7.emb", "DEC.17.8.emb", "DEC.17.9.emb", "DEC.17.10.emb", "DEC.17.11.emb", "DEC.17.12.emb", "DEC.17.13.emb", "DEC.17.14.emb", "DEC.17.15.emb", "DEC.17.16.emb", "DEC.18.0.emb", "DEC.18.1.emb", "DEC.18.2.emb", "DEC.18.3.emb", "DEC.18.4.emb", "DEC.18.5.emb", "DEC.18.6.emb", "DEC.18.7.emb", "DEC.18.8.emb", "DEC.18.9.emb", "DEC.18.10.emb", "DEC.18.11.emb", "DEC.18.12.emb", "DEC.18.13.emb", "DEC.18.14.emb", "DEC.18.15.emb", "DEC.18.16.emb", "DEC.18.17.emb", "DEC.19.0.emb", "DEC.19.1.emb", "DEC.19.2.emb", "DEC.19.3.emb", "DEC.19.4.emb", "DEC.19.5.emb", "DEC.19.6.emb", "DEC.19.7.emb", "DEC.19.8.emb", "DEC.19.9.emb", "DEC.19.10.emb", "DEC.19.11.emb", "DEC.19.12.emb", "DEC.19.13.emb", "DEC.19.14.emb", "DEC.19.15.emb", "DEC.19.16.emb", "DEC.19.17.emb", "DEC.19.18.emb". 
	size mismatch for MA_EMB: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([190]).
	size mismatch for PTW.conv2.weight: copying a param with shape torch.Size([10, 128, 1]) from checkpoint, the shape in current model is torch.Size([20, 128, 1]).
	size mismatch for PTW.conv2.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([20]).
	size mismatch for PT_L.weight: copying a param with shape torch.Size([10, 10]) from checkpoint, the shape in current model is torch.Size([20, 20]).
	size mismatch for PT_L.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([20]).
	size mismatch for MA_L.weight: copying a param with shape torch.Size([45, 256]) from checkpoint, the shape in current model is torch.Size([190, 256]).
	size mismatch for MA_L.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([190]).

In [9]:

ori_data,kpn_ds,paths=prepare_data(pcd_path)

In [11]:
out_kpcd=[]
for i in tqdm.tqdm(range(0, len(kpn_ds), batch), unit_scale=batch):
    Q = []
    for j in range(batch):
        if i + j >= len(kpn_ds):
            continue
        pc = kpn_ds[i + j]
        Q.append(pc)
    if len(Q) == 1:
        Q.append(Q[-1])
    with torch.no_grad():
        recon, key_points, kpa, emb, null_activation = net(torch.Tensor(np.array(Q)).to(device))
    for kp in key_points:
        out_kpcd.append(kp)
for i in range(len(out_kpcd)):
    out_kpcd[i] = out_kpcd[i].cpu().numpy()

for i in range(len(out_kpcd)):
    kp_id=find_nearest_point(ori_data[i],out_kpcd[i])
    # visualize_pointcloud(ori_data[i],kp_id)
    np.savez(paths[i].replace('processed.obj','keypoints.npz'),keypoints=out_kpcd[i],keypoint_id=kp_id,pointcloud=ori_data[i])

100%|██████████| 500/500 [00:36<00:00, 13.68it/s]


KeyboardInterrupt: 