In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torchvision.models as models
import random as random

In [2]:
n_labels = 10
cuda = torch.cuda.is_available()
device = "cuda:0" if cuda else "cpu"

def onehot(k):
    """
    Converts a number to its one-hot or 1-of-k representation
    vector.
    :param k: (int) length of vector
    :return: onehot function
    """
    def encode(label):
        y = torch.zeros(k)
        if label < k:
            y[label] = 1
        return y
    return encode

def get_mnist(location="./", batch_size=128):
    from functools import reduce
    from operator import __or__
    from torch.utils.data.sampler import SubsetRandomSampler
    from torchvision.datasets import MNIST
    import torchvision.transforms as transforms
    
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        ])
    
    img_transform = lambda x: compose(x).view(-1)
    

    mnist_train = MNIST(location, train=True, download=True,transform=img_transform)
    mnist_valid = MNIST(location, train=False, download=True,transform=img_transform)

    # Dataloaders for MNIST
    labelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size)
    
    validation = torch.utils.data.DataLoader(mnist_valid, batch_size=batch_size)

    return labelled, validation

In [3]:
labelled, valid = get_mnist()

In [4]:
class Classifier(nn.Module):
    def __init__(self, in_dim=784):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_features=in_dim, out_features=256),
                                 nn.ReLU(),
#                                  nn.Dropout(0.3),
                                 nn.Linear(in_features=256, out_features=256),
                                 nn.ReLU(),
#                                  nn.Dropout(0.3),
                                 nn.Linear(in_features=256, out_features=256),
                                 nn.ReLU(),
#                                  nn.Dropout(0.3),
                                 nn.Linear(in_features=256, out_features=10),
                                 nn.LogSoftmax(dim=-1))
        
    def forward(self,x):        
        return self.net(x)

In [5]:
classifier = Classifier().to(device)

In [6]:
optimizerC = torch.optim.Adam(classifier.parameters(), lr=1e-5)

In [7]:
BCE = nn.BCELoss()
NLL = nn.NLLLoss()

MSE = nn.MSELoss()
NUM_EPOCH = 100

def train():
    loss_c = []
    for epoch in range(NUM_EPOCH):
        for idx, batch_data in enumerate(labelled):
            batch_size = batch_data[0].shape[0]            
            
            # retrieve real data from loader
            real_images =batch_data[0].cuda()
            classes_real = batch_data[1].cuda()
            
            # Train classifier on fake and real data
            prediction_real = classifier(real_images)
            
            
            
#             lossC = MSE(prediction_real, classes_real)
            lossC = NLL(prediction_real, classes_real)
            
            lossC.backward()
            
            optimizerC.step()           
            loss_c.append(lossC.item())

        print("Epoch: ", epoch, " Closs: ", np.mean(loss_c[-20:]))
        print("Accuracy:", test())
            

In [8]:
def test():
    r =0
    length = 0 
    for idx, batch_data in enumerate(valid):
        batch_size = batch_data[0].shape[0]
        real_images = batch_data[0].cuda()
        classes_real = batch_data[1].cuda()
#         print(classes_real.argmax(dim=-1))

        predictions_real = classifier(real_images).argmax(dim=-1)
#         print(predictions_real)
#         r += torch.sum(abs(predictions_real-classes_real)).item()
        r += predictions_real.eq(classes_real).sum().item()
        length += batch_size
    return r/length

In [9]:
train()

Epoch:  0  Closs:  1.127143558859825
Accuracy: 0.6754
Epoch:  1  Closs:  1.2023566335439682
Accuracy: 0.5879
Epoch:  2  Closs:  0.7634003609418869
Accuracy: 0.7385
Epoch:  3  Closs:  0.4577137365937233
Accuracy: 0.831
Epoch:  4  Closs:  0.482887626439333
Accuracy: 0.8238
Epoch:  5  Closs:  0.48293308094143866
Accuracy: 0.8006
Epoch:  6  Closs:  0.4049548789858818
Accuracy: 0.8297
Epoch:  7  Closs:  0.38354877270758153
Accuracy: 0.8445
Epoch:  8  Closs:  0.3232653960585594
Accuracy: 0.8564
Epoch:  9  Closs:  0.29146419167518617
Accuracy: 0.8825
Epoch:  10  Closs:  0.2869791928678751
Accuracy: 0.8821
Epoch:  11  Closs:  0.3081853505223989
Accuracy: 0.8779
Epoch:  12  Closs:  0.28983214870095253
Accuracy: 0.8848
Epoch:  13  Closs:  0.2339403409510851
Accuracy: 0.9062
Epoch:  14  Closs:  0.23218420408666135
Accuracy: 0.9098
Epoch:  15  Closs:  0.19537619296461345
Accuracy: 0.9183
Epoch:  16  Closs:  0.20013975352048874
Accuracy: 0.9211
Epoch:  17  Closs:  0.20421941792592407
Accuracy: 0.91