# Libraries

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

# Hyperparameters

In [None]:
train_batchsize=50000
test_batchsize=10000
Net_layers = [784,500,500,500]
threshold = 3.0
learning_rate = 0.03
epochs = 1000


# Load dataset

In [None]:
def Data_loader():

    # Calculate the mean and standard deviation of the MNIST training dataset
    train_dataset = MNIST('./data/', train=True, download=True, transform=ToTensor())
    train_mean = train_dataset.data.float().mean() / 255.0
    train_std = train_dataset.data.float().std() / 255.0

    # Define the Z-score normalization transformation
    transform = Compose([ToTensor(), Normalize(mean=(train_mean,), std=(train_std,)), Lambda(lambda x: x.view(-1))])

    train_loader = DataLoader(
        MNIST('./data/', train=True, download=True, transform=transform),
        batch_size=train_batchsize, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False, download=True, transform=transform),
        batch_size=test_batchsize, shuffle=False)

    return train_loader, test_loader


# Generate data

In [None]:
def y_OneHot_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = x.max()

    return x_

In [None]:
def get_y_neg(y):
    # Generate a random tensor of labels
    y_neg = torch.randint(0, 10, size=y.size(), device=y.device)

    # Find the positions where the random labels match the original labels
    match_positions = (y_neg == y)

    # Replace these positions with another random choice
    while match_positions.any():
        y_neg[match_positions] = torch.randint(0, 10, size=(match_positions.sum().item(),), device=y.device)
        match_positions = (y_neg == y)

    return y_neg

# Define NetWork

In [None]:
class NetWork(torch.nn.Module):
    def __init__(self, dims):
        super().__init__()
        # Initialize layers of the network based on the provided dimensions
        self.layers = [Layer(dims[d], dims[d + 1]).cuda() for d in range(len(dims) - 1)]

    def predict(self, x):
        # Initialize list to store goodness of each label
        g_for_labels = []
        # Iterate over each label
        for label in range(10):
            h = y_OneHot_on_x(x, label)
            goodness = []
            # Pass the input through each layer and calculate goodness
            for layer in self.layers:
                h = layer(h)
                goodness += [h.pow(2).mean(1)]
            # Sum the goodness of all layers for the current label
            g_for_labels += [sum(goodness)]

        # Return the label with the maximum goodness
        return  torch.stack(g_for_labels, dim=1).argmax(dim=1)

    def train(self, x_pos, x_neg):
        # Initialize positive and negative inputs
        h_pos, h_neg = x_pos, x_neg
        # Train each layer
        for i, layer in enumerate(self.layers):
            print('training layer', i+1, '...')
            h_pos, h_neg = layer.train(h_pos, h_neg)


goodness function is defined as:
$$
g = \frac{1}{M} \sum_{j=1}^{M} y_j^2
$$


The loss function is defined as:

$$
\text{loss} = \frac{1}{N} \sum_{i=1}^{N} \log\left(1 + \exp\left(\text{torch.cat}\left([\text{threshold} - g_{\text{pos},i} , g_{\text{neg},i} - \text{threshold}]\right)\right)\right)
$$


In [None]:
class Layer(nn.Linear):
    def __init__(self, in_features, out_features,
                 bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.relu = torch.nn.ReLU()
        self.opt = Adam(self.parameters(), lr=learning_rate)
        self.th = threshold
        self.num_epochs = epochs


    def forward(self, x):
        # Normalize the input tensor along the second dimension
        x.div_(x.norm(2, 1, keepdim=True) + 1e-4)

        # Perform a linear transformation
        out = (x.mm(self.weight.T)+self.bias.unsqueeze(0))

        # Apply the ReLU activation function
        return out.relu()

    def train(self, x_pos, x_neg):
      # Iterate over the number of epochs
      for i in tqdm(range(self.num_epochs)):
          # Forward pass for positive and negative samples, calculate the Goodness fuction
          g_pos = self.forward(x_pos).pow(2).mean(1)
          g_neg = self.forward(x_neg).pow(2).mean(1)

          # Calculate the loss
          loss = torch.mean(torch.log(1 + torch.exp(torch.cat([self.th - g_pos , g_neg - self.th]))))

          # Zero the gradients before running the backward pass
          self.opt.zero_grad()

          # Backward pass: compute gradient of the loss with respect to all the learnable parameters
          loss.backward()

          # Calling the step function on an Optimizer makes an update to its parameters
          self.opt.step()

      # Return the forward pass of the positive and negative samples after training
      return self.forward(x_pos).detach(), self.forward(x_neg).detach()



# Train and Test the Network

In [None]:
#Load data
torch.manual_seed(1234)
train_loader, test_loader = Data_loader()

# Create network
net = NetWork(Net_layers)
x, y = next(iter(train_loader))
x, y = x.cuda(), y.cuda()
x_pos = y_OneHot_on_x(x, y)
x_neg = y_OneHot_on_x(x, get_y_neg(y))


# Train the network
net.train(x_pos, x_neg)

# Calculate and print the training accuracy
y_pred_train = net.predict(x).cpu().numpy()
train_accuracy = accuracy_score(y.cpu().numpy(), y_pred_train) * 100
print(f'Train accuracy: {train_accuracy:.2f}')

# Get the test data
x_te, y_te = next(iter(test_loader))
x_te, y_te = x_te.cuda(), y_te.cuda()

# Calculate and print the test accuracy
y_pred_test = net.predict(x_te).cpu().numpy()
test_accuracy = accuracy_score(y_te.cpu().numpy(), y_pred_test) * 100
print(f'Test accuracy: {test_accuracy:.2f}')

import gc
gc.collect()
torch.cuda.empty_cache()
del x, y, x_pos, x_neg, net, x_te, y_te

training layer 1 ...


100%|██████████| 1000/1000 [01:00<00:00, 16.59it/s]


training layer 2 ...


100%|██████████| 1000/1000 [00:40<00:00, 24.84it/s]


Train accuracy: 92.43
Test accuracy: 92.66
