In [1]:
import numpy as np
import os
import sys
import torch
import pytorch3d

import matplotlib.pyplot as plt

import trimesh
from pytorch3d.io import load_objs_as_meshes, save_obj
import numpy as np
from tqdm.notebook import tqdm

#from src.cleansed_cube import Cube

from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
from pytorch3d.io import load_objs_as_meshes, save_obj

from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    TexturesVertex,
)

from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.plot_image_grid import image_grid

In [2]:
import torch
import torch.nn  as nn


from src.operators import get_gaussian

from src.cleansed_cube import SourceCube, sides_dict
from src.discrete_laplacian import DiscreteLaplacian
from src.discrete_gaussian import DiscreteGaussian
from src.padding import pad_side

class SimpleCube(nn.Module):
    def __init__(self, n, kernel=5, sigma=1, clip_value = 0.1, start=-0.5, end=0.5):
        super(SimpleCube, self).__init__()        
        self.n = n
        self.kernel = kernel
        self.params = sides_dict(n)
        self.source = SourceCube(n, start, end)
        #self.gaussian = get_gaussian(kernel)
        self.gaussian = DiscreteGaussian(kernel, sigma=sigma, padding=False)
        self.laplacian = DiscreteLaplacian()          
        for p in self.params.values():            
            p.register_hook(lambda grad: torch.clamp(
                torch.nan_to_num(grad), -clip_value, clip_value))

    def make_vert(self):
        return torch.cat([p[0].reshape(3, -1).t()
                          for p in self.params.values()]) 

    def forward(self):
        ps = torch.cat([p for p in self.params.values()])        
        deform_verts = ps.permute(0, 2, 3, 1).reshape(-1, 3)        
        new_src_mesh = self.source(deform_verts)        
        laplacian = self.final_laplacian(new_src_mesh.vertices) 
        return new_src_mesh, laplacian # self.laplacian(ps)
    
    def final_laplacian(self, vert):
        side_names = ['front', 'right', 'back', 'left', 'top', 'down']
        sides = {}
        vertices = vert[0, :, :3].reshape(6, self.n, self.n, 3)
        for i, side_name in enumerate(side_names):
            sides[side_name] = vertices[i]
        res = 0
        for side_name in side_names:
            padded = pad_side(sides, side_name, self.kernel)
            padded = padded.permute(2, 0, 1)[None]
            res += self.laplacian(padded) * 1/6
        return res
    
    def smooth(self):
        sides = {}
        for side_name in self.params:
            grad = self.params[side_name].grad[0]
            sides[side_name] = grad.permute(1, 2, 0)
            
        for side_name in self.params:
            padded = pad_side(sides, side_name, self.kernel)
            padded = padded.permute(2, 0, 1)[None]
            padded = self.gaussian(padded)
            self.params[side_name].grad.copy_(padded)
    
    def laplacian_loss(self):
        sides, loss = {}, 0
        for side_name in self.params:
            side = self.params[side_name]    
            sides[side_name] = side[0].permute(1, 2, 0)

        for side_name in self.params:
            padded = pad_side(sides, side_name, self.kernel)    
            padded = padded.permute(2, 0, 1)[None]
            loss += self.laplacian(padded) * 1/6
            
        return loss
      
    def export(self, f):        
        mesh, _ = self.forward()
        vertices = mesh.vertices[0].cpu().detach()
        faces = mesh.faces.cpu().detach()        
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
        mesh.export(f)

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")


obj_filename = "./data/bunny.obj"
obj_filename = "./data/nefertiti.obj"
bunny = load_objs_as_meshes([obj_filename], device=device)

textures=TexturesVertex(verts_features=[torch.ones_like(bunny.verts_packed())])

bunny = Meshes(verts=[bunny.verts_packed()], 
           faces=[bunny.faces_packed()],
           textures=textures)

verts = bunny.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
bunny.offset_verts_(-center)
bunny.scale_verts_((1.0 / float(scale)));

In [4]:
sample_trg = sample_points_from_meshes(bunny, 5000)
sample_trg
#sample_src = sample_points_from_meshes(new_src_mesh, 5000)

tensor([[[-0.2139, -0.9915,  0.0452],
         [ 0.3887,  0.5209,  0.1039],
         [-0.1458, -0.7027, -0.0133],
         ...,
         [-0.2696,  0.6226,  0.0899],
         [-0.1215,  0.6017,  0.4080],
         [ 0.3772,  0.4916, -0.1434]]], device='cuda:0')

In [25]:
n, kernel, sigma = 32, 7, 2

cube = SimpleCube(n, kernel=kernel, sigma=sigma, clip_value = 1.).to(device)

optimizer = torch.optim.SGD(cube.parameters(), lr=0.1, momentum=0.)
#optimizer = torch.optim.Adam(cube.parameters(), lr=0.001)
optimizer

7


SGD (
Parameter Group 0
    dampening: 0
    lr: 0.1
    momentum: 0.0
    nesterov: False
    weight_decay: 0
)

In [None]:
Niter = 20001
num_views_per_iteration = 2

loop = tqdm(range(Niter))

laplace_weight =  1.

w_chamfer = 1.0
w_laplacian = 0.1 
w_normal = 0.01 
w_edge = 1.0 

for i in loop:
    stop = False
    
    optimizer.zero_grad()
    
    new_src_mesh, laplace_loss = cube.forward()        
    verts=new_src_mesh.vertices[:, :, :3]
    textures = TexturesVertex(verts_features=torch.ones_like(verts))
    p3d_mesh = Meshes(verts=verts,
                  faces=new_src_mesh.faces[None],
                  textures=textures)
    
    sample_src = sample_points_from_meshes(p3d_mesh, 15000)
    sample_trg = sample_points_from_meshes(bunny, 15000)
    
    loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
    loss_laplacian = mesh_laplacian_smoothing(p3d_mesh, method="uniform")
    loss_normal = mesh_normal_consistency(p3d_mesh)
    loss_edge = mesh_edge_loss(p3d_mesh)
    
    
    
    
    #laplacian_smoothing = mesh_laplacian_smoothing(new_src_mesh, method="uniform")
    sum_loss = torch.tensor(0.0, device=device) 
    
    sum_loss += loss_chamfer * w_chamfer
    #sum_loss += laplace_loss * w_laplacian
    sum_loss += loss_laplacian * w_laplacian
    sum_loss += loss_normal * w_normal
    sum_loss += loss_edge * w_edge
    
    loop.set_description("total_loss = %.6f" % sum_loss)
    
    sum_loss.backward()
    #cube.smooth()
    optimizer.step()
    
    if i % 500 == 0:
        f = f'./data/cube_mesh_{n}_{i}.obj'
        cube.export(f)
        print(f)
#kernel = 0
f = f'./data/cube_mesh_{n}_{kernel}.obj'
cube.export(f)
f

HBox(children=(FloatProgress(value=0.0, max=20001.0), HTML(value='')))

./data/cube_mesh_32_0.obj
./data/cube_mesh_32_500.obj
./data/cube_mesh_32_1000.obj
./data/cube_mesh_32_1500.obj
./data/cube_mesh_32_2000.obj
./data/cube_mesh_32_2500.obj
./data/cube_mesh_32_3000.obj
