In [6]:
import torch
from torch.utils.data import DataLoader
import torchvision
import os
import nbimporter
from dataset import VocDataset

In [11]:
def get_data_loaders(data_dir, train_transforms, validation_transforms, test_transforms, batch_size):
    train_loader = DataLoader(
        VocDataset(data_dir, train_transforms, 'train'),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=0
    )
    validation_loader = DataLoader(
        VocDataset(data_dir, validation_transforms, 'val'),
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
        num_workers=0
    )
    test_loader = DataLoader(
        VocDataset(data_dir, test_transforms, 'test'),
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0
    )
    return train_loader, validation_loader, test_loader

In [None]:
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']


def predict(test_iterator, model, device, image):
    # test iterator is from DataLoader for test set.
    # X = test_iterator.dataset.normalize(image).unsqueeze(0)
    X = test_iterator.dataset.normalize(image)
    print(f'X.shape={X.shape}')
    with torch.no_grad():
        prediction = model(X.to(device)).argmax(dim=1)
    # return prediction.reshape(prediction.shape[1], prediction.shape[2])
    print(f'prediction.unsqueeze(0).shape={prediction.unsqueeze(0).shape}')
    return prediction.unsqueeze(0)

def prediction_to_image(prediction, device):
    VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]
    color_map = torch.tensor(VOC_COLORMAP, device=device)
    print(f'PREDICTION_TO_IMAGE: color_map={color_map.shape} prediction={prediction.shape}')
    X = prediction.long()
    print(f'PREDICTION_TO_IMAGE: color_map={color_map.shape} prediction.long={X.shape}')
    print(f'color_map[X, :]={color_map[X, :].shape}')
    return color_map[X, :]

In [None]:
import torch

def prediction_accuracy(prediction, label):
    print(f'PREDICTION_ACCURACY: prediction={prediction.shape} label={label.shape}')
    print(f'PREDICTION: \n{prediction}')
    print(f'LABEL: \n{label}')
    if len(prediction.shape) > 1 and prediction.shape[1] > 1:
        prediction = torch.argmax(prediction, axis=1)
    cmp = prediction.type(label.dtype) == label
    return float(torch.sum(cmp.type(label.dtype)))

def accuracy(prediction, label):
    num_matches = (prediction == label).sum().item()
    print(num_matches)

a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([1, 0, 3, 5])
accuracy(a, b)

2


In [10]:
def save_checkpoint(state_dict, checkpoint_file='model/model.pt'):
    '''Saves the model state to disc.'''
    # if not os.path.exists(checkpoint_file):
    #     print(f'Error. The path {checkpoint_file} does not exist. Could not save model state.')
    torch.save(state_dict, checkpoint_file)

def load_checkpoint(model, checkpoint_file='model/model.pt'):
    '''Loads the model state from disc.'''
    if not os.path.exists(checkpoint_file):
        print(f'Error. The path {checkpoint_file} does not exist. Could not retreive model state.')
    model.load_state_dict(checkpoint_file)

In [None]:
# # TODO: complete the function to save prediction images to the predictions/ directory

# def save_predictions(loader, model, folder='predictions/', device='cuda'):
#     for index, (feature, label) in enumerate(loader):
#         feature = feature.to(device)
#         with torch.no_grad():
#             predictions = label_to_image(predict(model(feature)))
            