## Setup 

In [107]:
import torch
from torch.autograd import Variable as V
import torchvision.models as models
from torchvision import transforms as trn
from torch.nn import functional as F
import os
import numpy as np
from scipy.misc import imresize as imresize
from PIL import Image

# Custom imports
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn

import sys
sys.path.append("../../Places365")
from run_placesCNN_unified import load_labels, load_model, returnCAM

## Data Parameters

In [108]:
# Data location
datadir = '../data/images/'
traindir = datadir + 'train/'
validdir = datadir + 'val/'
testdir = datadir + 'test/'

# Batch size
batch_size = 128

# GPU Settings
train_on_gpu = torch.cuda.is_available()
print(f'Train on gpu: {train_on_gpu}')
if train_on_gpu:
    gpu_count = torch.cuda.device_count()
    print(f'{gpu_count} gpus detected.')
    
# Images per set
for group in ['train', 'val', 'test']:
    for label in [0, 1]:
        path = datadir + group + '/' + str(label)
        print(group, label, '-', len(os.listdir(path)))

Train on gpu: False
train 0 - 4392
train 1 - 5970
val 0 - 1541
val 1 - 1913
test 0 - 1454
test 1 - 2000


## Image Preprocessing

In [109]:
# Image transformations
image_transforms = {
    # Train uses data augmentation
    'train':
        trn.Compose([
        trn.Resize(size=256),
        trn.RandomRotation(degrees=15),
        trn.ColorJitter(),
        trn.RandomHorizontalFlip(),
        trn.CenterCrop(size=224),
        trn.ToTensor(),
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    
    # Validation does not use augmentation
    'val':
        trn.Compose([
        trn.Resize(size=256),
        trn.CenterCrop(size=224),
        trn.ToTensor(),
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    
    # Test does not use augmentation
    'test':
        trn.Compose([
        trn.Resize(size=256),
        trn.CenterCrop(size=224),
        trn.ToTensor(),
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

## Data Iterators

In [114]:
data = {
    'train': datasets.ImageFolder(root=traindir, transform=image_transforms['train']),
    'val': datasets.ImageFolder(root=validdir, transform=image_transforms['val']),
    'test': datasets.ImageFolder(root=testdir, transform=image_transforms['test'])
}


# Dataloader iterators
dataloaders = {
    'train': DataLoader(data['train'], batch_size=batch_size, shuffle=True),
    'val': DataLoader(data['val'], batch_size=batch_size, shuffle=True),
    'test': DataLoader(data['test'], batch_size=batch_size, shuffle=True)
}

In [115]:
trainiter = iter(dataloaders['train'])
features, labels = next(trainiter)
print(features.shape, labels.shape, 'batch dimensions')
print(len(data['train'].classes), 'classes')

torch.Size([128, 3, 224, 224]) torch.Size([128]) batch dimensions
2 classes


In [116]:
# Set update criteria and optimizer - TODO: Tune which are used
criterion = torch.nn.CrossEntropyLoss() # TODO: is 'cuda' right here?
optimizer = torch.optim.Adam(model.parameters())

## Create the model

In [117]:
def load_model():
    
    # Load the wideresnet model
    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    
    # Load in the pretrained weights
    model_file = 'wideresnet18_places365.pth.tar'
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict)
    
    # Add hooks for attributes and CAM
    features_names = ['layer4','avgpool']
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)

    # Freeze model weights
    for param in model.parameters():
        param.requires_grad = False

    # Replace the final layer with a FC mapping to binary classes
    n_inputs = model.fc.in_features
    model.fc = torch.nn.Sequential(
        torch.nn.Linear(n_inputs, 256), torch.nn.ReLU(), torch.nn.Dropout(0.2),
        torch.nn.Linear(256, 2), torch.nn.LogSoftmax(dim=1))
    
    return model

model = load_model()

In [118]:
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

11,308,354 total parameters.
131,842 training parameters.
