In [77]:
import torch

torch.set_printoptions(precision=2, sci_mode=False)

def get_neighbors(sides,  side_name, size):    
    pad_map = {    
        'front': ['left', 'right', 'top', 'down'],
        #'front': ['top', 'down', 'left', 'right',],
        #'front': ['right', 'left', 'down','top', ],
        #'front': ['right', 'left', 'top', 'down',],
        #'front': ['left', 'right', 'down', 'top', ],
        'right': ['front', 'back', 'top', 'down'],
        'back': ['right', 'left', 'top', 'down'],
        'left': ['back', 'front', 'top', 'down'],
        'top': ['left', 'right', 'back', 'front'],
        'down': ['left', 'right', 'front', 'back'],
    }
    l, r, t, d = [sides[n] for n in pad_map[side_name]]
    return [
        l[:, -size:, :],
        r[:, :size, :],
        t[-size:, :, :],
        d[:size, :, :],
    ]

def make_tris(size, device):
    tl = torch.triu(torch.ones(size, size)) - (torch.eye(size) * 0.5)
    dr = torch.tril(torch.ones(size, size)) - (torch.eye(size) * 0.5)    
    tl, dr = tl.to(device), dr.to(device)
    res = {
        'tl': tl,
        'lt': tl.t().clone(),        
        'tr': tl.t().flip(dims=(0,)).t(),
        'rt': tl.flip(dims=(0,)).t(),
        'dr': dr.t(),
        'rd': dr,
        'ld': dr.flip(dims=(0,)).t(),
        'dl': dr.t().flip(dims=(0,)).t(),        
    }
    for k in res:
        res[k]=  torch.ones_like(res[k])
    return res

# lt/tl tr/rt
# ld/dl dr/rd
def get_corners(size, l, r, t, d):
    tris = make_tris(size, l.device)
    lt = l[-size:, :, :] * tris['lt']
    tl = t[:, -size:, :] * tris['tl']
    
    tr = t[:, :size, :] * tris['tr']
    rt = r[-size:, :, :] * tris['rt']
    
    ld = l[:size, :, :] * tris['ld']
    dl = d[:, -size:, :] * tris['dl']
        
    rd = r[:size, :, :] * tris['rd']
    dr = d[:, :size, :] * tris['dr']
    
    ltc = lt + tl
    trc = tr + rt
    ldc = ld + dl
    drc = dr + rd
    return ltc, trc, ldc, drc  

def pad_side(sides, side_name, kernel_size):
    o = sides[side_name]
    size = (kernel_size - 1) // 2
    l, r, t, d = get_neighbors(sides, side_name, size)
    lt, tr, ld, dr = get_corners(size, l, r, t, d)

    top = torch.cat((lt, t, tr), dim=1)    
    middle = torch.cat((l, o, r), dim=1)
    down = torch.cat((ld, d, dr), dim=1)
    
    return torch.cat((top, middle, down), dim=0)  

In [78]:
from collections import (
    namedtuple,
    OrderedDict,
)
import trimesh
from src.cleansed_cube import (
    make_cube_faces,
    SourceCube,
    SimpleCube,
)

n, start, end = 7, -0.5, 0.5

device = torch.device('cuda')
d1, d2 = torch.meshgrid(
    torch.linspace(start, end, steps=n),
    torch.linspace(start, end, steps=n))
d1, d2 = d1.to(device), d2.to(device)

d3 = torch.full_like(d1, end) + 1 / n
sides =  OrderedDict({
    'front': torch.stack((+d3,  d1,  d2), dim=-1),
    'right': torch.stack(( d1, +d3,  d2), dim=-1),    
    'back' : torch.stack((-d3,  d1,  d2), dim=-1),         
    'left' : torch.stack(( d1, -d3,  d2), dim=-1),
    'top'  : torch.stack(( d1,  d2, +d3), dim=-1),
    'down' : torch.stack(( d1,  d2, -d3), dim=-1),
})
faces =  make_cube_faces(n)#.int().to(device)
faces_raw =  make_cube_faces(n+2)
sides, faces.shape
#vertices, faces = mesh.vertices.detach(), mesh.faces.detach()
#mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
#mesh.export(f)

