In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

In [None]:
mnist_training_data = torchvision.datasets.MNIST('data/', train=True, download=True,
                                                transform=transforms.ToTensor())

In [5]:
def train(model, use_cuda, n_epochs):
    if use_cuda:
        model.cuda()
    else:
        model.cpu()
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    optimizer = optim.RMSprop(model.parameters())
    criterion = nn.CrossEntropyLoss()
    
    # Train the model for n_epochs epochs, iterating on the data in batches

    # store metrics
    training_accuracy_history = np.zeros([n_epochs, 1])
    training_loss_history = np.zeros([n_epochs, 1])
    validation_accuracy_history = np.zeros([n_epochs, 1])
    validation_loss_history = np.zeros([n_epochs, 1])
    
    start_time = time.time()
    for epoch in range(n_epochs):
        print(f'Epoch {epoch+1}/10:', end='')
        train_total = 0
        train_correct = 0
        # train
        model.train()
        for i, data in enumerate(training_data_loader):
            images, labels = data
            if use_cuda:
                images, labels = images.cuda(), labels.cuda()
            optimizer.zero_grad()
            # forward pass
            output = model(images)
            # calculate categorical cross entropy loss
            loss = criterion(output, labels)
            # backward pass
            loss.backward()
            optimizer.step()

            # track training accuracy
            _, predicted = torch.max(output.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            # track training loss
            training_loss_history[epoch] += loss.item()
            # progress update after 180 batches (~1/10 epoch for batch size 32)
            if i % 180 == 0: print('.',end='')
        training_loss_history[epoch] /= len(training_data_loader)
        training_accuracy_history[epoch] = train_correct / train_total
        print(f'\n\tloss: {training_loss_history[epoch,0]:0.4f}, acc: {training_accuracy_history[epoch,0]:0.4f}',end='')

        # validate
        test_total = 0
        test_correct = 0
        with torch.no_grad():
            model.eval()
            for i, data in enumerate(test_data_loader):
                images, labels = data
                if use_cuda:
                    images, labels = images.cuda(), labels.cuda()
                # forward pass
                output = model(images)
                # find accuracy
                _, predicted = torch.max(output.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()
                # find loss
                loss = criterion(output, labels)
                validation_loss_history[epoch] += loss.item()
            validation_loss_history[epoch] /= len(test_data_loader)
            validation_accuracy_history[epoch] = test_correct / test_total
        print(f', val loss: {validation_loss_history[epoch,0]:0.4f}, val acc: {validation_accuracy_history[epoch,0]:0.4f}')
    elapsed_time = time.time() - start_time
    print(elapsed_time, "seconds elapsed")