In [1]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(
    precision=1,    
    sci_mode=False)

In [2]:
n, start, end = 2, -0.5, 0.5

d1, d2 = torch.meshgrid(
        torch.linspace(start, end, steps=n),
        torch.linspace(start, end, steps=n))
d3 = torch.full_like(d1, end) + 1 / n
sides =  OrderedDict({
    'front': torch.stack((+d3,  d1,  d2), dim=-1),
    'right': torch.stack(( d1, +d3,  d2), dim=-1),    
    'back' : torch.stack((-d3,  d1,  d2), dim=-1),         
    'left' : torch.stack(( d1, -d3,  d2), dim=-1),
    'top'  : torch.stack(( d1,  d2, +d3), dim=-1),
    'down' : torch.stack(( d1,  d2, -d3), dim=-1),
})

sides

OrderedDict([('front',
              tensor([[[ 1.0, -0.5, -0.5],
                       [ 1.0, -0.5,  0.5]],
              
                      [[ 1.0,  0.5, -0.5],
                       [ 1.0,  0.5,  0.5]]])),
             ('right',
              tensor([[[-0.5,  1.0, -0.5],
                       [-0.5,  1.0,  0.5]],
              
                      [[ 0.5,  1.0, -0.5],
                       [ 0.5,  1.0,  0.5]]])),
             ('back',
              tensor([[[-1.0, -0.5, -0.5],
                       [-1.0, -0.5,  0.5]],
              
                      [[-1.0,  0.5, -0.5],
                       [-1.0,  0.5,  0.5]]])),
             ('left',
              tensor([[[-0.5, -1.0, -0.5],
                       [-0.5, -1.0,  0.5]],
              
                      [[ 0.5, -1.0, -0.5],
                       [ 0.5, -1.0,  0.5]]])),
             ('top',
              tensor([[[-0.5, -0.5,  1.0],
                       [-0.5,  0.5,  1.0]],
              
                   

In [3]:
t0 = torch.arange(0, 4, 1).reshape(2, 2)
t1 = t0 + 10
t2 = t0 + 20
t3 = t0 + 30
t4 = t0 + 40
t5 = t0 + 50

t0, t1, t2, t3, t4, t5

(tensor([[0, 1],
         [2, 3]]),
 tensor([[10, 11],
         [12, 13]]),
 tensor([[20, 21],
         [22, 23]]),
 tensor([[30, 31],
         [32, 33]]),
 tensor([[40, 41],
         [42, 43]]),
 tensor([[50, 51],
         [52, 53]]))

In [4]:
torch.cat((t3, t0, t1))

tensor([[30, 31],
        [32, 33],
        [ 0,  1],
        [ 2,  3],
        [10, 11],
        [12, 13]])

In [5]:
pad_map = {    
    'front': ['left', 'right', 'top', 'bottom'],
    'right': ['front', 'back', 'top', 'bottom'],
    'back': ['right', 'left', 'top', 'bottom'],
    'left': ['back', 'front', 'top', 'bottom'],
    'top': ['left', 'right', 'back', 'front'],
    'bottom': ['left', 'right', 'front', 'back'],
}
pad_map

{'front': ['left', 'right', 'top', 'bottom'],
 'right': ['front', 'back', 'top', 'bottom'],
 'back': ['right', 'left', 'top', 'bottom'],
 'left': ['back', 'front', 'top', 'bottom'],
 'top': ['left', 'right', 'back', 'front'],
 'bottom': ['left', 'right', 'front', 'back']}

In [6]:
kernel = 5 
pad_size = (kernel - 1) // 2
pad_size

2

In [7]:
torch.eye(2) * 0.5

tensor([[0.5, 0.0],
        [0.0, 0.5]])

In [10]:
top_left = torch.triu(torch.ones(3, 3)) - (torch.eye(3)* 0.5)
left_top = top_left.t() 
top_left, left_top

top_right = top_left.flip(dims=(0,)).t()
right_top = left_top.flip(dims=(0,)).t()
top_right, right_top

down_right = torch.tril(torch.ones(3, 3)) - (torch.eye(3)* 0.5)
right_down = down_right.t()
down_right, right_down

(tensor([[0.5, 0.0, 0.0],
         [1.0, 0.5, 0.0],
         [1.0, 1.0, 0.5]]),
 tensor([[0.5, 1.0, 1.0],
         [0.0, 0.5, 1.0],
         [0.0, 0.0, 0.5]]))

In [11]:
def get_neighbors(sides,  side_name, size):    
    pad_map = {    
        'front': ['left', 'right', 'top', 'down'],
        'right': ['front', 'back', 'top', 'down'],
        'back': ['right', 'left', 'top', 'down'],
        'left': ['back', 'front', 'top', 'down'],
        'top': ['left', 'right', 'back', 'front'],
        'down': ['left', 'right', 'front', 'back'],
    }
    l, r, t, d = [sides[n] for n in pad_map[side_name]]
    return [
        l[:, -size:, :],
        r[:, :size, :],
        t[-size:, :, :],
        d[:size, :, :],
    ]
[p.shape for p in get_neighbors(sides,  'front', 1)]

[torch.Size([2, 1, 3]),
 torch.Size([2, 1, 3]),
 torch.Size([1, 2, 3]),
 torch.Size([1, 2, 3])]

In [47]:
def make_tris(size):
    tl = torch.triu(torch.ones(size, size)) - (torch.eye(size) * 0.5)
    dr = torch.tril(torch.ones(size, size)) - (torch.eye(size) * 0.5)    
    return {
        'tl': tl,
        'lt': tl.t().clone(),        
        'tr': tl.t().flip(dims=(0,)).t(),
        'rt': tl.flip(dims=(0,)).t(),
        'dr': dr.t(),
        'rd': dr,
        'ld': dr.flip(dims=(0,)).t(),
        'dl': dr.t().flip(dims=(0,)).t(),        
    }
tris = make_tris(3)
for key in tris:
    print(key)
    print(tris[key])
    

tl
tensor([[0.5, 1.0, 1.0],
        [0.0, 0.5, 1.0],
        [0.0, 0.0, 0.5]])
lt
tensor([[0.5, 0.0, 0.0],
        [1.0, 0.5, 0.0],
        [1.0, 1.0, 0.5]])
tr
tensor([[1.0, 1.0, 0.5],
        [1.0, 0.5, 0.0],
        [0.5, 0.0, 0.0]])
rt
tensor([[0.0, 0.0, 0.5],
        [0.0, 0.5, 1.0],
        [0.5, 1.0, 1.0]])
dr
tensor([[0.5, 1.0, 1.0],
        [0.0, 0.5, 1.0],
        [0.0, 0.0, 0.5]])
rd
tensor([[0.5, 0.0, 0.0],
        [1.0, 0.5, 0.0],
        [1.0, 1.0, 0.5]])
ld
tensor([[1.0, 1.0, 0.5],
        [1.0, 0.5, 0.0],
        [0.5, 0.0, 0.0]])
dl
tensor([[0.0, 0.0, 0.5],
        [0.0, 0.5, 1.0],
        [0.5, 1.0, 1.0]])


In [40]:
# lt t tr
# l  o r
# lb b br

# lt/tl tr/rt
# ld/dl dr/rd
def get_corners(size, l, r, t, d):
    tris = make_tris(size)
    lt = l[-size:, :, :] * tris['lt']
    tl = t[:, -size:, :] * tris['tl']
    
    tr = t[:, :size, :] * tris['tr']
    rt = r[-size:, :, :] * tris['rt']
    
    ld = l[:size, :, :] * tris['ld']
    dl = d[:, -size:, :] * tris['dl']
        
    rd = r[:size, :, :] * tris['rd']
    dr = d[:, :size, :] * tris['dr']
    
    ltc = lt + tl
    trc = tr + rt
    ldc = ld + dl
    drc = dr + rd
    return ltc, trc, ldc, drc

side_name, kernel_size = 'front', 3
size = (kernel_size - 1) // 2
o = sides[side_name]
l, r, t, d = get_neighbors(sides, side_name, size)
lt, tr, ld, dr = get_corners(size, l, r, t, n)
[p.shape for p in get_corners(size, l, r, t, d)]

[torch.Size([1, 1, 3]),
 torch.Size([1, 1, 3]),
 torch.Size([1, 1, 3]),
 torch.Size([1, 1, 3])]

In [53]:
# lt t tr
# l  o r
# lb b br
def pad_side(sides, side_name, kernel_size):
    o = sides[side_name]
    size = (kernel_size - 1) // 2
    l, r, t, n = get_neighbors(sides, side_name, size)
    lt, tr, lb, br = get_corners(size, l, r, t, n)

    top = torch.cat((lt, t, tr), dim=1)
    down = torch.cat((ld, d, dr), dim=1)
    middle = torch.cat((l, o, r), dim=1)
    
    return torch.cat((top, middle, down), dim=0)
    
pad_side(sides, 'left', 5).shape

RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 2

In [52]:
sides['left'].shape

torch.Size([2, 2, 3])

{'tl': tensor([[0.5]]),
 'lt': tensor([[0.5]]),
 'tr': tensor([[0.5]]),
 'rt': tensor([[0.5]]),
 'dr': tensor([[0.5]]),
 'rd': tensor([[0.5]]),
 'lb': tensor([[0.5]]),
 'bl': tensor([[0.5]])}