In [1]:
import sys
sys.path.append( '../..' )
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'

import torch
import torch.nn as nn 
import torch.nn.functional as F
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
import numpy as np
from energy import tangent_kernel

# Set the device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    print("WARNING: CPU only, this will be slow!")
import input_output 

In [2]:
#experiment_name = 'cup'
#experiment_name = 'dolphin'
experiment_name = 'hippo'
#experiment_name = 'bunny'

In [3]:
# We read the target 3D model using load_obj
# case 1.  "Cup"
low = 31
mid = 26
high = 20
if experiment_name=='cup':
    # two different Cups 
    [VS,FS,FunS] = input_output.loadData("../../data/matching/cup2.ply")
    [VT,FT,FunT] = input_output.loadData("../../data/matching/cup3_broken.ply")
    source = [VS,FS]
    target= [VT,FT]
    # original size
    print(VS.shape)
    print(VT.shape)
    # Option 1 Sampling
    # Decimate source mesh to compute initialization for the multires algorithm 
    param_decimation = {'factor':31/32,'Vol_preser':1, 'Fun_Error_Metric': 1, 'Fun_weigth':0.00} #decimate by a factor of 32
    [verts1,faces1]= input_output.decimate_mesh(VS,FS,param_decimation)
    print(verts1.shape)
    param_decimation = {'factor':31/32,'Vol_preser':1, 'Fun_Error_Metric': 1, 'Fun_weigth':0.00} #decimate by a factor of 32
    [verts2,faces2]= input_output.decimate_mesh(VT,FT,param_decimation)
    print(verts2.shape)
    verts1 = torch.FloatTensor(verts1)
    verts2 = torch.FloatTensor(verts2)
    faces1 = torch.LongTensor(faces1)
    faces2 = torch.LongTensor(faces2)
    save_obj('../../results/cup2_mid.obj', verts1, faces1)
    save_obj('../../results/cup3_broken_mid.obj', verts2, faces2)
# Case 2 Dolphin
elif experiment_name == 'dolphin':
    # sphere to dolphine 
    src_mesh = ico_sphere(4)
    VT, FT, FunS = load_obj("../../data/matching/dolphin.obj")
    VS, FS = src_mesh.verts_packed(), src_mesh.faces_packed()
    verts1, faces1 = VS, FS
    verts2, faces2 = VT, FT.verts_idx
    
# Case 3 Hippocampus 
elif experiment_name == 'hippo':
    verts1, faces1, verts2, faces2 = torch.load('../../data/matching/hippos_red.pt')

# Case 2 Dolphin
elif experiment_name == 'bunny':
    # sphere to dolphine 
    src_mesh = ico_sphere(4)
    VT, FT, FunS = load_obj("../../data/matching/bunny.obj")
    VS, FS = src_mesh.verts_packed(), src_mesh.faces_packed()
    # Decimate source mesh to compute initialization for the multires algorithm 
    param_decimation = {'factor':17/32,'Vol_preser':1, 'Fun_Error_Metric': 1, 'Fun_weigth':0.00} #decimate by a factor of 16
    [verts1,faces1]= input_output.decimate_mesh(VS.numpy(),FS.numpy(),param_decimation)
    print(verts1.shape)
    param_decimation = {'factor':31/32,'Vol_preser':1, 'Fun_Error_Metric': 1, 'Fun_weigth':0.00} #decimate by a factor of 16
    [verts2,faces2]= input_output.decimate_mesh(VT.numpy(),FT.verts_idx.numpy(),param_decimation)
    print(verts2.shape)
    verts1 = torch.FloatTensor(verts1)
    verts2 = torch.FloatTensor(verts2)
    faces1 = torch.LongTensor(faces1)
    faces2 = torch.LongTensor(faces2)

In [4]:
# verts is a FloatTensor of shape (V, 3) where V is the number of vertices in the mesh
# faces is an object which contains the following LongTensors: verts_idx, normals_idx and textures_idx
# For this tutorial, normals and textures are ignored.
faces_idx1 = faces1.to(device)
faces_idx2 = faces2.to(device)

# Mark: obj files
#faces_idx1 = faces1.to(device)#.verts_idx.to(device)
#faces_idx2 = faces2.verts_idx.to(device)

verts1 = verts1.to(device)
verts2 = verts2.to(device)
# We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0). 
# (scale, center) will be used to bring the predicted mesh to its original center and scale
# Note that normalizing the target mesh, speeds up the optimization but is not necessary!
#'''
center1 = verts1.mean(0)
center2 = verts2.mean(0)
verts1 = verts1 - center1
verts2 = verts2 - center2
scale1 = max(verts1.abs().max(0)[0])
scale2 = max(verts2.abs().max(0)[0])
verts1 = verts1 / scale1
verts2 = verts2 / scale2
#'''
# We construct a Meshes structure for the target mesh
src_mesh = Meshes(verts=[verts1], faces=[faces_idx1])
trg_mesh = Meshes(verts=[verts2], faces=[faces_idx2])
print(verts1.shape)
print(verts2.shape)
#save_obj('../../results/cup2.obj',verts1,faces1)
#save_obj('../../results/cup3_broken.obj',verts2,faces2)


torch.Size([1654, 3])
torch.Size([1654, 3])


