In [7]:
import torch
import torch.nn  as nn


from src.operators import get_gaussian

from src.cleansed_cube import SourceCube, sides_dict
from src.discrete_laplacian import DiscreteLaplacian
from src.discrete_gaussian import DiscreteGaussian
from src.padding import pad_side

class SimpleCube(nn.Module):
    def __init__(self, n, kernel=5, sigma=1, clip_value = 0.1, start=-0.5, end=0.5):
        super(SimpleCube, self).__init__()        
        self.n = n
        self.kernel = kernel
        self.params = sides_dict(n)
        self.source = SourceCube(n, start, end)
        #self.gaussian = get_gaussian(kernel)
        self.gaussian = DiscreteGaussian(kernel, sigma=sigma, padding=False)
        self.laplacian = DiscreteLaplacian()          
        for p in self.params.values():            
            p.register_hook(lambda grad: torch.clamp(
                torch.nan_to_num(grad), -clip_value, clip_value))

    def make_vert(self):
        return torch.cat([p[0].reshape(3, -1).t()
                          for p in self.params.values()]) 

    def forward(self):
        ps = torch.cat([p for p in self.params.values()])
        deform_verts = ps.permute(0, 2, 3, 1).reshape(-1, 3)        
        new_src_mesh = self.source(deform_verts)        
        return new_src_mesh, 0 # self.laplacian(ps)
    
    def smooth(self):
        sides = {}
        for side_name in self.params:
            grad = cube.params[side_name].grad[0]            
            sides[side_name] = grad.permute(1, 2, 0)
            
        for side_name in self.params:
            padded = pad_side(sides, side_name, self.kernel)
            padded = padded.permute(2, 0, 1)[None]
            padded = self.gaussian(padded)
            self.params[side_name].grad.copy_(padded)
                
cube = SimpleCube(9)
cube

5


SimpleCube(
  (params): ParameterDict(
      (back): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (down): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (front): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (left): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (right): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (top): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
  )
  (source): SourceCube()
  (gaussian): DiscreteGaussian(
    (seq): Sequential(
      (0): ReflectionPad2d((0, 0, 0, 0))
      (1): Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1), groups=3, bias=False)
    )
  )
  (laplacian): DiscreteLaplacian(
    (seq): Sequential(
      (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), groups=3, bias=False)
    )
  )
)

In [8]:
(vertices, _, _), _ = cube.forward()
loss = torch.mean(vertices)
loss.backward()
cube.smooth()

torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 9, 9])
torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 9, 9])
torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 9, 9])
torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 9, 9])
torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 9, 9])
torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 9, 9])
