In [1]:
%matplotlib notebook
%config InlineBackend.figure_format = 'retina'

# native
import sys
import os
from os import listdir
from collections import defaultdict
from PIL import Image
import pprint as pp
import functools
import pickle
import re

# math
import numpy as np
from sklearn.metrics import accuracy_score

# plotting
import matplotlib
from matplotlib import pyplot as plt

# extra
from tqdm import tqdm
import logging

# pytorch
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils import model_zoo

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models


In [2]:
requirements = {
    torch: '1',
    matplotlib: '3'
}

def check_requirements(requirements):
    for requirement in requirements:
        error_message = '{} environment does not match requirement'.format(requirement.__name__)
        assert (requirement.__version__[0] == requirements[requirement]), error_message

check_requirements(requirements)

In [3]:
cuda = torch.cuda.is_available()

if cuda:
    torch.backends.cudnn.benchmark = True
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

device = 'cuda' if cuda else 'cpu'

device

'cuda'

In [4]:
class PlotGrid:
    def __init__(self, figsize=None):
        self.fig = plt.figure(figsize=figsize)
        self.ax = {}
        self.xlim = {}
        self.ylim = {}
        self.filled = {}
        self.grid = {}
    
    def plot(self, position_id, data, title=None, xlim=None, ylim=None, filled=None, grid=None):
        if position_id in self.ax:
            ax = self.ax[position_id]
        else:
            ax = self.fig.add_subplot(*position_id)

        # cache current values
        if title is None:
            title = ax.get_title()

        if xlim is not None:
            self.xlim[position_id] = xlim

        if ylim is not None:
            self.ylim[position_id] = ylim

        if filled is not None:
            self.filled[position_id] = filled
        
        if position_id not in self.filled:
            self.filled[position_id] = True

        if grid is not None:
            self.grid[position_id] = grid
        
        if position_id not in self.grid:
            self.grid[position_id] = True

        ax.cla()
        ax.clear()
        if type(data).__name__ == 'Image':
            ax.imshow(data)
        else:
            if hasattr(data, 'is_cuda') and data.is_cuda:
                data = data.cpu()
            if hasattr(data, 'numpy'):
                data = data.numpy()
            ax.plot(data)

            if self.filled[position_id]:
                ax.fill_between(range(len(data)), data)

            if self.grid[position_id]:
                ax.grid(True)

            # set xlim
            if position_id in self.xlim:
                ax.set_xlim(*self.xlim[position_id])

            # set ylim
            if position_id in self.ylim:
                ax.set_ylim(*self.ylim[position_id])
        
        # set title
        if title is not None:
            ax.set_title(title)

        self.fig.tight_layout()
        self.fig.canvas.draw()
        self.ax[position_id] = ax
    
    def prediction_plot(self, position_id, data, title=None, grid=None):
        if position_id in self.ax:
            ax = self.ax[position_id]
        else:
            ax = self.fig.add_subplot(*position_id)

        # cache current values
        if title is None:
            title = ax.get_title()

        if grid is not None:
            self.grid[position_id] = grid
        
        if position_id not in self.grid:
            self.grid[position_id] = True

        ax.cla()
        ax.clear()
        plot_data = data[2]
        plot_labels = data[1]
        if hasattr(plot_data, 'is_cuda') and plot_data.is_cuda:
            plot_data = plot_data.cpu()
        if hasattr(plot_data, 'numpy'):
            plot_data = plot_data.numpy()

        ticks = range(len(plot_data)-1, -1, -1)

        ax.barh(ticks, plot_data, align='center')

        if self.grid[position_id]:
            ax.grid(True)

        # set xlim
        ax.set_xlim(0, 1)

        # set y labels
        ax.set_yticks(ticks)
        ax.set_yticklabels(plot_labels)
        
        # set title
        if title is not None:
            ax.set_title(title)

        self.fig.tight_layout()
        self.fig.canvas.draw()
        self.ax[position_id] = ax
    
    def savefig(self, filename):
        figure_directory = os.path.join('results', 'activation-plots')
        os.makedirs(figure_directory, exist_ok=True)
        figure_path = os.path.join(figure_directory, filename)
        self.fig.savefig(figure_path, bbox_inches='tight')


In [5]:
def pathJoin(*args):
    return os.path.abspath(os.path.join(*args))


def pprint(*args):
    pp.pprint(*args)


def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))


def smooth(x, span=10):
    return [ np.mean(x[i:i+span]) for i in range(len(x) - span + 1)]


toPILImage = transforms.ToPILImage()

