In [None]:
import torch
from torch import nn
from torch.utils import data
from torch.autograd import Variable
from torchvision import datasets, transforms

In [None]:
batch_size = 128
shuffle = True
use_cuda = torch.cuda.is_available()
learning_rate = 1e-3
epochs = 10

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Conv2d(32, 64, 5, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d()
        )

        self.linear = nn.Sequential(
            nn.Linear(64 * 4 * 4, 1024),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(1024, 10),
        )

    def forward(self, input_):
        output = self.conv(input_)
        output = output.view(-1, 64 * 4 * 4)
        output = self.linear(output)
        return output

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.13066047740239478,), (0.3081078087569972,))
])

train_dataloader = data.DataLoader(datasets.MNIST('data',
                                                  train=True,
                                                  transform=transform,
                                                  download=True),
                                   batch_size=batch_size,
                                   shuffle=shuffle)

test_dataloader = data.DataLoader(datasets.MNIST('data',
                                                  train=False,
                                                  transform=transform,
                                                  download=True),
                                   batch_size=batch_size,
                                   shuffle=shuffle)

In [None]:
net = Net()

if use_cuda:
    net.cuda()
    
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
cross_entropy = nn.CrossEntropyLoss()

In [None]:
def train():
    net.train()
    losses = []
    for train_index, (train_x, train_y) in enumerate(train_dataloader):
        train_x = Variable(train_x)
        train_y = Variable(train_y)
        
        if use_cuda:
            train_x = train_x.cuda()
            train_y = train_y.cuda()
        
        loss = cross_entropy(net(train_x), train_y)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        losses.append(loss)
        
    return torch.cat(losses).mean()

def evaluate():
    net.eval()

    correct = 0
    for test_index, (test_x, test_y) in enumerate(test_dataloader):
        test_x = Variable(test_x, volatile=True)
        test_y = Variable(test_y)

        if use_cuda:
            test_x = test_x.cuda()
            test_y = test_y.cuda()

        _, max_indices = net(test_x).max(1)
        correct += int((max_indices == test_y).sum())

    testing_accuracy = correct / len(test_dataloader.dataset)

    return testing_accuracy

In [None]:
for epoch in range(epochs):
    train_avg_loss = train()
    test_acc = evaluate()
    print('Epoch: {}, train loss: {}, test acc: {}.'.format(epoch, float(train_avg_loss), test_acc))