In [9]:
import torch
import torch.nn as nn
from torch import FloatTensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader

# torchvision: popular datasets, model architectures, 
# and common image transformations for computer vision.

from torchvision import datasets
from torchvision.transforms import transforms

from random import randint
from random import shuffle
import numpy as np
import matplotlib.pyplot as plt

label_1, label_2 = 4, 9

# MNIST training data
train_set = datasets.MNIST(root='./mnist_data/', 
                           train=True, 
                           transform=transforms.ToTensor(), download=True)

# Use data with two labels
idx = (train_set.targets == label_1) + (train_set.targets == label_2)
train_set.data = train_set.data[idx]
train_set.targets = train_set.targets[idx]
train_set.targets[train_set.targets == label_1] = -1
train_set.targets[train_set.targets == label_2] = 1

# MNIST testing data
test_set = datasets.MNIST(root='./mnist_data/', 
                          train=False, transform=transforms.ToTensor())

# Use data with two labels
idx = (test_set.targets == label_1) + (test_set.targets == label_2)
test_set.data = test_set.data[idx]
test_set.targets = test_set.targets[idx]
test_set.targets[test_set.targets == label_1] = -1
test_set.targets[test_set.targets == label_2] = 1


In [10]:
class LR(nn.Module) :
    # MNIST data is 28x28 images
    def __init__(self, input_dim=28*28) :
        super().__init__()
        self.linear = nn.Linear(input_dim, 1, bias=False)

    def forward(self, x) :
        return self.linear(x.float().view(-1, 28*28))


model_with_SOQ_loss = LR()
model_with_KLDiv_loss = LR()

def sum_of_squares(output, target):
    output = FloatTensor(output)
    target = target.unsqueeze(-1)
    result = torch.mean(0.5*(1-target)*( 
        (1-torch.sigmoid(-output))**2 + (torch.sigmoid(output))**2 )
                + 0.5*(1+target)*( (1-torch.sigmoid(output))**2 
                                  + (torch.sigmoid(-output))**2 ))

    return result

def logistic_loss(output, target):

    return torch.mean(-torch.nn.functional.logsigmoid(
        target.reshape(-1)*output.reshape(-1)))


loss_function_1 = sum_of_squares
loss_function_2 = logistic_loss

optimizer_1 = torch.optim.SGD(model_with_SOQ_loss.parameters(), 
                              lr=255*1e-4)

optimizer_2 = torch.optim.SGD(model_with_KLDiv_loss.parameters(), 
                              lr=255*1e-4)

batch_size = 64


In [11]:
train_loader = DataLoader(dataset=train_set, batch_size= batch_size , shuffle=True)
import time

start = time.time()
iter_count = 0

for epoch in range(1000):
    for image,label in train_loader :
        iter_count += 1
        if iter_count > 1000:
            break

        # Clear previously computed gradient
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        # then compute gradient with forward and backward passes
        train_loss = loss_function_1(
            model_with_SOQ_loss(image), label.float())
        
        train_loss_2 = loss_function_2(
            model_with_KLDiv_loss(image), label.float())
        
        train_loss.backward()
        train_loss_2.backward()

        # perform SGD step (parameter update)
        optimizer_1.step()
        optimizer_2.step()

end = time.time()
print(f"Time ellapsed in training is: {end-start}")



Time ellapsed in training is: 6.977036237716675


In [12]:
test_loss, correct = 0, 0
test_loss_2, correct_2 = 0, 0
misclassified_ind = []
correct_ind = []
misclassified_ind_2 = []
correct_ind_2 = []

# Test data
test_loader = DataLoader(dataset=test_set, 
                         batch_size=1, shuffle=False)
# no need to shuffle test data

# Evaluate accuracy using test data
for ind, (image, label) in enumerate(test_loader) :

    # Forward pass
    output = model_with_SOQ_loss(image)
    output2 = model_with_KLDiv_loss(image)

    # Calculate cumulative loss
    test_loss += loss_function_1(output, 
                                 label.float()).item()
    
    test_loss_2 += loss_function_2(output2, 
                                   label.float()).item()

    # Make a prediction
    if output.item() * label.item() >= 0 :
        correct += 1
        correct_ind += [ind]
    else:
        misclassified_ind += [ind]

    # Make a prediction
    if output2.item() * label.item() >= 0 :
        correct_2 += 1
        correct_ind_2 += [ind]
    else:
        misclassified_ind_2 += [ind]

# Print out the results
print('---- Using Sum of Squares as loss_func ----')
print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss /len(test_loader), correct, len(test_loader),
        100. * correct / len(test_loader)))
print('---- Using KL Divergence as loss_func ----')
print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss_2 /len(test_loader), correct_2, len(test_loader),
        100. * correct_2 / len(test_loader)))

---- Using Sum of Squares as loss_func ----
[Test set] Average loss: 0.0840, Accuracy: 1909/1991 (95.88%)

---- Using KL Divergence as loss_func ----
[Test set] Average loss: 0.1457, Accuracy: 1905/1991 (95.68%)

