In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision

import numpy as np
import pandas as pd

In [2]:
def embedding(data, y, num_classes):
    input = data.clone().detach()
    one_hot = F.one_hot(y, num_classes=num_classes)
    input[:,:,0,:num_classes] = one_hot.view(one_hot.shape[0], 1, one_hot.shape[1])
    return input

class MNISTDataset(Dataset):
    def __init__(self, images, labels, size=(28, 28), num_classes=10, transforms=None, train=True):
        self.X = images
        self.y = labels

        self.w = size[0]
        self.h = size[1]

        if transforms is not None:
            self.transforms = transforms
        else:
            self.transforms = torchvision.transforms.Compose(
                                  [
                                  torchvision.transforms.ToPILImage(),
                                  torchvision.transforms.ToTensor(),
                                  torchvision.transforms.Normalize((0.5, ), (0.5, ))
                              ])
            
        self.train = train
         
    def __len__(self):
        return (len(self.X))

    def __getitem__(self, i):
        data = self.X.iloc[i, :]
        data = np.asarray(data).astype(np.uint8).reshape(self.h, self.w, 1)
        
        if self.transforms:
            data = self.transforms(data)

        return data, self.y[i]

In [3]:
class FFLinear(nn.Module):
    def __init__(self, in_features, out_features,
                 num_epochs = 1000, bias=True):
        super(FFLinear, self).__init__()

        self.linear = nn.Linear(in_features = in_features, 
                                out_features = out_features, 
                                bias = bias)
        
        nn.init.xavier_uniform_(self.linear.weight.data, gain=1.0)
        nn.init.zeros_(self.linear.bias.data)

        self.relu = nn.ReLU()
        self.optimizer = torch.optim.Adam(self.linear.parameters(), lr=0.03)
        self.threshold = 2.0
        self.num_epochs = num_epochs

    def forward(self, x):
        x_norm = x.norm(2, 1, keepdim=True)
        x_dir = x / (x_norm + 1e-4)
        res = self.linear(x_dir)
        return self.relu(res)

    def forward_forward(self, x_pos, x_neg):

        for i in range(self.num_epochs):
            x_pos.requires_grad = True
            x_neg.requires_grad = True

            g_pos = torch.mean(torch.pow(self.forward(x_pos), 2), 1)
            g_neg = torch.mean(torch.pow(self.forward(x_neg), 2), 1)

            loss = self.criterion(g_pos, g_neg)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        with torch.no_grad():
            return self.forward(x_pos), self.forward(x_neg), loss

    def criterion(self, g_pos, g_neg):
        return torch.mean(
            torch.log(
                1 + torch.exp(torch.cat([-g_pos + self.threshold, g_neg - self.threshold], 0))
            )
        )
            

In [4]:
class FFNetwork(nn.Module):
    def __init__(self, in_features=784, num_classes=10):
        super(FFNetwork, self).__init__()

        self.in_features = in_features
        self.num_classes = num_classes

        self.ff_1 = FFLinear(in_features=self.in_features, out_features=512)
        self.ff_2 = FFLinear(in_features=512, out_features=512)
        self.ff_3 = FFLinear(in_features=512, out_features=512)
        self.ff_4 = FFLinear(in_features=512, out_features=512)

        self.layers = [self.ff_1, self.ff_2, self.ff_3, self.ff_4]

    def train(self, data_pos, data_neg):
        h_pos = data_pos.view(-1, self.in_features)
        h_neg = data_neg.view(-1, self.in_features)

        total_loss = 0
        total_cnt = 0

        for idx, layer in enumerate(self.layers):
            if isinstance(layer, FFLinear):
                print(f"Training layer {idx+1} now")
                h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg)
                
                total_loss += loss
                total_cnt += 1

            else:
                print(f"Passing layer {idx+1} now")
                x = layer(x)

        print('Loss: ', total_loss / total_cnt)

    def predict(self, data):
        with torch.no_grad():
            goodness_per_label = []
            for cls in range(self.num_classes):
                lbl = torch.tensor([cls] * data.shape[0])
                input = embedding(data, lbl, self.num_classes)

                h = input.view(-1, self.in_features)

                goodness = []
                for layer in self.layers:
                    h = layer(h)
                    goodness += [h.pow(2).mean(1)]
                goodness_per_label += [sum(goodness).unsqueeze(1)]

            goodness_per_label = torch.cat(goodness_per_label, 1)
            return goodness_per_label.argmax(1)

In [5]:
num_classes = 10
batch_size = 10000

df_train = pd.read_csv('./sample_data/mnist_train_small.csv', header=None)
df_test = pd.read_csv('./sample_data/mnist_test.csv', header=None)

train_labels = df_train.iloc[:, 0]
train_images = df_train.iloc[:, 1:]
test_labels = df_test.iloc[:, 0]
test_images = df_test.iloc[:, 1:]

custom_transform = torchvision.transforms.Compose(
                    [
                    torchvision.transforms.ToPILImage(),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize((0.5, ), (0.5, ))
                ])

train_data = MNISTDataset(train_images, train_labels, transforms=custom_transform)
test_data = MNISTDataset(test_images, test_labels, transforms=custom_transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True)

model = FFNetwork().cuda() if torch.cuda.is_available() else FFNetwork()

In [6]:
print("Start training...")
for data, lbl in train_loader:

    data_pos = data.clone().detach()
    lbl_pos = lbl.clone().detach()
    data_pos = embedding(data_pos, lbl, num_classes=num_classes)
    
    data_neg = data.clone().detach()
    lbl_neg = torch.from_numpy(np.random.choice(num_classes, data.shape[0]))
    data_neg = embedding(data_neg, lbl_neg, num_classes=num_classes)

    if torch.cuda.is_available():
        data_pos, data_neg = data_pos.cuda(), data_neg.cuda()

    model.train(data_pos, data_neg)

Start training...
Training layer 1 now
Training layer 2 now
Training layer 3 now
Training layer 4 now
Loss:  tensor(0.4684, device='cuda:0', grad_fn=<DivBackward0>)
Training layer 1 now
Training layer 2 now
Training layer 3 now
Training layer 4 now
Loss:  tensor(0.3640, device='cuda:0', grad_fn=<DivBackward0>)


In [7]:
print("Start testing...")
predictions = []
groundtruths = []

for i, (data_test, lbl_test) in enumerate(test_loader):
    if torch.cuda.is_available():
        data_test = data_test.cuda()

    prediction = model.predict(data_test).item()
    groundtruth = lbl_test.item()

    predictions.append(prediction)
    groundtruths.append(groundtruth)

from sklearn.metrics import f1_score
print("F1-score: ", f1_score(groundtruths, predictions, average='macro'))

Start testing...
F1-score:  0.9109228917932397
