In [50]:
import torch
import torch.nn  as nn


from src.operators import get_gaussian

from src.cleansed_cube import SourceCube, sides_dict
from src.discrete_laplacian import DiscreteLaplacian
from src.discrete_gaussian import DiscreteGaussian
from src.padding import pad_side

class SimpleCube(nn.Module):
    def __init__(self, n, kernel=5, sigma=3, clip_value = 0.1, start=-0.5, end=0.5):
        super(SimpleCube, self).__init__()        
        self.n = n
        self.params = sides_dict(n)
        self.source = SourceCube(n, start, end)
        self.gaussian = get_gaussian(kernel)
        #self.gaussian = DiscreteGaussian(kernel, sigma=sigma)
        self.laplacian = DiscreteLaplacian()        
        for p in self.params.values():            
            p.register_hook(lambda grad: torch.clamp(
                torch.nan_to_num(grad), -clip_value, clip_value))

    def make_vert(self):
        return torch.cat([p[0].reshape(3, -1).t()
                          for p in self.params.values()]) 

    def forward(self):
        ps = torch.cat([p for p in self.params.values()])
        deform_verts = ps.permute(0, 2, 3, 1).reshape(-1, 3)        
        new_src_mesh = self.source(deform_verts)        
        return new_src_mesh, 0 # self.laplacian(ps)
    
    def smooth(self):
        pass
    
cube = SimpleCube(9)
cube

SimpleCube(
  (params): ParameterDict(
      (back): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (down): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (front): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (left): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (right): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
      (top): Parameter containing: [torch.FloatTensor of size 1x3x9x9]
  )
  (source): SourceCube()
  (gaussian): Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False, padding_mode=replicate)
  (laplacian): DiscreteLaplacian(
    (seq): Sequential(
      (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), groups=3, bias=False)
    )
  )
)

In [51]:
(vertices, _, _), _ = cube.forward()
vertices.shape

torch.Size([1, 486, 4])

In [52]:
loss = torch.mean(vertices)
loss

tensor(0.2500, grad_fn=<MeanBackward0>)

In [53]:
cube.params['back'].grad

In [54]:
loss.backward()
cube.params['back'].grad

tensor([[[[0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005]],

         [[0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.000

In [55]:
sides = {}
for side_name in  cube.params:
    sides[side_name] = cube.params[side_name].grad[0].permute(1, 2, 0)
    #print(side_name)
sides['front'].shape

torch.Size([9, 9, 3])

In [46]:
padded = {}
kernel_size = 3

for side_name in sides:
    padded[side_name] = pad_side(sides, side_name, kernel_size)
    
padded['back'].shape

torch.Size([11, 11, 3])

In [49]:
nn.ReflectionPad2d(0)(torch.rand(1, 3, 4, 4)).shape

torch.Size([1, 3, 4, 4])

In [57]:
gaussian = DiscreteGaussian(kernel_size, sigma=1, padding=False)
gaussian

3


DiscreteGaussian(
  (seq): Sequential(
    (0): ReflectionPad2d((0, 0, 0, 0))
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), groups=3, bias=False)
  )
)

In [59]:
res = {}
for side_name in padded:
    res[side_name] = gaussian(padded[side_name].permute(2, 0, 1)[None])
res['back'].shape

torch.Size([1, 3, 9, 9])

In [61]:
t = cube.params['back'].grad
t

tensor([[[[0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005]],

         [[0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005, 0.0005,
           0.0005],
          [0.0005, 0.0005, 0.0005, 0.0005, 0.000

In [None]:
t