(OrderedDict([('front',
               tensor([[[     0.64,     -0.50,     -0.50],
                        [     0.64,     -0.50,     -0.33],
                        [     0.64,     -0.50,     -0.17],
                        [     0.64,     -0.50,     -0.00],
                        [     0.64,     -0.50,      0.17],
                        [     0.64,     -0.50,      0.33],
                        [     0.64,     -0.50,      0.50]],
               
                       [[     0.64,     -0.33,     -0.50],
                        [     0.64,     -0.33,     -0.33],
                        [     0.64,     -0.33,     -0.17],
                        [     0.64,     -0.33,     -0.00],
                        [     0.64,     -0.33,      0.17],
                        [     0.64,     -0.33,      0.33],
                        [     0.64,     -0.33,      0.50]],
               
                       [[     0.64,     -0.17,     -0.50],
                        [     0.64,     -0.17,     -0.33]

In [82]:
def export(f, sides, faces):
#     ps = torch.cat([p for p in sides.values()])
#     vertices = ps.permute(0, 2, 3, 1).reshape(-1, 3)   
    vertices = torch.stack(list(sides.values())).reshape(-1, 3)
    vertices = vertices.cpu().detach()
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    mesh.export(f)
f = './cube.stl'
export(f, sides, faces);   

In [83]:
conv = torch.nn.Conv2d(3, 3, 3, groups=3, bias=False).to(device)
conv.weight.data = torch.ones_like(conv.weight.data) / 9
conv.weight.data

tensor([[[[0.11, 0.11, 0.11],
          [0.11, 0.11, 0.11],
          [0.11, 0.11, 0.11]]],


        [[[0.11, 0.11, 0.11],
          [0.11, 0.11, 0.11],
          [0.11, 0.11, 0.11]]],


        [[[0.11, 0.11, 0.11],
          [0.11, 0.11, 0.11],
          [0.11, 0.11, 0.11]]]], device='cuda:0')

In [84]:
res, kernel_size = {}, 3
raw = {}
for side_name in sides:
    if True or side_name in ['front']:
        padded = pad_side(sides, side_name, kernel_size)
        raw[side_name] = padded.clone()
        padded = padded.permute(2, 0, 1)[None]
        print(sides[side_name].shape, padded.shape, conv(padded).shape)
        
        padded = conv(padded)        
        res[side_name] = padded[0].permute(1, 2, 0)
    else:
        res[side_name] = sides[side_name]
f = f'./cube_{kernel_size}.stl'    
export(f, res, faces);
f = f'./cube_raw_{kernel_size}.stl'    
export(f, raw, faces_raw);
f

torch.Size([7, 7, 3]) torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 7, 7])
torch.Size([7, 7, 3]) torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 7, 7])
torch.Size([7, 7, 3]) torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 7, 7])
torch.Size([7, 7, 3]) torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 7, 7])
torch.Size([7, 7, 3]) torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 7, 7])
torch.Size([7, 7, 3]) torch.Size([1, 3, 9, 9]) torch.Size([1, 3, 7, 7])


'./cube_raw_3.stl'

In [75]:
res['left']

tensor([[[    -0.22,     -0.42,      0.13],
         [    -0.11,     -0.54,     -0.01],
         [    -0.11,     -0.48,      0.10],
         [    -0.11,     -0.43,      0.21],
         [    -0.11,     -0.37,      0.33],
         [    -0.11,     -0.32,      0.44],
         [     0.13,     -0.29,      0.22]],

        [[    -0.44,     -0.54,     -0.11],
         [    -0.33,     -0.64,     -0.33],
         [    -0.33,     -0.64,     -0.17],
         [    -0.33,     -0.64,     -0.00],
         [    -0.33,     -0.64,      0.17],
         [    -0.33,     -0.64,      0.33],
         [    -0.01,     -0.54,      0.11]],

        [[    -0.33,     -0.48,     -0.11],
         [    -0.17,     -0.64,     -0.33],
         [    -0.17,     -0.64,     -0.17],
         [    -0.17,     -0.64,     -0.00],
         [    -0.17,     -0.64,      0.17],
         [    -0.17,     -0.64,      0.33],
         [     0.10,     -0.48,      0.11]],

        [[    -0.21,     -0.43,     -0.11],
         [     0.00,     -

In [76]:
raw['front']

tensor([[[     0.50,     -0.07,      0.57],
         [     0.50,     -0.50,      0.64],
         [     0.50,     -0.33,      0.64],
         [     0.50,     -0.17,      0.64],
         [     0.50,     -0.00,      0.64],
         [     0.50,      0.17,      0.64],
         [     0.50,      0.33,      0.64],
         [     0.50,      0.50,      0.64],
         [     0.50,      0.07,      0.07]],

        [[    -0.50,     -0.64,      0.50],
         [     0.64,     -0.50,     -0.50],
         [     0.64,     -0.50,     -0.33],
         [     0.64,     -0.50,     -0.17],
         [     0.64,     -0.50,     -0.00],
         [     0.64,     -0.50,      0.17],
         [     0.64,     -0.50,      0.33],
         [     0.64,     -0.50,      0.50],
         [    -0.50,      0.64,     -0.50]],

        [[    -0.33,     -0.64,      0.50],
         [     0.64,     -0.33,     -0.50],
         [     0.64,     -0.33,     -0.33],
         [     0.64,     -0.33,     -0.17],
         [     0.64,     -0.