In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
%matplotlib notebook 
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80


from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.mesh.cube import (
    NetCube, 
    to_vertices, 
    to_stacked,
)
from argparse import Namespace


device = torch.device('cuda')

In [2]:
fitted = to_vertices(torch.load('./data/centered_32_32.pt')).to(device)
# Load the dolphin mesh.
trg_obj = os.path.join('./data/scenes/centered/meshes/centered.obj')

# We read the target 3D model using load_obj
verts, faces, aux = load_obj(trg_obj)

# verts is a FloatTensor of shape (V, 3) where V is the number of vertices in the mesh
# faces is an object which contains the following LongTensors: verts_idx, normals_idx and textures_idx
# For this tutorial, normals and textures are ignored.
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)

# We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0). 
# (scale, center) will be used to bring the predicted mesh to its original center and scale
# Note that normalizing the target mesh, speeds up the optimization but is not necessary!
# center = verts.mean(0)
# verts = verts - center
# scale = max(verts.abs().max(0)[0])
# verts = verts / scale

# We construct a Meshes structure for the target mesh
trg_mesh = Meshes(verts=[verts], faces=[faces_idx])#.to(device)
trg_mesh

<pytorch3d.structures.meshes.Meshes at 0x7f4b64038090>

In [3]:
st = to_stacked(fitted)
st = F.interpolate(st, size=16, mode='bilinear', align_corners=True)
print(st.shape)
fitted = to_vertices(st)
fitted.shape

torch.Size([6, 3, 16, 16])


torch.Size([1536, 3])

In [5]:
opt = Namespace(
    n = 82,
    nfc=128, 
    min_nfc=128, 
    ker_size=3,
    num_layer=5,
    stride=1,
    padd_size=0,
    nc_im=3,
)

n = 16
cube = NetCube(n, opt, kernel=7, sigma=2).to(device)

#radii = torch.rand(3, device=v_ref.device, requires_grad=True)
v, f = cube()
# n_ref = compute_face_normals(v_ref, f_ref)
print([f.shape for f in [v, f]])

start = cube.get_start().to(device)
src_mesh =  Meshes(verts=[start], faces=[f])
src_mesh

[torch.Size([1536, 3]), torch.Size([3068, 3])]


<pytorch3d.structures.meshes.Meshes at 0x7f4b3c4e8610>

In [6]:
cube.get_start().shape, v.shape

(torch.Size([1536, 3]), torch.Size([1536, 3]))

In [7]:
sample_trg = sample_points_from_meshes(trg_mesh, 5000)
new_src_mesh = src_mesh.offset_verts(v)
sample_src = sample_points_from_meshes(new_src_mesh, 5000)
sample_trg, sample_src

(tensor([[[-0.2805,  0.7066, -0.5080],
          [ 0.7037, -0.4157, -0.2628],
          [-0.8382, -0.5431,  0.0517],
          ...,
          [-0.7235, -0.4388,  0.2055],
          [-0.7066,  0.6313,  0.2232],
          [-0.0475, -0.3480,  0.4957]]], device='cuda:0'),
 tensor([[[ 0.4095,  0.0105,  0.3809],
          [-0.0312, -0.2941,  0.1147],
          [-0.2752, -0.2036, -0.2301],
          ...,
          [ 0.3063, -0.1639,  0.2615],
          [-0.1654, -0.3827, -0.2337],
          [ 0.4029,  0.3780,  0.6790]]], device='cuda:0',
        grad_fn=<IndexPutBackward0>))

In [8]:
#optimizer = torch.optim.SGD(cube.parameters(), lr=1.0, momentum=0.9)
optimizer = torch.optim.Adam(cube.parameters(), lr=0.0003)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0003
    weight_decay: 0
)

In [9]:
# Number of optimization steps
Niter = 5000

loop = tqdm(range(Niter))


for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    v, _ = cube()
    
    verts = v + start
    loss = F.mse_loss(verts, fitted)
    #loss = loss_chamfer
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
#     chamfer_losses.append(float(loss_chamfer.detach().cpu()))
#     edge_losses.append(float(loss_edge.detach().cpu()))
#     normal_losses.append(float(loss_normal.detach().cpu()))
#     laplacian_losses.append(float(loss_laplacian.detach().cpu()))
    
    # Plot mesh
    #if i % plot_period == 0:
    #    plot_pointcloud(new_src_mesh, title="iter: %d" % i)
        
    # Optimization step
    loss.backward()
    optimizer.step()
    
#chamfer_losses    

  0%|          | 0/5000 [00:00<?, ?it/s]

In [11]:
torch.save(cube.net.state_dict(), './data/net_fitted_n128_c128.pth')

In [10]:
import meshplot

