In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F

np.random.seed(42)
PERMUTATIONS = [np.random.permutation(784) for _ in range(10)]

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [23]:
A = torch.randn(10, 9, 8, 7, 6)



for i in range(10):
    for j in range(9):
        A[i, j] = i*9 + j
        
        
B = A.view((A.shape[0]*A.shape[1],) + A.shape[2:])
B.shape

torch.Size([90, 8, 7, 6])

In [25]:
B[87]

tensor([[[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.]],

        [[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.]],

        [[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.]],

        [[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [

In [50]:
class ConvVectorizedLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, category_dim, stride=1, padding=0, dilation=1,
                 groups=1, padding_mode='zeros', nonneg=False, nonlin=True, expanded_input=False, pool=True, pool_size=-1, pool_stride=-1):
        super(ConvVectorizedLayer, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False, padding_mode=padding_mode)
        
        k = groups / (in_channels * np.prod(kernel_size))
        if expanded_input:
            k = k * category_dim
        with torch.no_grad():
            if nonneg: self.conv.weight.uniform_(0, np.sqrt(k))
            else: self.conv.weight.uniform_(-np.sqrt(k), np.sqrt(k))
        self.bias = nn.Parameter(torch.zeros(category_dim, out_channels))
        
        if pool:
            self.avgpool = nn.AvgPool2d(kernel_size=pool_size, stride=pool_stride)
        
        
        self.pool = pool
        self.nonneg = nonneg
        self.nonlin = nonlin
        
    def forward(self, input):
        #(batch_dim, category_dim, channels, width, height)
        input_reshaped = input.view((input.shape[0]*input.shape[1],) + input.shape[2:])
        conv_out = self.conv(input_reshaped)
        saved_shape = conv_out.shape
        conv_out = conv_out.view((input.shape[0], input.shape[1]) + conv_out.shape[1:]) #?
        conv_out = conv_out + self.bias[None, :, :, None, None]
        if self.nonlin:
            conv_out_sum = conv_out.sum(dim=1).detach()
            mask = (conv_out_sum > 0.).float()
            conv_out = conv_out * mask[:, None, :, :, :]
        if self.pool:
            conv_out = conv_out.view(saved_shape)
            conv_out = self.avgpool(conv_out)
            conv_out = conv_out.view((input.shape[0], input.shape[1]) + conv_out.shape[1:])
        return conv_out
    
    def post_step_callback(self):
        if self.nonneg:
            with torch.no_grad():
                self.conv.weight.clamp_(min=0)
                



In [56]:
conv = ConvVectorizedLayer(in_channels=10, out_channels=3, kernel_size=6,
                           category_dim=10, nonneg=True, nonlin=True, pool=True, pool_size=3, pool_stride=2)

In [64]:
def expand_input(input, category_dim):
    #input = (batch, channels, width, height)
    batch_size, in_channels = input.shape[:2]
    expanded_input = torch.zeros((batch_size,) + (category_dim,) + (in_channels*category_dim,) + input.shape[2:])
    for i in range(category_dim):
        expanded_input[:, i, i*in_channels:(i+1)*in_channels]
    return expanded_input

    

(torch.Size([200, 3, 28, 28]), torch.Size([200, 10, 30, 28, 28]))

In [49]:
out[0, 4, 0]

tensor([[ 0.0911,  0.9251,  0.8003,  0.4003,  0.3749,  0.3951,  0.3706,  0.5327,
          0.2019,  0.5012, -0.0698],
        [-0.2160,  1.4210,  0.9609,  0.2281,  0.0295,  0.0000,  0.3862,  0.6340,
          0.6682,  0.3384, -0.2897],
        [-0.4937,  1.2639,  0.9281,  0.0576,  0.0000,  0.0000,  0.3630,  0.2920,
          0.2813,  0.3893, -0.1127],
        [-0.7265, -0.0214, -0.1159, -0.1476, -0.1476, -0.3713, -0.1671, -0.5562,
         -0.8358, -0.6825, -0.9315],
        [-0.0844, -0.2680, -0.1064, -0.1476, -0.2641, -0.4892, -0.2719, -0.2046,
         -0.7462, -1.0276, -0.8504],
        [ 0.2280,  0.0438, -0.1954, -0.0410, -0.2976, -0.1757, -0.0446,  0.0000,
         -0.0449, -0.5120, -0.5180],
        [ 0.0000,  0.0000, -0.2391,  0.3729,  0.0760,  0.2647,  0.0000,  0.0000,
          0.1351, -0.4942, -0.3749],
        [ 0.0000,  0.0458,  0.2932,  0.6550,  0.3292,  0.4564,  0.2740,  0.0000,
         -0.5876, -1.3710, -0.6554],
        [-0.0631,  0.0817,  0.9879,  1.9242,  0.8596,  0