In [39]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F

np.random.seed(42)
PERMUTATIONS = [np.random.permutation(784) for _ in range(10)]

In [40]:
import torchvision

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                            torchvision.transforms.Lambda(lambda x: torch.flatten(x))])
train_set = torchvision.datasets.MNIST("./data", train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST("./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=200, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=200, shuffle=False)

In [94]:
class VectorizedLayer(nn.Module):
    def __init__(self, in_features, out_features, category_dim, nonneg=False, nonlin=True, expanded_input=False):
        super(VectorizedLayer, self).__init__()
        k = 1. / np.sqrt(in_features)
        if expanded_input:
            k = 1. / np.sqrt(in_features / category_dim)
        k = k * 0.25
        self.weight = nn.Parameter(torch.zeros(out_features, in_features))
        with torch.no_grad():
            if nonneg: self.weight.uniform_(0, k)
            else: self.weight.uniform_(-k, k)
        self.bias = nn.Parameter(torch.zeros(category_dim, out_features))
        self.nonneg = nonneg
        self.nonlin = nonlin
        self.mask_weights = torch.randint(0, 2, (category_dim, out_features))*2 - 1
        #self.mask_weights = torch.ones(category_dim, out_features)
        #for i in range(out_features):
        #    if np.random.rand() < 0.5:
        #        self.mask_weights[:, i] = -1
        
    def forward(self, input):
        self.input = input.detach()
        h = torch.matmul(input, self.weight.T) + self.bias
        if self.nonlin:
            mask = ((h.detach() * self.mask_weights).sum(dim=1) >= 0.).float()
            self.mask = mask
            h = h * mask[:, None, :]
        else:
            self.mask = torch.ones(h.shape[0], h.shape[2])
        return h
    
    def update(self, error, eta):
        #i = batch dim
        #j = category dim
        #n = input feature dim
        #m = output feature dim
        dot_prods = torch.einsum("ijn,ij->in", self.input, error.detach())
        delta_weight = torch.einsum("im,in->mn", self.mask, dot_prods) / len(self.input)
        delta_bias = torch.einsum("ij,im->jm", error, self.mask) / len(self.input)
        with torch.no_grad():
            self.weight -= eta*delta_weight
            self.bias -= eta*delta_bias
        self.post_step_callback()
        
    def set_grad(self, error):
        #i = batch dim
        #j = category dim
        #n = input feature dim
        #m = output feature dim
        dot_prods = torch.einsum("ijn,ij->in", self.input, error.detach())
        delta_weight = torch.einsum("im,in->mn", self.mask, dot_prods) / len(self.input)
        delta_bias = torch.einsum("ij,im->jm", error, self.mask) / len(self.input)
        
        if self.weight.grad is None:
            self.weight.grad = delta_weight.detach()
        else:
            self.weight.grad += delta_weight.detach()
        if self.bias.grad is None:
            self.bias.grad = delta_bias.detach()
        else:
            self.bias.grad += delta_bias.detach()
        
    def post_step_callback(self):
        if self.nonneg:
            with torch.no_grad():
                self.weight.clamp_(min=0)
                
def expand_input(input, category_dim):
    batch_dim, input_dim = input.shape
    expanded_input = torch.zeros(batch_dim, category_dim, category_dim*input_dim)
    for i in range(category_dim):
        expanded_input[:, i, i*input_dim:(i+1)*input_dim] = input
    return expanded_input

def permute_input(input, category_dim):
    batch_dim, input_dim = input.shape
    permuted_input = torch.zeros(batch_dim, category_dim, input_dim)
    for i in range(category_dim):
        permuted_input[:, i, :] = input[:, PERMUTATIONS[i]]
    return permuted_input

def eval_test_accuracy(model, input_data_fn=permute_input):
    num_correct = 0
    for batch_idx, (data, labels) in enumerate(test_loader):
        input = input_data_fn(data, 10)
        #with torch.no_grad():
        out = model.forward(input)[..., 0]
        num_correct += (out.argmax(dim=1) == labels).int().sum().item()
    acc = num_correct / 10000.
    return acc



In [114]:
model = nn.Sequential(VectorizedLayer(7840, 900, 10, nonneg=False, nonlin=True, expanded_input=True),
                      VectorizedLayer(900, 500, 10, nonneg=True, nonlin=True),
                      VectorizedLayer(500, 1, 10, nonneg=True, nonlin=False))
model


Sequential(
  (0): VectorizedLayer()
  (1): VectorizedLayer()
  (2): VectorizedLayer()
)

In [115]:
loss_fn = nn.CrossEntropyLoss(reduction="mean")

for epoch_idx in range(1000):
    print(epoch_idx)
    epoch_loss = 0.
    for batch_idx, (data, labels) in enumerate(train_loader):
        input = expand_input(data, 10)
        with torch.no_grad():
            out = model(input)[..., 0]
        epoch_loss += loss_fn(out, labels).item()
        softmax_out = F.softmax(out, dim=1)
        targets = torch.eye(10)[labels]
        error = softmax_out - targets
        for i in range(len(model)):
            model[i].update(error, eta=1e-2)
    print(epoch_loss / (batch_idx + 1))


0
0.26045166566967964
1
0.10555590665588777
2
0.06749829900761445
3
0.04683079943681757
4
0.03346292841869096
5
0.02360150794032961
6


KeyboardInterrupt: 

In [53]:
loss_fn = nn.CrossEntropyLoss(reduction="mean")
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.6)

for epoch_idx in range(1000):
    print(epoch_idx)
    epoch_loss = 0.
    for batch_idx, (data, labels) in enumerate(train_loader):
        input = permute_input(data, 10)
        with torch.no_grad():
            out = model(input)[..., 0]
        epoch_loss += loss_fn(out, labels).item()
        softmax_out = F.softmax(out, dim=1)
        targets = torch.eye(10)[labels]
        error = softmax_out - targets
        optimizer.zero_grad()
        for i in range(len(model)):
            model[i].set_grad(error)
        optimizer.step()
        for i in range(len(model)):
            model[i].post_step_callback()
    print(epoch_loss / (batch_idx + 1))


0
0.3311158266166846
1
0.17153284413119158
2
0.1339501079544425
3
0.11353618660320838
4


KeyboardInterrupt: 

In [116]:
eval_test_accuracy(model, expand_input)

0.9792

In [56]:
loss_fn = nn.CrossEntropyLoss(reduction="mean")
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0)

for epoch_idx in range(1000):
    epoch_loss = 0.
    for batch_idx, (data, labels) in enumerate(train_loader):
        input = permute_input(data, 10)
        out = model(input)[..., 0]
        loss = loss_fn(out, labels)
        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for i in range(len(model)):
            model[i].post_step_callback()
    print(epoch_loss / (batch_idx + 1))

0.27577578199406466
0.1650282567491134
0.13700482554733753
0.11706988518436749


KeyboardInterrupt: 

In [121]:

l = model(input).sum()
l.backward()

In [124]:
model[0].weight.grad

False