In [2]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
import torchvision

np.random.seed(42)
PERMUTATIONS = [np.random.permutation(784) for _ in range(10)]

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
def expand_input(input, expanded_input, category_dim):
    #input = (batch, channels, width, height)
    batch_size, in_channels = input.shape[:2]
    #expanded_input = torch.zeros((batch_size, category_dim, in_channels*category_dim) + input.shape[2:])
    for i in range(category_dim):
        expanded_input[:, i, i*in_channels:(i+1)*in_channels] = input
    return expanded_input
    

In [4]:
import layers
from imp import reload
reload(layers)

<module 'layers' from '/home/davidclark/Projects/VectorizedNets/layers.py'>

In [26]:
reload(layers)
class VectorizedLeNet(nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()
        self.conv_part = nn.Sequential(layers.ConvVectorizedLayer(10, 30, 128, 3, expanded_input=True, nonlin=True, nonneg=False, pool=False, device=device),
                                       layers.ConvVectorizedLayer(10, 128, 128, 3, expanded_input=False, nonlin=True, nonneg=False, pool=True, device=device),
                                       layers.ConvVectorizedLayer(10, 128, 128, 3, expanded_input=False, nonlin=True, nonneg=False, pool=False, device=device),
                                       layers.ConvVectorizedLayer(10, 128, 128, 3, expanded_input=False, nonlin=True, nonneg=False, pool=True, device=device),
                                       layers.ConvVectorizedLayer(10, 128, 256, 3, expanded_input=False, nonlin=True, nonneg=False, pool=False, device=device),
                                       layers.ConvVectorizedLayer(10, 256, 256, 3, expanded_input=False, nonlin=True, nonneg=False, pool=True, device=device),)
        self.fc_part = nn.Sequential(layers.VectorizedLayer(10, 256*4*4, 2048, expanded_input=False, nonlin=True, nonneg=False, device=device),
                                     layers.VectorizedLayer(10, 2048, 1024, expanded_input=False, nonlin=True, nonneg=False, device=device),
                                     layers.VectorizedLayer(10, 1024, 1, expanded_input=False, nonlin=False, nonneg=False, device=device))
        
    def post_step_callback(self):
        for i in range(len(self.conv_part)):
            self.conv_part[i].post_step_callback()
        for i in range(len(self.fc_part)):
            self.fc_part[i].post_step_callback()
        
    def forward(self, input):
        conv_out = self.conv_part(input)
        conv_out = conv_out.view(conv_out.shape[:2] + (256*4*4,))
        fc_out = self.fc_part(conv_out)
        return fc_out
        

In [29]:
model = VectorizedLeNet(device=0)

In [16]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [17]:
input = torch.zeros(128, 10, 30, 32, 32, device=0)

In [None]:
loss_fn = nn.CrossEntropyLoss(reduction="mean")
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch_idx in range(1000):
    print(epoch_idx)
    epoch_loss = 0.
    epoch_correct = 0.
    for batch_idx, (data, labels) in enumerate(train_loader):
        input *= 0.
        expand_input(data.to(0), input[:len(data)], 10) #.to(0)
        optimizer.zero_grad()
        out = model(input[:len(data)])[..., 0]
        loss = loss_fn(out, labels.to(0))
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_correct += (out.detach().cpu().argmax(dim=1) == labels).float().sum()

    print(epoch_correct / 50000.)
    print(epoch_loss / (batch_idx + 1))


0


In [25]:

epoch_correct = 0.
for batch_idx, (data, labels) in enumerate(test_loader):
    input *= 0.
    expand_input(data.to(0), input[:len(data)], 10) #.to(0)
    out = model(input[:len(data)])[..., 0]
    epoch_correct += (out.detach().cpu().argmax(dim=1) == labels).float().sum()

print(epoch_correct / 10000.)


tensor(0.7026)


In [29]:
x = torch.randn(5, 5)
A = torch.randn(5, 5)
B = torch.randn(5, 5)

A.requires_grad = True
B.requires_grad = False

y = A.mm(B).mm(x)
l = y.sum()
l.backward()

In [31]:
B.grad