In [None]:
import torch
import torchvision

import matplotlib.pyplot as plt

# interactive mode
plt.ion()

import train_helper2 as train_helper

# autoreload external code
%load_ext autoreload
%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
dataloaders, class_names = train_helper.load_data(batch_size=1)

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

for _ in range(5):
    # Get a batch of training data
    inputs, classes = next(iter(dataloaders['train']))

    # Make a grid from batch
    out = torchvision.utils.make_grid(inputs)

    imshow(out, title=[class_names[x] for x in classes])

In [None]:
# train a new model
n_epochs = 3
model, criterion, optimizer, scheduler = train_helper.get_model(dataloaders, n_epochs)

In [None]:
model = train_helper.train_model(model, criterion, optimizer, scheduler, n_epochs=n_epochs)

In [None]:
# load an existing model from disk
#model = train_helper.load_model('model.pt', map_location=device)

In [None]:
def visualize_model(model, device, dataloader, num_images=6):
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (img, label) in enumerate(dataloader):
            img = img.to(device)
            label = label.to(device)

            output = model(img)[0]

            images_so_far += 1
            ax = plt.subplot(num_images//2, 2, images_so_far)
            ax.axis('off')
            pred = 'no_mask' if output.item() > 0.5 else 'mask'
            ax.set_title(f'{pred} ({output.item():.3f})')
            imshow(img[0].detach().cpu())

            if images_so_far == num_images:
                return

In [None]:
dataloaders, class_names = train_helper.load_data(batch_size=1)
visualize_model(model, device, dataloaders['val'])