In [None]:
import os
import csv
import torch
from torch import nn
from torch.utils import data
from torch.autograd import Variable
from torchvision import datasets, transforms

In [None]:
use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Use CUDA.")
batch_size = 32
shuffle = True
learning_rate = 1e-3
epochs = 10
log_interval = 1000
log_filename = 'logs/training.csv'

In [None]:
class Logger():
    """
    CSV Logger for training logs
    """

    def __init__(self, filename):
        self.filename = filename

    def __enter__(self):
        path = "logs/"
        try:
            os.makedirs(path)
        except OSError:
            if not os.path.isdir(path):
                raise
        try:
            os.remove(self.filename)
        except OSError:
            pass
        self.file = open(self.filename, 'w')
        fieldnames = [
            'epoch',
            'batch_index',
            'loss',
            'accuracy',
        ]
        self.writer = csv.DictWriter(self.file, fieldnames=fieldnames)
        self.writer.writeheader()
        return self

    def append(self, epoch, batch_index, loss, accuracy):
        """
        Append training log to csv
        """
        self.writer.writerow({
            'epoch': epoch + 1,
            'batch_index': batch_index,
            'loss': "{:.9f}".format(loss),
            'accuracy': accuracy,
        })
        self.file.flush()
        os.fsync(self.file)
        os.sync()

    def __exit__(self, exc_type, exc_value, traceback):
        if not self.file.closed:
            self.file.close()

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 5, 2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Conv2d(16, 32, 5, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d()
        )

        self.linear = nn.Sequential(
            nn.Linear(32 * 4 * 4, 512),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(512, 10),
        )

    def forward(self, input_):
        output = self.conv(input_)
        output = output.view(-1, 32 * 4 * 4)
        output = self.linear(output)
        return output

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.13066047740239478,), (0.3081078087569972,))
])

train_dataloader = data.DataLoader(datasets.MNIST('data',
                                                  train=True,
                                                  transform=transform,
                                                  download=True),
                                   batch_size=batch_size,
                                   shuffle=shuffle)

test_dataloader = data.DataLoader(datasets.MNIST('data',
                                                  train=False,
                                                  transform=transform,
                                                  download=True),
                                   batch_size=batch_size,
                                   shuffle=shuffle)

In [None]:
net = Net()
if use_gpu:
    net.cuda()

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
cross_entropy = nn.CrossEntropyLoss()

In [None]:
def train(epoch, logger):
    for train_index, (train_x, train_y) in enumerate(train_dataloader):
        net.train()
        
        train_x = Variable(train_x)
        train_y = Variable(train_y)
        
        if use_gpu:
            train_x = train_x.cuda()
            train_y = train_y.cuda()
        
        optimizer.zero_grad()
        loss = cross_entropy(net(train_x), train_y)
        
        loss.backward()
        optimizer.step()
        
        if train_index % log_interval == 0:
            print('Train epoch: {}, batch index: {}, loss: {}.'.format(epoch, train_index, float(loss.data)))
            test_acc = evaluate()
            logger.append(epoch, train_index, float(loss.data), test_acc)
            
def evaluate():
    net.eval()

    correct = 0
    for test_index, (test_x, test_y) in enumerate(test_dataloader):
        test_x = Variable(test_x, volatile=True)
        
        if use_gpu:
            test_x = test_x.cuda()
            test_y = test_y.cuda()
        
        _, max_indices = net(test_x).data.max(1)
        correct += int((max_indices == test_y).sum())

    test_acc = correct / len(test_dataloader.dataset)
    print('Test accuracy: {}.'.format(float(test_acc)))
    return test_acc

In [None]:
with Logger(log_filename) as logger:
    for epoch in range(epochs):
        train(epoch, logger)