In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
%matplotlib notebook 
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80


from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.mesh.cube import NetCube, to_vertices
from argparse import Namespace


device = torch.device('cuda')

In [2]:
fitted = to_vertices(torch.load('./data/centered_32_32.pt')).to(device)
# Load the dolphin mesh.
trg_obj = os.path.join('./data/scenes/centered/meshes/centered.obj')

# We read the target 3D model using load_obj
verts, faces, aux = load_obj(trg_obj)

# verts is a FloatTensor of shape (V, 3) where V is the number of vertices in the mesh
# faces is an object which contains the following LongTensors: verts_idx, normals_idx and textures_idx
# For this tutorial, normals and textures are ignored.
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)

# We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0). 
# (scale, center) will be used to bring the predicted mesh to its original center and scale
# Note that normalizing the target mesh, speeds up the optimization but is not necessary!
# center = verts.mean(0)
# verts = verts - center
# scale = max(verts.abs().max(0)[0])
# verts = verts / scale

# We construct a Meshes structure for the target mesh
trg_mesh = Meshes(verts=[verts], faces=[faces_idx])#.to(device)
trg_mesh

<pytorch3d.structures.meshes.Meshes at 0x7f8db1d0a750>

In [8]:
opt = Namespace(
    n = 82,
    nfc=64, 
    min_nfc=64, 
    ker_size=3,
    num_layer=5,
    stride=1,
    padd_size=0,
    nc_im=3,
)

n = 32
cube = NetCube(n, opt, kernel=7, sigma=2).to(device)

#radii = torch.rand(3, device=v_ref.device, requires_grad=True)
v, f = cube()
# n_ref = compute_face_normals(v_ref, f_ref)
print([f.shape for f in [v, f]])

start = cube.get_start().to(device)
src_mesh =  Meshes(verts=[start], faces=[f])
src_mesh

[torch.Size([6144, 3]), torch.Size([12284, 3])]


<pytorch3d.structures.meshes.Meshes at 0x7f8d279738d0>

In [9]:
cube.get_start().shape, v.shape

(torch.Size([6144, 3]), torch.Size([6144, 3]))

In [10]:
sample_trg = sample_points_from_meshes(trg_mesh, 5000)
new_src_mesh = src_mesh.offset_verts(v)
sample_src = sample_points_from_meshes(new_src_mesh, 5000)
sample_trg, sample_src

(tensor([[[-0.7598,  0.2454, -0.1011],
          [ 0.4518,  0.3388, -0.5362],
          [-0.4601,  0.0573, -0.4956],
          ...,
          [-0.2578,  0.9072, -0.4133],
          [ 0.2198, -0.4969, -0.5312],
          [ 0.6108, -0.6324,  0.1658]]], device='cuda:0'),
 tensor([[[-0.7310,  0.4013,  0.0498],
          [ 0.1120, -0.7909,  0.2986],
          [-0.3991,  0.4612, -0.6683],
          ...,
          [-0.4673,  0.4728, -0.8830],
          [-0.3563,  0.0799, -0.0345],
          [-0.2652, -0.2571,  1.1032]]], device='cuda:0',
        grad_fn=<IndexPutBackward0>))

In [11]:
#optimizer = torch.optim.SGD(cube.parameters(), lr=1.0, momentum=0.9)
optimizer = torch.optim.Adam(cube.parameters(), lr=0.0003)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0003
    weight_decay: 0
)

In [12]:
# Number of optimization steps
Niter = 5000

loop = tqdm(range(Niter))


for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    v, _ = cube()
    
    verts = v + start
    loss = F.mse_loss(verts, fitted)
    #loss = loss_chamfer
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
#     chamfer_losses.append(float(loss_chamfer.detach().cpu()))
#     edge_losses.append(float(loss_edge.detach().cpu()))
#     normal_losses.append(float(loss_normal.detach().cpu()))
#     laplacian_losses.append(float(loss_laplacian.detach().cpu()))
    
    # Plot mesh
    #if i % plot_period == 0:
    #    plot_pointcloud(new_src_mesh, title="iter: %d" % i)
        
    # Optimization step
    loss.backward()
    optimizer.step()
    
#chamfer_losses    

  0%|          | 0/5000 [00:00<?, ?it/s]

In [13]:
torch.save(cube.net.state_dict(), './data/net_fitted_n32_c64.pth')

In [14]:
import meshplot