In [5]:
class testnet(nn.Module):
    def __init__(self):
        super(testnet, self).__init__()
        self.net = nn.Sequential(nn.Linear(6,64),
                                 nn.ReLU(),
                                 nn.Linear(64,128),
                                 nn.ReLU(),
                                 nn.Linear(128,3))
    def forward(self,x):
        return self.net(x)

In [6]:
# optimizer setting
models = testnet().cuda()
optimizer = torch.optim.Adam(models.parameters(), lr=.001)
# Number of optimization steps
Niter = 200001

# loss parameters
# Weight for the chamfer loss
w_chamfer = 1.0 
# Plot period for the losses
plot_period = 1000

chamfer_losses=[]

In [7]:
def compute_engine(V1,V2,L1,L2,K):
    cst_tmp = []
    n_batch = 10000
    for i in range(len(V1)//n_batch + 1):
        tmp = V1[i*n_batch:(i+1)*n_batch,:]
        l_tmp = L1[i*n_batch:(i+1)*n_batch,:]
        v = torch.matmul(K(tmp,V2),L2)*l_tmp
        cst_tmp.append(v)
    cst = torch.sum(torch.cat(cst_tmp,0))
    return cst

def CompCLNn(F, V):
    if F.shape[1] == 2:
        V0, V1 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1])
        C, N  =  (V0 + V1)/2, V1 - V0
    else:
        V0, V1, V2 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1]), V.index_select(0, F[:, 2])
        C, N =  (V0 + V1 + V2)/3, .5 * torch.cross(V1 - V0, V2 - V0)

    L = (N ** 2).sum(dim=1)[:, None].sqrt()
    return C, L, N / L, 1#Fun_F

c,l,n,_ = CompCLNn(faces_idx1,verts1)

In [8]:

best = None
best_loss = 0
best_iter = 0

for i in range(Niter):
    # Initialize optimizer
    optimizer.zero_grad()
    sv, sf = src_mesh.get_mesh_verts_faces(0)
    sn = src_mesh.verts_normals_packed()
    inputs = torch.cat([sv.cuda(),sn.cuda()],1)
    deform_verts = models(inputs) 

    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    f1 = new_src_mesh.faces_packed()
    v1 = new_src_mesh.verts_packed()
    c1,l1,n1,_ = CompCLNn(f1,v1)
    c2,l2,n2,_ = CompCLNn(faces_idx2,verts2)

    c1 = torch.cat([c1,n1],1)
    c2 = torch.cat([c2,n2],1)
    
    loss_chamfer, _ = chamfer_distance(c1.unsqueeze(0), c2.unsqueeze(0))
    
    # Weighted sum of the losses
    loss = w_chamfer*loss_chamfer

    if best_loss == 0:
        best = deform_verts
        best_loss = loss.detach()
        best_iter = 0
    elif best_loss > loss.detach():
        best = deform_verts
        best_loss = loss.detach()
        best_iter = i
        torch.save(models.state_dict(),'../../results/chamfer_%s_n.pth'%experiment_name)

    # Print the losses
    if i % plot_period==0:
        print('%d Iter: total_loss %.6f Chamfer_loss %.6f'% (i,loss,loss_chamfer))
        print('current best loss is %d: %.6f'%(best_iter,best_loss))
        
    # Optimization step
    loss.backward()
    optimizer.step()

0 Iter: total_loss 0.153925 Chamfer_loss 0.153925
current best loss is 0: 0.153925
1000 Iter: total_loss 0.025799 Chamfer_loss 0.025799
current best loss is 997: 0.025590
2000 Iter: total_loss 0.023101 Chamfer_loss 0.023101
current best loss is 1995: 0.022932
3000 Iter: total_loss 0.021752 Chamfer_loss 0.021752
current best loss is 2988: 0.021454
4000 Iter: total_loss 0.021222 Chamfer_loss 0.021222
current best loss is 3891: 0.020798
5000 Iter: total_loss 0.020432 Chamfer_loss 0.020432
current best loss is 4939: 0.020292
6000 Iter: total_loss 0.019983 Chamfer_loss 0.019983
current best loss is 5994: 0.019860
7000 Iter: total_loss 0.019922 Chamfer_loss 0.019922
current best loss is 6838: 0.019542
8000 Iter: total_loss 0.019444 Chamfer_loss 0.019444
current best loss is 7701: 0.019353
9000 Iter: total_loss 0.019252 Chamfer_loss 0.019252
current best loss is 8956: 0.019029
10000 Iter: total_loss 0.018857 Chamfer_loss 0.018857
current best loss is 9858: 0.018843
11000 Iter: total_loss 0.01

In [9]:
# Fetch the verts and faces of the final predicted mesh
new_src_mesh = src_mesh.offset_verts(best)
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)

# Scale normalize back to the original target size
final_verts = (final_verts) * scale2 + center2

# Store the predicted mesh using save_obj
save_obj('../../results/chamfer_%s_low.obj'%(experiment_name), final_verts, final_faces)
print('Done!')

Done!


In [10]:
final_chamfer,_ = chamfer_distance((final_verts.unsqueeze(0).double() - center2)/scale2, verts2.unsqueeze(0).double())
print('final Chamfer distance: %.6f'%(final_chamfer.detach().cpu().numpy()))


final Chamfer distance: 0.001417