vert =  cube.get_start().to(device) + v
meshplot.plot(vert.detach().cpu().numpy(), f.cpu().numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0178797…

<meshplot.Viewer.Viewer at 0x7f4b3c795610>

In [12]:
#optimizer = torch.optim.SGD(cube.parameters(), lr=1.0, momentum=0.9)
optimizer = torch.optim.Adam(cube.parameters(), lr=0.0003)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0003
    weight_decay: 0
)

In [13]:
# Number of optimization steps
Niter = 1000
# Weight for the chamfer loss
w_chamfer = 1.0 
# Weight for mesh edge loss
w_edge = 1.0 
# Weight for mesh normal consistency
w_normal = 0.01 
# Weight for mesh laplacian smoothing
w_laplacian = 0.1 
# Plot period for the losses
plot_period = 250
loop = tqdm(range(Niter))

chamfer_losses = []
laplacian_losses = []
edge_losses = []
normal_losses = []
no_samples = 100000


for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    v, _ = cube()
    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(v)
    
    # We sample 5k points from the surface of each mesh 
    sample_trg = sample_points_from_meshes(trg_mesh, no_samples)
    sample_src = sample_points_from_meshes(new_src_mesh, no_samples)
    
    # We compare the two sets of pointclouds by computing (a) the chamfer loss
    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
    
    #if False:
    # and (b) the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(new_src_mesh)

    # mesh normal consistency
    loss_normal = mesh_normal_consistency(new_src_mesh)

    # mesh laplacian smoothing
    loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")

    # Weighted sum of the losses
    loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
    #loss = loss_chamfer
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
    chamfer_losses.append(float(loss_chamfer.detach().cpu()))
#     edge_losses.append(float(loss_edge.detach().cpu()))
#     normal_losses.append(float(loss_normal.detach().cpu()))
#     laplacian_losses.append(float(loss_laplacian.detach().cpu()))
    
    # Plot mesh
    #if i % plot_period == 0:
    #    plot_pointcloud(new_src_mesh, title="iter: %d" % i)
        
    # Optimization step
    loss.backward()
    optimizer.step()
    
chamfer_losses    

  0%|          | 0/1000 [00:00<?, ?it/s]

  self._edges_packed = torch.stack([u // V, u % V], dim=1)


[0.09579381346702576,
 0.06859998404979706,
 0.061568960547447205,
 0.039552006870508194,
 0.03657432645559311,
 0.024924900382757187,
 0.015025515109300613,
 0.025563552975654602,
 0.020120719447731972,
 0.016099104657769203,
 0.013637600466609001,
 0.00899301003664732,
 0.012289734557271004,
 0.009848378598690033,
 0.007349794264882803,
 0.006681983359158039,
 0.00563233345746994,
 0.007141944020986557,
 0.006042477674782276,
 0.0032548531889915466,
 0.0035218545235693455,
 0.00493550393730402,
 0.004230385646224022,
 0.0030941765289753675,
 0.0031305216252803802,
 0.0035638045519590378,
 0.0032821334898471832,
 0.0027386320289224386,
 0.002413013018667698,
 0.0022477011661976576,
 0.0022777928970754147,
 0.002523296047002077,
 0.0024915568064898252,
 0.002086709253489971,
 0.001829467830248177,
 0.0018398400861769915,
 0.0018663805676624179,
 0.0017546997405588627,
 0.001707918243482709,
 0.001737711252644658,
 0.0017403742531314492,
 0.0016361859161406755,
 0.0014367098920047283,
 

In [20]:
import meshplot

vert =  start + v
meshplot.plot(vert.detach().cpu().numpy(), f.cpu().numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.019426…

<meshplot.Viewer.Viewer at 0x7f17c000cb50>

In [28]:
torch.save(vert.cpu(), './data/sampled_32_32.pt')

In [None]:
def to_vertices(stacked):
    return stacked.permute(0, 2, 3, 1).reshape(-1, 3)

In [20]:
s = torch.randn(6, 3, 8, 8)
as_vs = to_vertices(s)
as_vs.shape

torch.Size([384, 3])

In [23]:
import math

f_n = int(math.sqrt(as_vs.size(0) // 6))
f_n
as_st = as_vs.reshape(6, f_n, f_n, 3).permute(0, 3, 1, 2)
as_st.shape

torch.Size([6, 3, 8, 8])

In [24]:
torch.equal(s, as_st)

True

In [38]:
torch.save(cube.state_dict(), './data/net_fitted_32.pth')

In [28]:
cube.net

Generator(
  (head): ConvBlock(
    (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1))
    (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (body): Sequential(
    (block1): ConvBlock(
      (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block2): ConvBlock(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block3): ConvBlock(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(n

In [2]:
import math
import torch

from src.mesh.cube import to_vertices




vs = torch.randn(6*8**2, 3)

torch.equal(vs, to_vertices(to_stacked(vs)))

8


True