In [34]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.cube import make_cube_mesh
from src.discrete_laplacian import DiscreteLaplacian
from src.operators import get_gaussian

def sides_dict(n):
    return nn.ParameterDict({
        'front': nn.Parameter(torch.zeros((1, 3, n, n))),
        'back' : nn.Parameter(torch.zeros((1, 3, n, n))),
        'left' : nn.Parameter(torch.zeros((1, 3, n, n))),
        'right': nn.Parameter(torch.zeros((1, 3, n, n))),
        'top'  : nn.Parameter(torch.zeros((1, 3, n, n))),
        'down' : nn.Parameter(torch.zeros((1, 3, n, n))),
    })


class ProgressiveCube(nn.Module):
    def __init__(self, n, kernel=21, sigma=7, start=-0.5, end=0.5):
        super(ProgressiveCube, self).__init__()        
        self.n = n
        self.side_names = list(sides_dict(1).keys())
        self.params = nn.ModuleList([sides_dict(2**i)
            for i in range(1, int(math.log2(n))+1)])
        
        self.source = make_cube_mesh(n, start, end)
        self.gaussian = get_gaussian(kernel)
        #self.gaussian = DiscreteGaussian(kernel, sigma=sigma)
        self.laplacian = DiscreteLaplacian()
        clip_value = 1. / n
        for d in self.params:
            for p in d.values():
                #p.register_hook(lambda grad: torch.nan_to_num(grad))
                p.register_hook(lambda grad: self.gaussian(torch.nan_to_num(grad)))
                #p.register_hook(lambda grad: torch.clamp(self.gaussian(grad), -clip_value, clip_value))
                #p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))        

    def make_vert(self):
        return torch.cat([p[0].reshape(3, -1).t()
                          for p in self.params.values()])
    
    def scale(self, t):
        return  F.interpolate(t, self.n, mode='bilinear', align_corners=True)

    def forward(self):
        summed = {}
        for d in self.params:
            print(d)
            for key in side_keys:
                if key in summed:
                    summed[key] = summed[key] + self.scale(d[key])
                else:
                    summed[key] = self.scale(d[key])        
        ps = torch.cat([p for p in summed.values()])        
        deform_verts = ps.permute(0, 2, 3, 1).reshape(-1, 3)         
        new_src_mesh = self.source.offset_verts(deform_verts)        
        return new_src_mesh, 0 #self.laplacian(ps)
    
    def to(self, device):
        module = super(ProgressiveCube, self).to(device)        
        module.source = self.source.to(device)        
        return module
    
    def export(self, f):        
        mesh, _ = self.forward()
        mesh = mesh.detach()
        save_obj(f, mesh.verts_packed(), mesh.faces_packed())   

        
cube = Cube(8, kernel=3)

cube.forward()

ParameterDict(
    (back): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (down): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (front): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (left): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (right): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (top): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
)
ParameterDict(
    (back): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (down): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (front): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (left): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (right): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (top): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
)
ParameterDict(
    (back): Parameter containing: [torch.FloatTensor of size 1x3x8x8]
    (down): Parameter containing: [torc

(<pytorch3d.structures.meshes.Meshes at 0x7feef620f340>, 0)

In [26]:
cube.source.verts_packed().shape

torch.Size([384, 3])

In [28]:
384 / 6

64.0

In [27]:
cube

Cube(
  (params): ModuleList(
    (0): ParameterDict(
        (back): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (down): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (front): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (left): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (right): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (top): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    )
    (1): ParameterDict(
        (back): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (down): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (front): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (left): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (right): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (top): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    )
    (2): ParameterDi

In [13]:
n = 4
params = nn.ModuleList([sides_dict(2**i)
            for i in range(1, int(math.log2(n))+1)])
params

ModuleList(
  (0): ParameterDict(
      (back): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
      (down): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
      (front): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
      (left): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
      (right): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
      (top): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
  )
  (1): ParameterDict(
      (back): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
      (down): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
      (front): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
      (left): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
      (right): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
      (top): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
  )
)

In [38]:
seq.requires_grad_(False)


Sequential(
  (0): ReflectionPad2d((10, 10, 10, 10))
  (1): Conv2d(3, 3, kernel_size=(21, 21), stride=(1, 1), groups=3, bias=False)
)

In [14]:
side_keys = list(sides_dict(1).keys())
side_keys

['back', 'down', 'front', 'left', 'right', 'top']

In [15]:
F.interpolate(torch.rand(1, 3, 2, 2), n).shape

torch.Size([1, 3, 4, 4])

In [18]:
res = {}
for d in params:
    for key in side_keys:
        if key in res:
            res[key] = res[key] + F.interpolate(d[key], n, mode='bilinear', align_corners=True)
        else:
            res[key] = F.interpolate(d[key], n, mode='bilinear', align_corners=True)
res            

{'back': tensor([[[[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],
 
          [[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],
 
          [[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]]]], grad_fn=<AddBackward0>),
 'down': tensor([[[[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],
 
          [[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],
 
          [[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]]]], grad_fn=<AddBackward0>),
 'front': tensor([[[[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.]],
 
          [[0., 0., 0., 0.],
           [0., 0., 0., 0.],
           [0., 0., 0., 0.],
     