In [6]:
from fastai.vision.all import *
from torchvision import transforms as vis_tfms
import random
from torchvision.datasets import MNIST

In [7]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
    
print(device)

cuda


In [30]:
pin_memory = False

In [31]:
orig_tf = vis_tfms.Compose([vis_tfms.ToTensor(), vis_tfms.Normalize((0.1307), (0.3081)),
                            vis_tfms.Lambda(torch.flatten)])

orig_train_kwargs = {'batch_size': 50000}
orig_test_kwargs = {'batch_size': 10000}
if device == 'cuda':
    orig_cuda_kwargs = {'num_workers': 1, 'shuffle': True, 'pin_memory': pin_memory}
    orig_train_kwargs.update(orig_cuda_kwargs)
    orig_test_kwargs.update(orig_cuda_kwargs)

orig_train_loader = torch.utils.data.DataLoader(
    MNIST('./data/', train=True, download=True, transform=orig_tf), **orig_train_kwargs
)

orig_test_loader = torch.utils.data.DataLoader( # may be unnecessary
    MNIST('./data/', train=False, download=True, transform=orig_tf), **orig_test_kwargs
)

In [32]:
def label(x, y, n_labels=10):
    out = x.clone()
    out[:, :n_labels] *= 0.0
    out[range(x.shape[0]), y] = x.max()
    return out

In [33]:
def get_neg(x, y, n_labels=10, ratio=1): # constructed myself
    labels = []
    for n in range(n_labels):
        labels.append(n)
        
    y_neg = []
    for n in range(ratio):
        y_neg.append(torch.zeros_like(y))
    
    for i in range(len(y)):
        labels_ = labels.copy()
        labels_.remove(y[i])
        negs = random.sample(labels_, k=ratio)
        for n in range(len(negs)):
            y_neg[n][i] = negs[n]
            
    x_neg = []
    for yn in y_neg:
        x_neg.append(label(x, yn, n_labels))
    x_neg = torch.cat(x_neg)
    return x_neg

In [34]:
a, b = orig_train_loader
x, y = a
x, y = x.to(device), y.to(device)

x_val, y_val = b
x_val = x_val.to(device)
y_val = y_val.to(device)

In [5]:
class Layer(nn.Linear):
    def __init__(self, f_in, f_out, bias=True, device=None, dtype=None,
                threshold=2., num_epochs=60, bs=256, n_labels=10):
        super().__init__(f_in, f_out, bias, device, dtype)
        self.relu = nn.ReLU()
        self.opt = Adam(self.parameters(), lr=0.03)
        self.threshold = threshold
        self.num_epochs = num_epochs
        self.bs = bs
        self.n_labels = n_labels
    
    def forward(self, x):
        x_direction = x / (x.norm(2, 1, keepdim=True) + 1e-4) # not sure what the point of this is, but it looks kind of like LVQ
        # norm yields L2 norm of abs(x): sum(abs(x)**2)**(1/2); 2 can be replaced by any number using 1st arg
        # using 'fro' as first arg gives L2 norm of x
        # using 'nuc' gives nuclear norm of x: sum(root2(e_i)) for eigenvalues E of X' @ X or X @ X'
        
        return self.relu(torch.mm(x_direction, self.weight.T) + self.bias.unsqueeze(0)) # X_dir @ W' + b
    
    def train_one_batch(self, x, y):
        x_pos = label(x, y)
        x_neg = get_neg(x, y)
        
        g_pos = self.forward(x_pos).pow(2).mean(1) - self.threshold
        g_neg = self.forward(x_neg).pow(2).mean(1) - self.threshold
        
        loss = torch.log1p
    
    def train(self, x_pos, x_neg):
        for i in range(self.num_epochs):
            # for each image, obtain average of squared activations
            g_pos = self.forward(x_pos).pow(2).mean(1)
            g_neg = self.forward(x_neg).pow(2).mean(1)
            
            # for each image, loss = ln(1 + e^(wrongness))
            # for negative wrongness = positive/negative(goodness - threshold)
            loss = torch.log1p(torch.exp(
                torch.cat([self.threshold - g_pos, g_neg - self.threshold])
            )).mean()
            
            self.opt.zero_grad() # reset gradient to 0
            loss.backward() # calculate gradient - not backpropagation because nothing to propagate to
            self.opt.step() # adjust weights according to Adam optimizer: gradient descent with global, local momentum
            if i % log_interval == 0:
                print('Loss: ', loss.item())
                
        self.opt.zero_grad() # to save memory        
        return self.forward(x_pos).detach(), self.forward(x_neg).detach() # send to next layer