In [1]:
%matplotlib inline

In [2]:
from __future__ import print_function 
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from tqdm import tqdm

In [3]:
from data import AVA

# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "vgg"

# Number of classes in the dataset
num_classes = 10

# Batch size for training (change depending on how much memory you have)
batch_size = 64

# Number of epochs to train for 
num_epochs = 10

# Flag for feature extracting. When False, we finetune the whole model, 
#   when True we only update the reshaped layer params

# Weight Path
weight_path = 'weights/vgg16_generic.pt'

In [4]:
from scipy import stats
def train_model(model, dataloaders, criterion, optimizers, num_epochs=25):
    since = time.time()

    val_loss_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            
            true_means = []
            pred_means = []

            # Iterate over data.
            import sys; sys.stdout.flush()
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                weight_votes = torch.arange(10, dtype=torch.float, device=device) + 1

                # zero the parameter gradients
                for opt in optimizers:
                    opt.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    true_mean = torch.matmul(labels, weight_votes)
                    for m in true_mean:
                        true_means.append(m.item())

                    pred_mean = torch.matmul(outputs, weight_votes)
                    for m in pred_mean:
                        pred_means.append(m.item())

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        for opt in optimizers:
                            opt.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                
            rho, pval = stats.pearsonr(true_means, pred_means)
            epoch_loss = running_loss / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f}, (rho, pval): ({:.3f},{:.3f})'.format(phase, epoch_loss, rho, pval))
            torch.save(model.state_dict(), weight_path)
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_loss_history

In [5]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [7]:
def initialize_model(model_name, num_classes, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    
    model_ft = models.vgg16(pretrained=True)
    num_ftrs = model_ft.classifier[6].in_features
    model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
    model_ft.classifier.add_module("softmax", nn.Softmax(1))
    input_size = 224
    
    return model_ft, input_size

# Initialize the model for this run
model_ft, input_size = initialize_model(model_name, num_classes, use_pretrained=True)

In [8]:
if os.path.exists(weight_path):
    model_ft.load_state_dict(torch.load(weight_path))
    print("Loaded saved weights")
else:
    print("Starting from scratch")

Loaded saved weights


In [9]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: AVA('data/AVA_dataset/aesthetics_image_lists/{}'
                             .format('generic_ls_train.jpgl' if x == 'train' else 'generic_test.jpgl'), 
                         data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, 
                                                   shuffle=True, num_workers=3) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Initializing Datasets and Dataloaders...


In [10]:
# Send the model to GPU
model_ft = model_ft.to(device)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are 
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.


base_params = []
last_layer = []

for name,param in model_ft.named_parameters():
    base_params.append(param)
    
last_layer.append(base_params[-1])
del base_params[-1]

In [11]:
# Observe that all parameters are being optimized
base_opt = optim.Adam(base_params,lr=1e-4)
last_opt = optim.Adam(last_layer,lr=1e-4)

In [12]:
# Setup the loss fxn
def EMD(input, target):
    batch_size = input.shape[0]
    N = input.shape[1]
    
    losses = torch.zeros(batch_size, dtype=torch.float, device=device)
    
    for k in range(N):
        losses += (input[:,:k+1].sum(dim=1) - target[:,:k+1].sum(dim=1))**2
    losses /= N
    
    losses = losses ** 0.5
    
    avg_loss = losses.mean()
    
    return avg_loss
    
    
criterion = EMD

In [None]:
# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, [base_opt, last_opt], num_epochs=num_epochs)

Epoch 0/9
----------


100%|██████████| 312/312 [12:26<00:00,  1.99s/it]


train Loss: 0.0860, (rho, pval): (0.532,0.000)

Epoch 1/9
----------


100%|██████████| 312/312 [12:23<00:00,  1.99s/it]


train Loss: 0.0753, (rho, pval): (0.585,0.000)

Epoch 2/9
----------


100%|██████████| 312/312 [12:24<00:00,  1.99s/it]


train Loss: 0.0748, (rho, pval): (0.588,0.000)

Epoch 3/9
----------


100%|██████████| 312/312 [12:26<00:00,  1.99s/it]


train Loss: 0.0733, (rho, pval): (0.607,0.000)

Epoch 4/9
----------


 92%|█████████▏| 287/312 [11:26<00:59,  2.39s/it]

# Linear Correlation

In [None]:
with torch.no_grad():
    true_means = []
    pred_means = []
    for inputs, labels in tqdm(dataloaders_dict['val']):
        inputs = inputs.to(device)
        
        true_mean = torch.matmul(labels, (torch.arange(10) + 1).float())
        for m in true_mean:
            true_means.append(m.item())
        
        preds = model_ft(inputs)
        pred_mean = torch.matmul(preds.cpu(), (torch.arange(10) + 1).float())
        for m in pred_mean:
            pred_means.append(m.item())

In [None]:
rho, pval = stats.pearsonr(true_means, pred_means)
plt.title("Linear Correlation: {:.3f}".format(rho))
plt.scatter(true_means, pred_means)