In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision
import os
import nbimporter
from dataset import VocDataset

In [None]:
def get_data_loaders(data_dir, train_transforms, validation_transforms, batch_size):
    train_loader = DataLoader(
        VocDataset(data_dir, train_transforms, True),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=0
    )
    validation_loader = DataLoader(
        VocDataset(data_dir, validation_transforms, False),
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=0
    )
    return train_loader, validation_loader

In [None]:
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']


def predict(test_iterator, model, device, image):
    # test iterator is from DataLoader for test set.
    X = test_iterator.dataset.normalize_image(image).unsqueeze(0)
    prediction = model(X.to(device)).argmax(dim=1)
    return prediction.reshape(prediction.shape[1], prediction.shape[2])

def label_to_image(prediction, device):
    color_map = torch.tensor(VOC_COLORMAP, device=device)
    X = prediction.long()
    return color_map[X, :]

In [None]:
def save_checkpoint(state_dict, checkpoint_file='model/model.pt'):
    '''Saves the model state to disc.'''
    # if not os.path.exists(checkpoint_file):
    #     print(f'Error. The path {checkpoint_file} does not exist. Could not save model state.')
    torch.save(state_dict, checkpoint_file)

def load_checkpoint(model, checkpoint_file='model/model.pt'):
    '''Loads the model state from disc.'''
    if not os.path.exists(checkpoint_file):
        print(f'Error. The path {checkpoint_file} does not exist. Could not retreive model state.')
    model.load_state_dict(checkpoint_file)

In [None]:
# TODO: complete the function to save prediction images to the predictions/ directory

def save_predictions(loader, model, folder='predictions/', device='cuda'):
    model.eval()
    for index, (feature, label) in enumerate(loader):
        feature = feature.to(device)
        with torch.no_grad():
            predictions = predict(model(feature))
            