softmax = torch.nn.Softmax(dim=1)


In [6]:
class BaseDataset(Dataset):

    def __init__(self, directory, split='train', transforms=None):
        self.datapoints = defaultdict(list)
        self.split = split
        self.directory = pathJoin(directory, split)
        self.datapoints, self.groundtruths = self.loadDataset()
        self.transforms = transforms

    def __len__(self):
        return len(self.datapoints)

    def __getitem__(self, idx):
        datapoint = self.loadDatapoint(idx)
        return datapoint

    def loadDatapoint(self, idx):
        raise NotImplementedError('Function "loadDatapoint" is not implemented')

    def loadDataset(self, name):
        raise NotImplementedError('Function "loadDataset" is not implemented')


In [7]:
class TinyImageNetDataset(BaseDataset):

    def __init__(self, directory, split='train', transforms=None):
        super().__init__(directory, split, transforms)
        self.descriptions = self.loadDescriptions()
        self.classes = self.loadClasses()
        self.imagenet_classes = self.loadImageNetClasses()

    def loadDatapoint(self, idx):
        filepath = self.datapoints[idx]
        image = Image.open(filepath).convert('RGB')
        groundtruth = self.groundtruths[idx]
        if self.transforms:
            image = self.transforms(image)
        return (filepath, image, groundtruth, self.descriptions[groundtruth], self.classes.index(groundtruth))

    def loadDataset(self):
        datapoints = []
        groundtruths = []

        if self.split is 'train':
            class_directories = os.listdir(self.directory)
            for classname in tqdm(class_directories):
                class_path = pathJoin(self.directory, classname, 'images')
                for filename in os.listdir(class_path):
                    datapoints.append(pathJoin(class_path, filename))
                    groundtruths.append(classname)
        elif self.split is 'val':
            dataset_file_list_filename = 'val_annotations.txt'
            dataset_file_list_path = os.path.join(self.directory, dataset_file_list_filename)

            with open(dataset_file_list_path, 'r') as dataset_file_list_file:
                for line in tqdm(dataset_file_list_file, total=sum(1 for line in open(dataset_file_list_path))):
                    filename, annotation, *_ = line.split('\t')
                    file_path = pathJoin(self.directory, 'images', self.sanitizeFilename(filename))
                    datapoints.append(file_path)
                    groundtruths.append(annotation)

        return datapoints, groundtruths
    
    def sanitizeFilename(self, filename):
        return filename.replace('"', '').strip()

    def loadDescriptions(self):
        descriptions = {}

        descriptions_filename = 'words.txt'
        descriptions_path = pathJoin(self.directory, '..', descriptions_filename)

        with open(descriptions_path, 'r') as descriptions_file:
            for line in descriptions_file:
                line = line.strip()
                description_breakdown = line.split('\t')
                descriptions[description_breakdown[0]] = description_breakdown[1]

        return descriptions

    def loadClasses(self):
        classes = []

        classes_filename = 'wnids.txt'
        classes_path = pathJoin(self.directory, '..', classes_filename)

        with open(classes_path, 'r') as classes_file:
            for line in classes_file:
                classes.append(line.strip())

        return classes

    def loadImageNetClasses(self):
        classes = []

        classes_filename = 'synsets.txt'
        classes_path = pathJoin(self.directory, '..', classes_filename)

        with open(classes_path, 'r') as classes_file:
            for line in classes_file:
                classes.append(line.strip())

        return classes
    
    def imagenetidx2class(self, idx):
        return self.imagenet_classes[idx]
    
    def imagenetidx2label(self, idx):
        return self.descriptions[self.imagenetidx2class(idx)]
    
    def idx2class(self, idx):
        return self.classes[idx]
    
    def idx2label(self, idx):
        return self.descriptions[self.idx2class(idx)]


In [8]:
class DeNormalize(object):
    # Source: https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        tensor = image.clone()
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor


In [9]:
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 512

imagenet_normalization_values = {
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225]
}

normalize = transforms.Normalize(**imagenet_normalization_values)
denormalize = DeNormalize(**imagenet_normalization_values)


def toImage(tensor_image):
    return toPILImage(denormalize(tensor_image))

raw_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    normalize
])

test_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    normalize
])

tinyimagenet_dataset_path = os.path.join('datasets', 'tiny-imagenet-200')

tinyimagenet_train_dataset = TinyImageNetDataset(tinyimagenet_dataset_path, transforms=train_transforms)#raw_transforms)
tinyimagenet_val_dataset = TinyImageNetDataset(tinyimagenet_dataset_path, split='val', transforms=test_transforms)
# tinyimagenet_test_dataset = TinyImageNetDataset(tinyimagenet_dataset_path, split='test')

