In [2]:
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from sym import models
from torch.nn.modules.utils import _pair

def activation_func(activation):
    return  nn.ModuleDict([
        ['relu', nn.ReLU(inplace=True)],
        ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
        ['selu', nn.SELU(inplace=True)],
    ])[activation]

In [69]:
class LCN(nn.Module):
    def __init__(self, in_channels=1, out_channels=10, h=280, w=280, f=10, ks=28, s=28, p=0, activation='relu', bias=True):
        super(LCN, self).__init__()
        width_span = int((w-ks+2*p)/s) + 1
        height_span = int((h-ks+2*p)/s) + 1
        self.weight = nn.Parameter(
            torch.ones(width_span*height_span*f, in_channels, ks, ks)
        )
        self.weight[width_span*height_span*(f-1):]*=2
        if bias:
            self.bias = nn.Parameter(
                torch.zeros(width_span*height_span*f, in_channels)
            )
        else:
            self.register_parameter('bias', None)
        self.kernel_size = _pair(ks)
        self.stride = _pair(s)
        self.activation = activation_func(activation)
        self.decoder = nn.Linear(width_span*height_span*f, out_channels)
        self.in_channels = in_channels
        self.ws,self.hs = width_span,height_span
        self.f = f
        self.pad = p

    def forward(self, x):
        _, c, h, w = x.size()
        print(x.size(),'\n',x)
        x = nn.functional.pad(x,(self.pad,self.pad,self.pad,self.pad),'constant',0)
        kh, kw = self.kernel_size
        dh, dw = self.stride
        x = x.unfold(2, kh, dh)
        x = x.unfold(3, kw, dw)
        x = x.reshape(x.size(0),-1,self.in_channels,kh,kw)
        #print(x.size(),'\n',x)
        x = x.repeat(1,self.f,1,1,1)
        #print(x.size(),'\n',x)
        x = (x * self.weight).sum([-1, -2])
        print(x.size(),'\n',x)
        if self.bias is not None:
            x += self.bias
        x = x.reshape(x.size(0),self.f,self.hs,self.ws)
        #x = x.reshape(x.size(0),self.f,x.size(1)//self.f)
        print(x.size(),'\n',x)
        x = self.activation(x)
        x = x.view(x.size(0),-1)
        x = self.decoder(x)
        return x

In [70]:
lcn = LCN(h=3,w=3,f=2,ks=2,s=1,p=0)

In [71]:
x = torch.arange(18)
x = x.reshape(2,1,3,3)

In [72]:
lcn(x)

torch.Size([2, 1, 3, 3]) 
 tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]]],


        [[[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]]]])
torch.Size([2, 8, 1]) 
 tensor([[[  8.],
         [ 12.],
         [ 20.],
         [ 24.],
         [ 16.],
         [ 24.],
         [ 40.],
         [ 48.]],

        [[ 44.],
         [ 48.],
         [ 56.],
         [ 60.],
         [ 88.],
         [ 96.],
         [112.],
         [120.]]], grad_fn=<SumBackward1>)
torch.Size([2, 2, 2, 2]) 
 tensor([[[[  8.,  12.],
          [ 20.,  24.]],

         [[ 16.,  24.],
          [ 40.,  48.]]],


        [[[ 44.,  48.],
          [ 56.,  60.]],

         [[ 88.,  96.],
          [112., 120.]]]], grad_fn=<ViewBackward>)


tensor([[  3.2281,  -2.4634, -18.3855,  20.1866,   9.4944,   8.0573,   3.8867,
          -9.8622,  -1.4546,   8.0757],
        [ 14.7531, -12.5270, -31.1293,  58.5585,  40.1661,  34.4191,  -9.3737,
         -44.1941, -16.2643,   0.9873]], grad_fn=<AddmmBackward>)