In [None]:
import torch
from torchvision import datasets, transforms, models
from collections import OrderedDict
from torch import nn, optim
import torch.nn.functional as F
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

In [None]:
def train(model, train_loader, loss_fn, optimizer, device):
    # Prepare for training
    model.train()
    running_loss = 0
 
    # Use TQDM for interactive loading bars
    with tqdm(total=len(train_loader)) as pbar:
        for i, (inputs, labels) in enumerate(train_loader, 0):
            # Make image 3 channels and put on device
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = torch.cat([inputs, inputs, inputs], axis=1)

            # Run through model and update
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            # Track loss and update progress
            running_loss += loss.item()
            pbar.update(1)

    return running_loss / len(train_loader)

In [None]:
# Function for the validation pass
def validation(model, val_loader, loss_fn, device):
    # Prepare for validating
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        with tqdm(total=len(val_loader)) as pbar:
            for inputs, labels in iter(val_loader):
                # Make image 3 channels and put on device
                inputs, labels = inputs.to(device), labels.to(device)
                inputs = torch.cat([inputs, inputs, inputs], axis=1)

                # Run through model
                outputs = model(inputs)

                # Track loss and update progress
                val_loss += loss_fn(outputs, labels).item()

                # Update accuracy
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                pbar.update(1)
            
    return val_loss / len(val_loader), correct / total

In [None]:
# An impure function to train our model
def fit(epochs):
    train_losses = []
    test_losses = []
    accuracies = []

    for epoch in range(epochs):
        # Run training loop and validation loop
        train_loss = train(model, train_loader, criterion, optimizer, device)
        val_loss, accuracy = validation(model, test_loader, criterion, device)

        # Print result
        print("Epoch: {}/{}, Training Loss: {:.4f}, Test Loss: {:.4f}, Test Accuracy: {}".format(epoch + 1, epochs, train_loss, val_loss, accuracy))
        print('-' * 20)

        # Record results
        train_losses.append(train_loss)
        test_losses.append(val_loss)
        accuracies.append(accuracy)    

    print("Finished Training")
    return train_losses, test_losses, accuracies

In [None]:
def plot_metrics(train_losses, test_losses, accuracies):
    # Creating one figure with two subplots
    f, (ax1, ax2) = plt.subplots(1, 2, sharey=False, figsize=(10, 3))
    
    # Plot accuracies
    ax1.plot(acurracies)
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Accuracy')
    ax1.set_title('Test Accuracy')
    
    # Plot NLL losses
    ax2.plot(train_losses, label='Train Losses')
    ax2.plot(test_losses, label='Test Losses')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('loss')
    ax2.set_title('NLL Loss')
    
    plt.show()

# Defining our Models and Datasets 

In [None]:
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
test_set = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

In [None]:
# Get model and pretrained weights
model = models.resnet34(pretrained=True)
model

# Initial Training
We use `pretrained=True` so that we can make use of some weights from ImageNet pretraining that PyTorch makes available to us. We are going to freeze these weights and replace the output layers to retrain this network for the new problem. The frozen layers can be thought of as an 'image feature extractor', which we are using rather than starting with random weights.

In [None]:
# Freeze all layers
for param in model.parameters():
    param.requires_grad = False 

In [None]:
# Create output layers and replace output layers
fc = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(512,256)),
    ('relu', nn.ReLU()),
    ('fc2', nn.Linear(256,64)),
    ('output', nn.LogSoftmax(dim=1))
]))
model.fc = fc
model

In [None]:
# Set criterion and optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-3)

In [None]:
# Get device and put model on device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
import os
if 'EPOCHS' in os.environ:
    epochs = int(os.environ['EPOCHS'])
else:
    epochs = 10

In [None]:
train_losses, test_losses, acurracies = fit(epochs)
plot_metrics(train_losses, test_losses, acurracies)

# Fine Tuning
We've now retrained this model for our problem, but we can do even better. The convolutional layers are still optimized for extracting features from the kind of images that imagenet has. After unfreezing the model we can retrain the whole network to adapt better to this problem, but we need to be careful to use a slower learning rate or we'll lose the pretraining that we already have.

In [None]:
# Unfreeze all layers
for param in model.parameters():
    param.requires_grad = True 

In [None]:
# Reset our learning rate to make it slower
optimizer = optim.Adam(model.parameters(), lr=3e-5)

In [None]:
# Fine tune the layers at a slower learning rate
train_losses, test_losses, acurracies = fit(epochs)
plot_metrics(train_losses, test_losses, acurracies)