tinyimagenet_train_loader = DataLoader(tinyimagenet_train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
tinyimagenet_val_loader = DataLoader(tinyimagenet_val_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)
# tinyimagenet_test_loader = DataLoader(tinyimagenet_test_dataset, batch_size=TEST_BATCH_SIZE, shuffle=False)


100%|██████████| 200/200 [00:03<00:00, 58.25it/s]
100%|██████████| 10000/10000 [00:00<00:00, 22479.13it/s]


In [10]:
image_grid = PlotGrid(figsize=(9,2))

<IPython.core.display.Javascript object>

In [11]:
for index, image in enumerate(tinyimagenet_val_dataset):
    img = image[1]
    image_grid.plot((1, 1, 1), toImage(img), title=image[3])
    if (index + 1) == 9:
        break

for index, batch in enumerate(tinyimagenet_val_loader):
    img_batch = batch[1]
    break


In [12]:
# def imshow(img):
#     img = denormalize(img)
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()

# # get some random training images
# dataiter = iter(tinyimagenet_train_loader)
# _, images, _, labels, _ = dataiter.next()

# # show images
# imshow(torchvision.utils.make_grid(images))
# # print labels
# print(' '.join('%5s' % labels[j] for j in range(4)))

In [13]:
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 16, 2)
        self.conv2 = torch.nn.Conv2d(16, 32, 2)
        self.conv3 = torch.nn.Conv2d(32, 64, 2)
        self.conv4 = torch.nn.Conv2d(64, 64, 1)
        self.conv5 = torch.nn.Conv2d(64, 128, 2)

        self.pool = torch.nn.MaxPool2d(2, 2)
        self.relu = torch.nn.ReLU(inplace=True)

        self.fc1 = torch.nn.Linear(3136, 1024)
        self.fc2 = torch.nn.Linear(1024, 1024)
        self.fc3 = torch.nn.Linear(1024, 200)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 3136)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class VGG_IN(torch.nn.Module):
    def __init__(self, layer_index, pretrained=False):
        super(VGG_IN, self).__init__()
        vgg19 = models.vgg19(pretrained=pretrained)
        self.features1 = vgg19.features[:layer_index]
        self.instance_normalization = torch.nn.InstanceNorm2d(vgg19.features[layer_index].out_channels)
        self.features2 = vgg19.features[layer_index:]
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=2048, out_features=1024, bias=True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(in_features=1024, out_features=512, bias=True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(in_features=512, out_features=200, bias=True)
        )

    def forward(self, x):
        x = self.features1(x)
        x = self.instance_normalization(x)
        x = self.features2(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


class VGG_IN_BS(torch.nn.Module):
    def __init__(self, layer_index, pretrained=False, eps=1e-05):
        super(VGG_IN_BS, self).__init__()
        vgg19 = models.vgg19(pretrained=pretrained)
        self.features1 = vgg19.features[:layer_index]
        self.features2 = vgg19.features[layer_index:]
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=2048, out_features=1024, bias=True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(in_features=1024, out_features=512, bias=True),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(in_features=512, out_features=200, bias=True)
        )
        self.eps = eps

    def forward(self, x):
        x = self.features1(x)
        x_ = x.view(x.size(0), x.size(1), -1)
        mean = x_.mean(2, keepdim=True).unsqueeze(-1)
        std = x_.std(2, keepdim=True).unsqueeze(-1)
        den = torch.sqrt(std.pow(2) + self.eps)
        y = (x - mean)/den
        indices = torch.randperm(x.size(0))
        y = y * std.index_select(0, indices) + mean.index_select(0, indices)
        x = self.features2(y)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x



In [14]:
def init_weights(m):
    if isinstance(m, torch.nn.Conv2d):
        torch.nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


def create_vgg19_in_pretrained():
    vgg = VGG_IN(21, pretrained=True)
    for param in vgg.features1.parameters():
        param.requires_grad = False
    for param in vgg.features2.parameters():
        param.requires_grad = False
    return vgg


def create_vgg19_in_batch_stats_pretrained():
    vgg = VGG_IN_BS(21, pretrained=True)
    for param in vgg.features1.parameters():
        param.requires_grad = False
    for param in vgg.features2.parameters():
        param.requires_grad = False
    return vgg


