In [2]:
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from sym import models
from torch.nn.modules.utils import _pair

def activation_func(activation):
    return  nn.ModuleDict([
        ['relu', nn.ReLU(inplace=True)],
        ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
        ['selu', nn.SELU(inplace=True)],
    ])[activation]

In [37]:
class LCN(nn.Module):
    def __init__(self, in_channels=1, out_channels=10, h=280, w=280, f=10, ks=28, s=28, p=0, activation='relu', bias=True):
        super(LCN, self).__init__()
        width_span = int((w-ks+2*p)/s) + 1
        height_span = int((h-ks+2*p)/s) + 1
        self.weight =  torch.ones(f, height_span, width_span, in_channels, ks, ks)
        self.weight = torch.einsum("ijklmn,k->ijklmn",self.weight,
                                   torch.arange(width_span)
                                  )
        self.weight = nn.Parameter(self.weight.flatten(0,2))
#         self.weight = nn.Parameter(
#             torch.ones(width_span*height_span*f, in_channels, ks, ks)
#         )
        #self.weight[width_span*height_span*(f-1):]*=2
        if bias:
            self.bias = nn.Parameter(
                torch.zeros(width_span*height_span*f, in_channels)
            )
        else:
            self.register_parameter('bias', None)
        self.kernel_size = _pair(ks)
        self.stride = _pair(s)
        self.activation = activation_func(activation)
        self.decoder = nn.Linear(width_span*height_span*f, out_channels)
        self.in_channels = in_channels
        self.ws,self.hs = width_span,height_span
        self.f = f
        self.pad = p

    def forward(self, x):
        _, c, h, w = x.size()
        print("x.size", x.size(),"x", x,sep="\n")
        print("Weights")
        print("size", self.weight.data.size(),"w", self.weight.data,sep="\n")
        x = nn.functional.pad(x,(self.pad,self.pad,self.pad,self.pad),'constant',0)
        kh, kw = self.kernel_size
        dh, dw = self.stride
        x = x.unfold(2, kh, dh)
        x = x.unfold(3, kw, dw)
        x = x.reshape(x.size(0),-1,self.in_channels,kh,kw)
        #print(x.size(),'\n',x)
        x = x.repeat(1,self.f,1,1,1)
        #print(x.size(),'\n',x)
        x = (x * self.weight).sum([-1, -2])
        print("Convolve with weights")
        print("x.size", x.size(),"x", x,sep="\n")
        if self.bias is not None:
            x += self.bias
        x = x.reshape(x.size(0),self.f,self.hs,self.ws)
        #x = x.reshape(x.size(0),self.f,x.size(1)//self.f)
        print("Final")
        print("x.size", x.size(),"x", x,sep="\n")
        x = self.activation(x)
        x = x.view(x.size(0),-1)
        x = self.decoder(x)
        return x

In [42]:
lcn = LCN(h=3,w=3,f=2,ks=2,s=1,p=0)

In [43]:
x = torch.arange(18)
x = x.reshape(2,1,3,3)

In [45]:
lcn.weight.unflatten(0,(2,2,2))[:,:,0]

tensor([[[[[0., 0.],
           [0., 0.]]],


         [[[0., 0.],
           [0., 0.]]]],



        [[[[0., 0.],
           [0., 0.]]],


         [[[0., 0.],
           [0., 0.]]]]], grad_fn=<SelectBackward>)

In [44]:
lcn(x)

x.size
torch.Size([2, 1, 3, 3])
x
tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]]],


        [[[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]]]])
Weights
size
torch.Size([8, 1, 2, 2])
w
tensor([[[[0., 0.],
          [0., 0.]]],


        [[[1., 1.],
          [1., 1.]]],


        [[[0., 0.],
          [0., 0.]]],


        [[[1., 1.],
          [1., 1.]]],


        [[[0., 0.],
          [0., 0.]]],


        [[[1., 1.],
          [1., 1.]]],


        [[[0., 0.],
          [0., 0.]]],


        [[[1., 1.],
          [1., 1.]]]])
Convolve with weights
x.size
torch.Size([2, 8, 1])
x
tensor([[[ 0.],
         [12.],
         [ 0.],
         [24.],
         [ 0.],
         [12.],
         [ 0.],
         [24.]],

        [[ 0.],
         [48.],
         [ 0.],
         [60.],
         [ 0.],
         [48.],
         [ 0.],
         [60.]]], grad_fn=<SumBackward1>)
Final
x.size
torch.Size([2, 2, 2, 2])
x
tensor([[[[ 0., 12.],
          [ 0., 24.]],

 

tensor([[-10.9365,   2.1441,   6.2785,   9.1374,  -4.5469,  -0.3482,  -2.6181,
           7.1429,   2.8883,  -7.8163],
        [-32.4127,   9.5459,  26.0804,  17.8768, -11.9783,  -4.8390,  -7.0343,
          25.1942,   7.7695, -23.9097]], grad_fn=<AddmmBackward>)