In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.cube import make_cube_mesh
from src.operators import get_gaussian

def sides_dict(n):
    return nn.ParameterDict({
        'front': nn.Parameter(torch.zeros((1, 3, n, n))),
        'back' : nn.Parameter(torch.zeros((1, 3, n, n))),
        'left' : nn.Parameter(torch.zeros((1, 3, n, n))),
        'right': nn.Parameter(torch.zeros((1, 3, n, n))),
        'top'  : nn.Parameter(torch.zeros((1, 3, n, n))),
        'down' : nn.Parameter(torch.zeros((1, 3, n, n))),
    })

class Cube(nn.Module):
    def __init__(self, n, kernel=21, sigma=7, start=-0.5, end=0.5):
        super(Cube, self).__init__()
        assert math.log2(n).is_integer(), f"n must be power of 2, n={n}"
        
        self.n = n
        self.params = nn.ModuleList([sides_dict(2**i) for i in range(1, int(math.log2(n)))])      
        self.source = make_cube_mesh(n, start, end)
        #self.gaussian = get_gaussian(kernel)
        #self.gaussian = GaussianLayer(kernel, sigma=sigma)
        clip_value = 1. / n
        for d in self.params:
            print(d)
            for p in d.values():
                #p.register_hook(lambda grad: torch.nan_to_num(grad))
                p.register_hook(lambda grad: self.gaussian(torch.nan_to_num(grad)))
                #p.register_hook(lambda grad: torch.clamp(self.gaussian(grad), -clip_value, clip_value))
                #p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))
    
    def make_vert(self):
        return torch.cat([torch.sigmoid(p)[0].reshape(3, -1).t()
                          for p in self.params.values()])           

    def forward(self):
        deform_verts = self.make_vert()
        new_src_mesh = self.source.offset_verts(deform_verts)
        return new_src_mesh
    
    def to(self, device):
        module = super(Cube, self).to(device)        
        module.source = self.source.to(device)        
        return module
    
    def export(self, f):
        mesh = self.forward().detach()
        save_obj(f, mesh.verts_packed(), mesh.faces_packed())   

cube = Cube(16, kernel=3)
cube

ParameterDict(
    (back): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (down): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (front): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (left): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (right): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    (top): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
)
ParameterDict(
    (back): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (down): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (front): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (left): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (right): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    (top): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
)
ParameterDict(
    (back): Parameter containing: [torch.FloatTensor of size 1x3x8x8]
    (down): Parameter containing: [torc

Cube(
  (params): ModuleList(
    (0): ParameterDict(
        (back): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (down): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (front): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (left): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (right): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
        (top): Parameter containing: [torch.FloatTensor of size 1x3x2x2]
    )
    (1): ParameterDict(
        (back): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (down): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (front): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (left): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (right): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
        (top): Parameter containing: [torch.FloatTensor of size 1x3x4x4]
    )
    (2): ParameterDi

In [37]:
seq = nn.Sequential(
    nn.ReflectionPad2d(10), 
    nn.Conv2d(3, 3, 21, stride=1, padding=0, bias=None, groups=3)
)

In [38]:
seq.requires_grad_(False)


Sequential(
  (0): ReflectionPad2d((10, 10, 10, 10))
  (1): Conv2d(3, 3, kernel_size=(21, 21), stride=(1, 1), groups=3, bias=False)
)

In [27]:
int(math.log2(8))

3