In [13]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import trimesh


from src.operators import get_gaussian

from src.cleansed_cube import SourceCube, sides_dict
from src.discrete_laplacian import DiscreteLaplacian
from src.discrete_gaussian import DiscreteGaussian
from src.padding import pad_side

class ProgressiveCube(nn.Module):
    def __init__(self, n, kernel=3, sigma=1, clip=None, start=-0.5, end=0.5):
        super(ProgressiveCube, self).__init__()        
        self.n = n
        self.kernel = kernel
        self.side_names = list(sides_dict(1).keys())
        self.params = nn.ModuleList([sides_dict(2**i)
            for i in range(1, int(math.log2(n))+1)])
        
        self.source = SourceCube(n, start, end)
        #self.gaussian = get_gaussian(kernel)
        self.gaussian = DiscreteGaussian(kernel, sigma=sigma, padding=False)
        self.laplacian = DiscreteLaplacian()
        clip = clip or 1. / n
        for d in self.params:
            for p in d.values():                
                p.register_hook(lambda grad:
                    torch.clamp(torch.nan_to_num(grad), -clip, clip))

    def make_vert(self):
        return torch.cat([p[0].reshape(3, -1).t()
                          for p in self.params.values()])

    def scale(self, t):
        return  F.interpolate(t, self.n, mode='bilinear', align_corners=True)

    def forward(self):
        summed = {}
        for d in self.params:            
            for key in self.side_names:
                if key in summed:
                    summed[key] = summed[key] + self.scale(d[key])
                else:
                    summed[key] = self.scale(d[key])        
        ps = torch.cat([p for p in summed.values()])        
        deform_verts = ps.permute(0, 2, 3, 1).reshape(-1, 3)         
        new_src_mesh = self.source(deform_verts)        
        return new_src_mesh, 0#self.laplacian(ps)    
    
    def smooth(self):
        for i in range(len(self.params)):
            params, sides = self.params[i], {}
            for side_name in params:
                grad = params[side_name].grad[0]        
                sides[side_name] = grad.permute(1, 2, 0)

            for side_name in params:
                padded = pad_side(sides, side_name, self.kernel)
                padded = padded.permute(2, 0, 1)[None]
                padded = self.gaussian(padded)
                self.params[i][side_name].grad.copy_(padded)

    def export(self, f):        
        mesh, _ = self.forward()
        vertices = mesh.vertices[0].cpu().detach()
        faces = mesh.faces.cpu().detach()        
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
        mesh.export(f)
                
cube = ProgressiveCube(8)
cube

3


ProgressiveCube(
  (params): ModuleList(
    (0): ParameterDict(
        (back): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (down): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (front): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (left): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (right): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (top): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    )
    (1): ParameterDict(
        (back): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (down): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (front): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (left): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (right): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (top): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    )
    (2): 

In [14]:
(vertices, _, _), _ = cube.forward()
loss = torch.mean(vertices)
loss.backward()
cube.smooth()

In [15]:
cube.export('./data/f.obj')