In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
import torchvision

np.random.seed(42)
PERMUTATIONS = [np.random.permutation(784) for _ in range(10)]

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [23]:
A = torch.randn(10, 9, 8, 7, 6)



for i in range(10):
    for j in range(9):
        A[i, j] = i*9 + j
        
        
B = A.view((A.shape[0]*A.shape[1],) + A.shape[2:])
B.shape

torch.Size([90, 8, 7, 6])

In [25]:
B[87]

tensor([[[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.]],

        [[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.]],

        [[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.]],

        [[87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [87., 87., 87., 87., 87., 87.],
         [

In [43]:
class ConvVectorizedLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, category_dim, stride=1, padding=0, dilation=1,
                 groups=1, padding_mode='zeros', nonneg=False, nonlin=True, expanded_input=False, pool=True, pool_size=2, pool_stride=2):
        super(ConvVectorizedLayer, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False, padding_mode=padding_mode)
        
        k = groups / (in_channels * np.prod(kernel_size))
        if expanded_input:
            k = k * category_dim
        with torch.no_grad():
            if nonneg: self.conv.weight.uniform_(0, np.sqrt(k))
            else: self.conv.weight.uniform_(-np.sqrt(k), np.sqrt(k))
        self.bias = nn.Parameter(torch.zeros(category_dim, out_channels))
        
        if pool:
            self.avgpool = nn.AvgPool2d(kernel_size=pool_size, stride=pool_stride)
        
        
        self.pool = pool
        self.nonneg = nonneg
        self.nonlin = nonlin
        
    def forward(self, input):
        #(batch_dim, category_dim, channels, width, height)
        input_reshaped = input.view((input.shape[0]*input.shape[1],) + input.shape[2:])
        conv_out = self.conv(input_reshaped)
        saved_shape = conv_out.shape
        conv_out = conv_out.view((input.shape[0], input.shape[1]) + conv_out.shape[1:]) #?
        conv_out = conv_out + self.bias[None, :, :, None, None]
        if self.nonlin:
            conv_out_sum = conv_out.sum(dim=1).detach()
            mask = (conv_out_sum > 0.).float()
            conv_out = conv_out * mask[:, None, :, :, :]
        if self.pool:
            conv_out = conv_out.view(saved_shape)
            conv_out = self.avgpool(conv_out)
            conv_out = conv_out.view((input.shape[0], input.shape[1]) + conv_out.shape[1:])
        return conv_out
    
    def post_step_callback(self):
        if self.nonneg:
            with torch.no_grad():
                self.conv.weight.clamp_(min=0)
                



In [44]:
conv = ConvVectorizedLayer(in_channels=10, out_channels=3, kernel_size=6,
                           category_dim=10, nonneg=True, nonlin=True, pool=True, pool_size=3, pool_stride=2)

In [45]:
def expand_input(input, category_dim):
    #input = (batch, channels, width, height)
    batch_size, in_channels = input.shape[:2]
    expanded_input = torch.zeros((batch_size,) + (category_dim,) + (in_channels*category_dim,) + input.shape[2:])
    for i in range(category_dim):
        expanded_input[:, i, i*in_channels:(i+1)*in_channels] = input
    return expanded_input

    

In [46]:
import layers
from imp import reload
reload(layers)

<module 'layers' from '/home/davidclark/Projects/VectorizedNets/layers.py'>

In [62]:
class VectorizedLeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_part = nn.Sequential(ConvVectorizedLayer(30, 6, 5, 10, expanded_input=True, nonlin=True, nonneg=False),
                                  ConvVectorizedLayer(6, 16, 5, 10, nonlin=True, nonneg=False),
                                  ConvVectorizedLayer(16, 120, 5, 10, nonlin=True, pool=False, nonneg=False))
        self.fc_part = nn.Sequential(layers.VectorizedLayer(120, 84, 10, nonneg=False, nonlin=True),
                                layers.VectorizedLayer(84, 1, 10, nonneg=False, nonlin=False))
        
    def post_step_callback(self):
        for i in range(len(self.conv_part)):
            self.conv_part[i].post_step_callback()
        for i in range(len(self.fc_part)):
            self.fc_part[i].post_step_callback()
        
    def forward(self, input):
        conv_out = self.conv_part(input)
        conv_out = conv_out.view(conv_out.shape[:2] + (120,))
        fc_out = self.fc_part(conv_out)
        return fc_out
        

In [51]:
input = torch.randn(128, 10, 30, 32, 32)

model = VectorizedLeNet()

#model(input).shape

In [49]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [82]:
model = VectorizedLeNet()
loss_fn = nn.CrossEntropyLoss(reduction="mean")
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch_idx in range(1000):
    print(epoch_idx)
    epoch_loss = 0.
    for batch_idx, (data, labels) in enumerate(train_loader):
        input = expand_input(data, 10)
        out = model(input)[..., 0]
        loss = loss_fn(out, labels)
        epoch_loss += loss.item()
        acc = (out.detach().argmax(dim=1) == labels).float().mean()
        print(acc)
        #optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #model.post_step_callback()
    print(epoch_loss / (batch_idx + 1))


0
tensor(0.1172)
tensor(0.1328)
tensor(0.1250)
tensor(0.1172)
tensor(0.1250)
tensor(0.1094)
tensor(0.0703)
tensor(0.1719)
tensor(0.1328)
tensor(0.1250)
tensor(0.1328)
tensor(0.1641)
tensor(0.1406)
tensor(0.1484)
tensor(0.1562)
tensor(0.1172)
tensor(0.1328)
tensor(0.1641)
tensor(0.1641)
tensor(0.1719)
tensor(0.1016)
tensor(0.1484)
tensor(0.1797)
tensor(0.1641)
tensor(0.1172)
tensor(0.1953)
tensor(0.1562)
tensor(0.1484)
tensor(0.1406)
tensor(0.1875)
tensor(0.1797)
tensor(0.1953)
tensor(0.1250)
tensor(0.1406)
tensor(0.1484)
tensor(0.1562)
tensor(0.1875)
tensor(0.2266)
tensor(0.1641)
tensor(0.1562)
tensor(0.1562)
tensor(0.1562)
tensor(0.1484)
tensor(0.1328)
tensor(0.1641)
tensor(0.1484)
tensor(0.2109)
tensor(0.1562)
tensor(0.1484)
tensor(0.2188)
tensor(0.2109)
tensor(0.1328)
tensor(0.1797)
tensor(0.1562)
tensor(0.1641)
tensor(0.1406)
tensor(0.1172)
tensor(0.1797)
tensor(0.1562)
tensor(0.1875)
tensor(0.1484)
tensor(0.2344)
tensor(0.1797)
tensor(0.1562)
tensor(0.2109)
tensor(0.1562)
tensor(0

KeyboardInterrupt: 

In [80]:
model.conv_part[0].conv.weight.grad

tensor([[[[ 3.8833e-01,  2.5953e-01,  2.9085e-01,  3.6096e-01,  3.4685e-01],
          [ 3.1685e-02, -1.3424e-01, -8.6265e-02,  4.4567e-02,  9.5919e-02],
          [-1.4994e-01, -3.2541e-01, -3.0814e-01, -1.9294e-01, -1.3343e-01],
          [-2.0272e-01, -3.5444e-01, -3.8051e-01, -2.9681e-01, -2.7025e-01],
          [-3.0795e-01, -4.4734e-01, -4.7512e-01, -3.8065e-01, -3.4627e-01]],

         [[ 4.8603e-01,  3.2542e-01,  3.2897e-01,  3.9225e-01,  3.8920e-01],
          [ 1.2913e-01, -6.9508e-02, -4.6146e-02,  7.8796e-02,  1.4004e-01],
          [-6.6802e-02, -2.7440e-01, -2.7874e-01, -1.7278e-01, -1.1063e-01],
          [-1.5374e-01, -3.3059e-01, -3.7922e-01, -3.0701e-01, -2.8014e-01],
          [-3.0846e-01, -4.6694e-01, -5.1109e-01, -4.1700e-01, -3.7184e-01]],

         [[ 2.9880e-01,  1.5014e-01,  1.5748e-01,  2.3186e-01,  2.3975e-01],
          [ 1.1446e-02, -1.7486e-01, -1.4560e-01, -1.2637e-02,  4.5705e-02],
          [-1.3347e-01, -3.2964e-01, -3.3304e-01, -2.2949e-01, -1.7978e-

In [29]:
x = torch.randn(5, 5)
A = torch.randn(5, 5)
B = torch.randn(5, 5)

A.requires_grad = True
B.requires_grad = False

y = A.mm(B).mm(x)
l = y.sum()
l.backward()

In [31]:
B.grad