def create_vgg19_pretrained():
    # load model from pytorch
    vgg19 = models.vgg19(pretrained=True)
    # always load in eval mode
    vgg19.eval();

    # freeze cnn layers
    for param in vgg19.parameters():
        param.requires_grad = False

    vgg19.classifier = torch.nn.Sequential(
        torch.nn.Linear(in_features=2048, out_features=1024, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=1024, out_features=512, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=512, out_features=200, bias=True)
    )
    return vgg19


def create_vgg19_scratch():
    # load model from pytorch
    vgg19 = models.vgg19(pretrained=False)
    # always load in eval mode
    vgg19.eval();

    vgg19.classifier = torch.nn.Sequential(
        torch.nn.Linear(in_features=2048, out_features=1024, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=1024, out_features=512, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=512, out_features=200, bias=True)
    )
    
    vgg19.apply(init_weights)
    return vgg19


def create_vgg16_pretrained():
    # load model from pytorch
    vgg16 = models.vgg16(pretrained=True)
    # always load in eval mode
    vgg16.eval();

    # freeze cnn layers
    for param in vgg16.parameters():
        param.requires_grad = False

    vgg16.classifier = torch.nn.Sequential(
        torch.nn.Linear(in_features=2048, out_features=1024, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=1024, out_features=512, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=512, out_features=200, bias=True)
    )

    return vgg16


def create_vgg16_scratch():
    # load model from pytorch
    vgg16 = models.vgg16(pretrained=False)
    # always load in eval mode
    vgg16.eval();

    vgg16.classifier = torch.nn.Sequential(
        torch.nn.Linear(in_features=2048, out_features=1024, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=1024, out_features=512, bias=True),
        torch.nn.ReLU(inplace=True),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(in_features=512, out_features=200, bias=True)
    )

    vgg16.apply(init_weights)
    return vgg16


def create_cnn():
    return CNN()


# vanilla_vgg19_pretrained = create_vgg19('vanilla', 'all')
# vanilla_vgg19_scratch = create_vgg19('vanilla', 'none')

In [15]:
def validate(model, dataloader, monitor, logger):
    logger.debug('Validation Start')
    model.eval()
    
    correct = []

    if monitor:
        validation_grid = PlotGrid(figsize=(9,3))
        validation_grid.plot((1, 2, 1), np.mean(correct), title='Validation Accuracy', filled=False, ylim=(0,1))

    for batch_index, batch in enumerate(dataloader):
        output = model(batch[1].to(device))
        target = batch[4].to(device)

        _, predicted_class = output.topk(1, 1, True, True)
        predicted_descriptions = [ dataloader.dataset.imagenetidx2label(x) for x in predicted_class.squeeze() ]
        batch_results = target == predicted_class.squeeze()
        batch_results = (target == predicted_class.squeeze()).cpu().numpy()
        correct.extend(batch_results)
        if monitor:
            validation_grid.plot((1, 2, 1), np.mean(correct), title='Validation Accuracy {0:.4f}'.format(np.mean(correct)))
            validation_grid.plot((1, 2, 2), toImage(batch[1][0]), title='True: {}\n Predicted: {}'.format(batch[3][0], predicted_descriptions[0]))

        if (batch_index + 1) % 10 == 0:
            logger.debug('Validation Batch {0}: Accuracy {1:.4f}'.format(batch_index + 1, np.mean(correct)))

    logger.debug('Validation End')
    return np.mean(correct)

# validate(vanilla_vgg19_pretrained, tinyimagenet_val_loader, monitor=True)

In [16]:
def train(model, dataloader, criterion, optimizer, monitor, logger):
    logger.debug('Training Start')
    model.train()

    if monitor:
        train_grid = PlotGrid(figsize=(9,3))

    loss = []
    accuracy = []

    for batch_index, batch in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(batch[1].to(device))
        target = batch[4].to(device)

        # accuracy
        _, predicted_class = output.topk(1, 1, True, True)
        batch_results = (target == predicted_class.squeeze()).cpu().numpy()
        accuracy.append(np.mean(batch_results))

        # loss
        batch_loss = criterion(output, target)
        loss.append(batch_loss.item())
        
        # backprop
        batch_loss.backward()
        optimizer.step()
        
        # plot
        if monitor:
            train_grid.plot((1, 2, 1), smooth(loss, 10), title='Train Loss {0:.4f}'.format(loss[-1]), filled=False)
            train_grid.plot((1, 2, 2), smooth(accuracy, 10), title='Train Accuracy {0:.4f}'.format(accuracy[-1]), filled=False, ylim=(0,1))
            
        if (batch_index + 1) % 10 == 0:
            logger.debug('Training Batch {0}: Accuracy {1:.4f} Loss {2:.4f}'.format(batch_index + 1, accuracy[-1], loss[-1]))

    logger.debug('Training End')
    return accuracy[-1], loss[-1]

