In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from torch.utils import data

import matplotlib.pyplot as plt
import numpy as np

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [3]:
EPOCHS     = 40
BATCH_SIZE = 64

In [4]:
transform=transforms.Compose([
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,))
                   ])

<h1>dataset 

In [5]:
trainset = datasets.MNIST(
    root      = '../data/', 
    train     = True,
    download  = True,
    transform = transform
)
testset = datasets.MNIST(
    root      = '../data/', 
    train     = False,
    download  = True,
    transform = transform
)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
train_loader = data.DataLoader(
    dataset     = trainset,
    batch_size  = BATCH_SIZE, shuffle=True
)
test_loader = data.DataLoader(
    dataset     = testset,
    batch_size  = BATCH_SIZE, shuffle=True
)

In [7]:
dataiter = iter(train_loader)
images, labels = next(dataiter)

In [8]:
images[0].shape

torch.Size([1, 28, 28])

In [9]:
print(labels)

tensor([6, 1, 1, 5, 4, 0, 3, 6, 1, 6, 5, 6, 3, 4, 6, 4, 7, 0, 2, 8, 0, 9, 9, 6,
        8, 0, 5, 4, 5, 9, 9, 0, 9, 1, 2, 6, 1, 7, 8, 1, 8, 6, 5, 4, 1, 9, 7, 6,
        0, 1, 6, 7, 6, 1, 6, 9, 2, 0, 3, 7, 8, 3, 7, 4])


In [10]:
CLASSES = {
    0: 'zero',
    1: 'one',
    2: 'two',
    3: 'three',
    4: 'four',
    5: 'five',
    6: 'six',
    7: 'seven',
    8: 'eight',
    9: 'nine'
}


for label in labels:
    index = label.item()
    print(CLASSES[index])

six
one
one
five
four
zero
three
six
one
six
five
six
three
four
six
four
seven
zero
two
eight
zero
nine
nine
six
eight
zero
five
four
five
nine
nine
zero
nine
one
two
six
one
seven
eight
one
eight
six
five
four
one
nine
seven
six
zero
one
six
seven
six
one
six
nine
two
zero
three
seven
eight
three
seven
four


In [None]:
idx = 1

item_img = images[idx]
item_npimg = item_img.squeeze().numpy()
plt.title(CLASSES[labels[idx].item()])
print(item_npimg.shape)
plt.imshow(item_npimg, cmap='gray')
plt.show()

<h1>Linear model

In [11]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)  
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x

In [12]:
model     = Net().to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [13]:
def train(model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 200 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [14]:
def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            test_loss += F.cross_entropy(output, target,
                                         reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy

In [15]:
for epoch in range(1, EPOCHS + 1):
    train(model, train_loader, optimizer, epoch)
    test_loss, test_accuracy = evaluate(model, test_loader)
    
    print('[{}] Test Loss: {:.4f}, Accuracy: {:.2f}%'.format(
          epoch, test_loss, test_accuracy))

[1] Test Loss: 2.2467, Accuracy: 32.41%
[2] Test Loss: 1.7541, Accuracy: 56.66%
[3] Test Loss: 1.0972, Accuracy: 68.34%
[4] Test Loss: 0.8502, Accuracy: 73.79%
[5] Test Loss: 0.7362, Accuracy: 76.72%
[6] Test Loss: 0.6704, Accuracy: 78.52%
[7] Test Loss: 0.6259, Accuracy: 80.10%
[8] Test Loss: 0.5959, Accuracy: 80.62%
[9] Test Loss: 0.5674, Accuracy: 81.88%
[10] Test Loss: 0.5504, Accuracy: 82.50%
[11] Test Loss: 0.5330, Accuracy: 82.67%
[12] Test Loss: 0.5158, Accuracy: 83.32%
[13] Test Loss: 0.4988, Accuracy: 83.86%
[14] Test Loss: 0.4884, Accuracy: 84.24%
[15] Test Loss: 0.4706, Accuracy: 84.81%
[16] Test Loss: 0.4548, Accuracy: 85.22%
[17] Test Loss: 0.4458, Accuracy: 85.80%
[18] Test Loss: 0.4303, Accuracy: 86.18%
[19] Test Loss: 0.4235, Accuracy: 86.40%
[20] Test Loss: 0.4096, Accuracy: 86.96%
[21] Test Loss: 0.3974, Accuracy: 87.21%
[22] Test Loss: 0.3930, Accuracy: 87.48%
[23] Test Loss: 0.3795, Accuracy: 88.01%
[24] Test Loss: 0.3721, Accuracy: 88.23%
[25] Test Loss: 0.3656, A

[29] Test Loss: 0.3359, Accuracy: 89.29%
[30] Test Loss: 0.3275, Accuracy: 89.76%
[31] Test Loss: 0.3224, Accuracy: 89.96%
[32] Test Loss: 0.3156, Accuracy: 89.96%
[33] Test Loss: 0.3099, Accuracy: 90.42%
[34] Test Loss: 0.3035, Accuracy: 90.54%
[35] Test Loss: 0.2990, Accuracy: 90.99%
[36] Test Loss: 0.2926, Accuracy: 90.96%
[37] Test Loss: 0.2857, Accuracy: 91.26%
[38] Test Loss: 0.2851, Accuracy: 91.13%
[39] Test Loss: 0.2776, Accuracy: 91.45%
[40] Test Loss: 0.2708, Accuracy: 91.67%


In [None]:
columns = 6
rows = 6
fig = plt.figure(figsize=(10,10))
 
model.eval()
for i in range(1, columns*rows+1):
    data_idx = np.random.randint(len(testset))
    input_img = testset[data_idx][0].unsqueeze(dim=0).to(DEVICE) 
 
    output = model(input_img)
    _, argmax = torch.max(output, 1)
    pred = CLASSES[argmax.item()]
    label = CLASSES[testset[data_idx][1]]
    
    fig.add_subplot(rows, columns, i)
    if pred == label:
        plt.title(pred + ', right')
        cmap = 'Blues'
    else:
        plt.title('N ' + pred + ' B ' +  label)
        cmap = 'Reds'
    plot_img = testset[data_idx][0][0,:,:]
    plt.imshow(plot_img, cmap=cmap)
    plt.axis('off')
    
plt.show() 