In [1]:
import torch
from torch.nn import Parameter
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim import SGD

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
from tqdm import tqdm_notebook

In [2]:
# Transformer function for image preprocessing
transforms_func = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])

# mnist train_set
mnist_train = MNIST('./data',train=True,download=True,transform=transforms_func)

# mnist test_set
mnist_test = MNIST('./data',train=False,transform=transforms_func)

In [3]:
train_len = int(0.9*mnist_train.__len__())
valid_len = mnist_train.__len__() - train_len
mnist_train, mnist_valid = torch.utils.data.random_split(mnist_train, lengths=[train_len, valid_len])

In [4]:
print("Size of:")
print("- Training-set:\t\t{}".format(mnist_train.__len__()))
print("- Validation-set:\t{}".format(mnist_valid.__len__()))
print("- Test-set:\t\t{}".format(mnist_test.__len__()))

Size of:
- Training-set:		54000
- Validation-set:	6000
- Test-set:		10000


In [None]:
img_shape = (28,28)

In [None]:
def plot_images(images, cls_true, cls_pred=None):
    assert len(images) == len(cls_true) == 9
    
    # Create figure with 3x3 sub-plots.
    fig, axes = plt.subplots(3, 3)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)

    for i, ax in enumerate(axes.flat):
        # Plot image.
        ax.imshow(images[i].reshape(img_shape), cmap='binary')

        # Show true and predicted classes.
        if cls_pred is None:
            xlabel = "True: {0}".format(cls_true[i])
        else:
            xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])

        ax.set_xlabel(xlabel)
        
        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])
        
    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()

In [None]:
# Get the first images from the test-set.
images = mnist_train.train_data[0:9]

# Get the true classes for those images.
cls_true = mnist_train.train_labels[0:9]

# Plot the images and labels using our helper-function above.
plot_images(images=images, cls_true=cls_true)

In [2]:
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.fc = nn.Linear(784,10)
        self.init_weights_and_bias()
#         self.weight = Parameter(torch.zeros((784, 10),dtype=torch.float32,requires_grad=True))
#         self.bias = Parameter(torch.zeros((10),dtype=torch.float32,requires_grad=True))
    
    def init_weights_and_bias(self):
        torch.nn.init.xavier_normal_(self.fc.weight)
        torch.nn.init.zeros_(self.fc.bias)
    
    def get_weights(self):
        return self.weight
    
    def forward(self,x):
#         out = torch.addmm(self.bias, x, self.weight)
        out = self.fc(x)
        out = F.log_softmax(out,1)
        return out

In [None]:
def train(model,device,train_loader,valid_loader,optimizer,epoch):
    model.train()
    for i,(data,target) in tqdm_notebook(enumerate(train_loader),total=train_loader.__len__()):
        data = torch.reshape(data,(-1,784))
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
#         set_trace()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
    test(model,device,valid_loader)

In [None]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i,(data, target) in tqdm_notebook(enumerate(test_loader),total=test_loader.__len__()):
            data, target = data.to(device), target.to(device)
            data = torch.reshape(data,(-1,784))
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
train_loader = DataLoader(mnist_train,batch_size=64,shuffle=True)

In [None]:
valid_loader = DataLoader(mnist_valid,batch_size=1024,shuffle=True)

In [None]:
test_loader = DataLoader(mnist_test,batch_size=1024,shuffle=True)

In [None]:
model = LinearModel().to("cpu")

In [None]:
epochs = 1
optimizer = SGD(model.parameters(),lr=0.5)
for epoch in range(epochs):
        train(model,'cpu',train_loader,valid_loader,optimizer,epoch)
        test(model,'cpu',test_loader)

In [None]:
def plot_weights():
    # Get the values for the weights from the TensorFlow variable.
    w = model.get_weights()
    w = w.numpy()
    img_shape = (28,28)
    
    # Get the lowest and highest values for the weights.
    # This is used to correct the colour intensity across
    # the images so they can be compared with each other.
    w_min = np.min(w)
    w_max = np.max(w)

    # Create figure with 3x4 sub-plots,
    # where the last 2 sub-plots are unused.
    fig, axes = plt.subplots(3, 4)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)

    for i, ax in enumerate(axes.flat):
        # Only use the weights for the first 10 sub-plots.
        if i<10:
            # Get the weights for the i'th digit and reshape it.
            # Note that w.shape == (img_size_flat, 10)
            image = w[:, i].reshape(img_shape)

            # Set the label for the sub-plot.
            ax.set_xlabel("Weights: {0}".format(i))

            # Plot the image.
            ax.imshow(image, vmin=w_min, vmax=w_max, cmap='seismic')

        # Remove ticks from each sub-plot.
        ax.set_xticks([])
        ax.set_yticks([])
        
    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()

In [None]:
plot_weights()

In [None]:
def plot_example_errors():
    # Use TensorFlow to get a list of boolean values
    # whether each test-image has been correctly classified,
    # and a list for the predicted class of each image.
    #correct, cls_pred = session.run([correct_prediction, y_pred_cls],
                                    #feed_dict=feed_dict_test)

    # Negate the boolean array.
    cls_pred,incorrect = get_incorrect_samples(model,'cpu',test_loader)
    #incorrect = ~incorrect
    # Get the images from the test-set that have been
    # incorrectly classified.
    images = mnist_test.test_data[incorrect].detach().numpy()
    
    # Get the predicted classes for those images.
    cls_pred = cls_pred[incorrect]

    # Get the true classes for those images.
    cls_true = mnist_test.test_labels[incorrect].detach().numpy()
    
    # Plot the first 9 images.
    plot_images(images=images[0:9],
                cls_true=cls_true[0:9],
                cls_pred=cls_pred[0:9])

In [None]:
def get_incorrect_samples(model, device, test_loader):
    model.eval()
    prediction = []
    correct = []
    with torch.no_grad():
        for i,(data, target) in tqdm_notebook(enumerate(test_loader),total=test_loader.__len__()):
            data, target = data.to(device), target.to(device)
            correct.extend(target)
            data = torch.reshape(data,(-1,784))
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            prediction.extend(pred.squeeze())
    out = torch.eq(torch.Tensor(prediction).type(torch.FloatTensor),torch.Tensor(correct).type(torch.FloatTensor))
    return np.asarray(prediction),out.detach().numpy()

In [3]:
model  = LinearModel()

In [5]:
list(model.parameters())

[Parameter containing:
 tensor([[-0.0590, -0.0213,  0.1101,  ..., -0.1071, -0.0308, -0.0339],
         [ 0.0345,  0.0200,  0.0790,  ..., -0.0147, -0.0371, -0.0519],
         [ 0.0305, -0.0293,  0.0666,  ..., -0.1485, -0.0646, -0.0241],
         ...,
         [-0.0134, -0.0260, -0.0246,  ..., -0.0028,  0.0191, -0.0078],
         [ 0.0059,  0.0380,  0.0092,  ..., -0.0378, -0.0561, -0.0341],
         [ 0.0601,  0.1145, -0.0237,  ..., -0.0166,  0.0450, -0.0317]],
        requires_grad=True), Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)]