def run(run_name, model, training, number_of_epochs, monitor, logger):

    criterion = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=1, min_lr=1e-5, verbose=True)
    
    best_validation_accuracy = 0

    for epoch in range(1, number_of_epochs + 1):
        if training:
            train_accuracy, train_loss = train(model, tinyimagenet_train_loader, criterion, optimizer, monitor, logger)
        validation_accuracy = validate(model, tinyimagenet_val_loader, monitor, logger)
        logger.info('Epoch {0}: Train Loss: {1:.4f} Train Accuracy: {2:.4f} Validation Accuracy: {3:.4f}'.format(epoch, train_loss, train_accuracy, validation_accuracy))

        lr_scheduler.step(validation_accuracy)

        if validation_accuracy > best_validation_accuracy:
            logger.debug('Improved Validation Score, saving new weights')
            model_directory = pathJoin('models')
            os.makedirs(model_directory, exist_ok=True)
            checkpoint = {
                'epoch': epoch,
                'train_accuracy': train_accuracy,
                'train_loss': train_loss,
                'validation_accuracy': validation_accuracy,
                'weights': model.state_dict()
            }
            torch.save(checkpoint, pathJoin(model_directory, '{}.ckpt'.format(run_name)))
            best_validation_accuracy = validation_accuracy


In [17]:
def create_logger(log_directory, filename, stream=False):
    info_filehandler = logging.FileHandler(os.path.join(log_directory, '{}_info.log'.format(filename)))
    debug_filehandler = logging.FileHandler(os.path.join(log_directory, '{}_debug.log'.format(filename)))

    formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')
    info_filehandler.setFormatter(formatter)
    debug_filehandler.setFormatter(formatter)

    info_filehandler.setLevel(logging.INFO)
    debug_filehandler.setLevel(logging.DEBUG)

    logger = logging.getLogger()
    for hdlr in logger.handlers[:]:
        logger.removeHandler(hdlr)

    if stream:
        streamhandler = logging.StreamHandler(sys.stdout)
        streamhandler.setFormatter(formatter)
        streamhandler.setLevel(logging.DEBUG)
        logger.addHandler(streamhandler)

    logger.addHandler(info_filehandler)
    logger.addHandler(debug_filehandler)

    logger.setLevel(logging.DEBUG)

    logging.getLogger('PIL.PngImagePlugin').setLevel(logging.ERROR)

    return logger


In [None]:
supported_models = {
    'vgg19_in_pretrained': create_vgg19_in_pretrained,
    'vgg19_in_batch_stats_pretrained': create_vgg19_in_batch_stats_pretrained,
    'vgg19_pretrained': create_vgg19_pretrained,
    'vgg16_pretrained': create_vgg16_pretrained,
    'vgg19_scratch': create_vgg19_scratch,
    'vgg16_scratch': create_vgg16_scratch,
    'cnn': create_cnn
}

training = True
epochs = 100
monitor = False

for model_type in supported_models:
    load_model = True

    run_name = '{}'.format(model_type)
    log_directory = pathJoin('run_logs')
    os.makedirs(log_directory, exist_ok=True)
    logger = create_logger(log_directory, run_name)
    logger.info('Run Name {}'.format(run_name))
    model = supported_models[model_type]()
    run(run_name, model, training, epochs, monitor, logger)


Epoch    18: reducing learning rate of group 0 to 2.0000e-03.
Epoch    27: reducing learning rate of group 0 to 4.0000e-04.
Epoch    29: reducing learning rate of group 0 to 8.0000e-05.
Epoch    37: reducing learning rate of group 0 to 1.6000e-05.
Epoch    39: reducing learning rate of group 0 to 1.0000e-05.
Epoch     7: reducing learning rate of group 0 to 2.0000e-03.
Epoch    15: reducing learning rate of group 0 to 4.0000e-04.
Epoch    21: reducing learning rate of group 0 to 8.0000e-05.
Epoch    23: reducing learning rate of group 0 to 1.6000e-05.
Epoch    26: reducing learning rate of group 0 to 1.0000e-05.


In [None]:
# vgg16 = models.vgg16(pretrained=False)
# vgg16.features = vgg16.features[:-7]

# pprint(img.shape)
# vgg16.features(img.unsqueeze(0).cuda()).shape