In [None]:
import torch
import torchvision

import matplotlib.pyplot as plt

# interactive mode
plt.ion()

import train_helper
import numpy as np

# autoreload external code
%load_ext autoreload
%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
dataloaders, class_names = train_helper.load_data()

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

for _ in range(4):
    # Get a batch of training data
    inputs, classes = next(iter(dataloaders['train']))

    # Make a grid from batch
    out = torchvision.utils.make_grid(inputs)

    imshow(out, title=[class_names[x] for x in classes])

In [None]:
# train a new model
n_epochs = 30
model, criterion, optimizer, scheduler = train_helper.get_model(dataloaders, n_epochs)
model = train_helper.train_model(model, criterion, optimizer, scheduler, dataloaders, n_epochs=n_epochs)

In [None]:
# load an existing model from disk
model = train_helper.load_model('mask_model.pt', map_location=device)

In [None]:
from matplotlib.pyplot import figure

def visualize_model(model, device, dataloaders, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['test']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            #fig, ax = plt.subplots(num_images//3,2)
            #ax = ax.ravel()
            #figure(num=None, figsize=(10, 8), dpi=80, facecolor='w', edgecolor='k')

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images, 1, images_so_far)
                plt.figure(figsize=(5,4))
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)


In [None]:
def test_performance(model, device, dataloaders):
    corrects = 0
    total = 0
    test_data = dataloaders['test']
    misclass = 0
    with torch.no_grad():
        for i, (images,labels) in enumerate (test_data):
            
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0) 
            correct = (predicted == labels).sum().item()
            if correct == 1:
                corrects += correct
            else:
                misclass += 1

    print('Accuracy of the network on',total,'test images of which misclass',misclass,': %d %%' % (
        100 * corrects / total))

In [None]:
test_performance(model, device, dataloaders)

In [None]:
visualize_model(model, device, dataloaders, 10)