In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
n_epochs = 20
batch_size_train = 64
batch_size_test = 1000

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(root='./data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(root='./data', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [3]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)

torch.Size([1000, 1, 28, 28])


In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
for i in range(3):
    plt.subplot(1,3,i+1)
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Ground Truth: {}".format(example_targets[i]))
fig

In [4]:
#network flowchart:
#input vector --- random initialized weights ---> hyperdimensional vecotor ---> one_hot_net l1 ---> ...ln... ---> sigmoid/softmax output
#One hot net

class One_hot_op(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, W, epsilon):
        Z = torch.matmul(W, A)
        ctx.Z = Z
        ctx.A = A
        ctx.W = W
        ret = Z > epsilon
        #print(ret[1:10][1:10])
        return ret.float()

    @staticmethod
    def backward(ctx, dL_dA):
        step = ctx.Z > 0
        drZ_dW = torch.matmul(step,torch.transpose(ctx.A, 0,1))
        drZ_dw = drZ_dw / torch.abs(drZ_dw)
        drZ_dA = torch.matmul(torch.transpose(ctx.W), step)
        drZ_dA = drZ_dA / torch.abs(drZ_dA)
        dA = torch.matmul(dL_dA, drZ_dA)
        dW = torch.matmul(dL_dA, drZ_dW)
        return dA, dW, None


class One_hot_layer(nn.Module):
    def __init__(self, in_dim, out_dim, initialization_f, epsilon):
        super(One_hot_layer, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.epsilon = epsilon
        self.initialization_f = initialization_f
        initialized_weight = initialization_f(out_dim, in_dim)
        self.weight = nn.Parameter(initialized_weight, requires_grad = True)
        self.op = One_hot_op
        
    def __str__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_dim) + ',' \
               + str(self.out_dim) + ')'
        
    def forward(self, A):
        return self.op.apply(A, self.weight, self.epsilon)
    

class One_hot_net(nn.Module):
    def __init__(self, in_dim, n_class, f_encoder, encoder_multiplier, f_initializer, epsilon, n_layers=2, layer_size_factor=[1, 5], dropout=[-1, 0.5]):
        super(One_hot_net, self).__init__()
        self.layers = nn.ModuleList()
        self.in_dim = in_dim
        feature_len = in_dim * encoder_multiplier
        self.feature_len = feature_len
        self.n_layers=n_layers
        self.layer_size_factor=layer_size_factor
        self.dropout=dropout
        self.n_class = n_class
        self.f_encoder = f_encoder
        self.f_initializer = f_initializer
        for i in range(n_layers):
            if dropout[i] > 0:
                self.layers.append(nn.Dropout(dropout[i]))
            if i < n_layers - 1:
                self.layers.append(
                    One_hot_layer(int(feature_len // layer_size_factor[i]), int(feature_len // layer_size_factor[i + 1]), f_initializer, epsilon))
        self.tail = nn.Linear(int(feature_len // layer_size_factor[-1]), n_class)
        self.out = nn.LogSoftmax(dim = n_class)
    
    def flatten(self, X):
        return X.view(X.shape[0], X.shape[1]*X.shape[2]*X.shape[3])

    def forward(self, X):
        X = self.flatten(X)
        X = torch.transpose(X, 0, 1)
        X = self.f_encoder.apply(X)
        for layer in self.layers:
            X = layer(X)
        X = torch.transpose(X, 0, 1)
        X = self.tail(X)
        return self.out(X)

In [5]:
#initializers and encoders
def uniform_initializer(out_dim, in_dim):
    tensor = torch.empty(out_dim, in_dim)
    return torch.nn.init.uniform_(tensor, a=-2, b=2).cuda() 

class simple_encoder():
    def __init__(self, out_dim, in_dim):
        self.W = uniform_initializer(out_dim, in_dim)
        
    def apply(self, X):
        #print(X.shape)
        #print(self.W.shape)
        return torch.matmul(self.W, X)

In [6]:
parameters = {
    'in_dim': 784,
    'n_class': 10,
    'f_encoder': simple_encoder(784*20, 784),
    'f_initializer': uniform_initializer,
    'encoder_multiplier': 20,
    'epsilon': 10e-3,
    'n_layers': 2,
    'layer_size_factor': [1, 1],
    'dropout': [-1, -1]
}

In [7]:
device = torch.device("cuda:0")
model1 = One_hot_net(parameters['in_dim'], parameters['n_class'], parameters['f_encoder'], parameters['encoder_multiplier'], 
                     parameters['f_initializer'], parameters['epsilon'], parameters['n_layers'], 
                     parameters['layer_size_factor'], parameters['dropout']).to(device)

optimizer1 = torch.optim.SGD(model1.parameters(), lr=0.001, momentum=0.5)

In [8]:
from torchsummary import summary
summary(model1, (1, 28, 28))

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 10)

In [None]:
def train(epoch, model, optimizer, trainloader, log_interval = 10):
    model.train()
    train_losses = []
    train_counter = []
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
            (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
              torch.save(model.state_dict(), '/results/model.pth')
            torch.save(optimizer.state_dict(), '/results/optimizer.pth')
        return model, optimizer, train_losses, train_counter