vert =  cube.get_start().to(device) + v
meshplot.plot(vert.detach().cpu().numpy(), f.cpu().numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.052400…

<meshplot.Viewer.Viewer at 0x7f8d27973510>

In [10]:
# Number of optimization steps
Niter = 1000
# Weight for the chamfer loss
w_chamfer = 1.0 
# Weight for mesh edge loss
w_edge = 1.0 
# Weight for mesh normal consistency
w_normal = 0.01 
# Weight for mesh laplacian smoothing
w_laplacian = 0.1 
# Plot period for the losses
plot_period = 250
loop = tqdm(range(Niter))

chamfer_losses = []
laplacian_losses = []
edge_losses = []
normal_losses = []



for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    v, _ = cube()
    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(v)
    
    # We sample 5k points from the surface of each mesh 
    sample_trg = sample_points_from_meshes(trg_mesh, 5000)
    sample_src = sample_points_from_meshes(new_src_mesh, 5000)
    
    # We compare the two sets of pointclouds by computing (a) the chamfer loss
    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
    
    #if False:
    # and (b) the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(new_src_mesh)

    # mesh normal consistency
    loss_normal = mesh_normal_consistency(new_src_mesh)

    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")

    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
    #loss = loss_chamfer
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
    chamfer_losses.append(float(loss_chamfer.detach().cpu()))
#     edge_losses.append(float(loss_edge.detach().cpu()))
#     normal_losses.append(float(loss_normal.detach().cpu()))
#     laplacian_losses.append(float(loss_laplacian.detach().cpu()))
    
    # Plot mesh
    #if i % plot_period == 0:
    #    plot_pointcloud(new_src_mesh, title="iter: %d" % i)
        
    # Optimization step
    loss.backward()
    optimizer.step()
    
chamfer_losses    

  0%|          | 0/1000 [00:00<?, ?it/s]

  self._edges_packed = torch.stack([u // V, u % V], dim=1)


[0.08580422401428223,
 0.11131545156240463,
 0.11299756169319153,
 0.09337344020605087,
 0.10640192031860352,
 0.06829623878002167,
 0.09871694445610046,
 0.07358501106500626,
 0.08694301545619965,
 0.06982751935720444,
 0.07825282961130142,
 0.06410867720842361,
 0.043873079121112823,
 0.06362723559141159,
 0.029885444790124893,
 0.03428409993648529,
 0.031234625726938248,
 0.02693854086101055,
 0.025703150779008865,
 0.01998584158718586,
 0.01962932199239731,
 0.02028483711183071,
 0.018270572647452354,
 0.017216715961694717,
 0.016716625541448593,
 0.014117913320660591,
 0.012886292301118374,
 0.011267879977822304,
 0.012165969237685204,
 0.010828897356987,
 0.009888475760817528,
 0.009792844764888287,
 0.009122336283326149,
 0.010015508159995079,
 0.0093916617333889,
 0.009002873674035072,
 0.00825746264308691,
 0.007931019179522991,
 0.007783900480717421,
 0.007359719835221767,
 0.006989621557295322,
 0.006963356398046017,
 0.0070494553074240685,
 0.006698760204017162,
 0.00654947

In [11]:
import meshplot

vert =  start + v
meshplot.plot(vert.detach().cpu().numpy(), f.cpu().numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0020848…

<meshplot.Viewer.Viewer at 0x7fcc24ce5d90>

In [28]:
torch.save(vert.cpu(), './data/sampled_32_32.pt')

In [None]:
def to_vertices(stacked):
    return stacked.permute(0, 2, 3, 1).reshape(-1, 3)

In [20]:
s = torch.randn(6, 3, 8, 8)
as_vs = to_vertices(s)
as_vs.shape

torch.Size([384, 3])

In [23]:
import math

f_n = int(math.sqrt(as_vs.size(0) // 6))
f_n
as_st = as_vs.reshape(6, f_n, f_n, 3).permute(0, 3, 1, 2)
as_st.shape

torch.Size([6, 3, 8, 8])

In [24]:
torch.equal(s, as_st)

True

In [38]:
torch.save(cube.state_dict(), './data/net_fitted_32.pth')

In [28]:
cube.net

Generator(
  (head): ConvBlock(
    (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1))
    (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (body): Sequential(
    (block1): ConvBlock(
      (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block2): ConvBlock(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block3): ConvBlock(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(n

In [None]:
import math

def to_stacked():
f_n = int(math.sqrt(as_vs.size(0) // 6))
f_n
as_st = as_vs.reshape(6, f_n, f_n, 3).permute(0, 3, 1, 2)
as_st.shape