In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm

# Data

In [18]:
from data import AVA

# Batch size for training (change depending on how much memory you have)
batch_size = 64

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Create training and validation datasets
image_datasets = {x: AVA('data/AVA_dataset/aesthetics_image_lists/{}'
                             .format('generic_ls_train.jpgl' if x == 'train' else 'generic_test.jpgl'), 
                         data_transforms[x]) for x in ['train', 'val']}

# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, 
                                                   shuffle=True, num_workers=4) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model

In [19]:
# Number of classes in the dataset
num_classes = 10

# Weight Path
weight_path = 'weights/dense121_generic.pt'

In [21]:
model_ft = models.densenet121(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Sequential(
    nn.Linear(num_ftrs,num_classes),
    nn.Softmax(1)
)    

In [22]:
# Send the model to GPU
model_ft = model_ft.to(device)

# Train

In [23]:
num_epochs = 5

In [24]:
from scipy import stats
def train_model(model, dataloaders, criterion, optimizers, num_epochs=25):
    since = time.time()

    val_loss_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            
            true_means = []
            pred_means = []

            # Iterate over data.
            import sys; sys.stdout.flush()
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                weight_votes = torch.arange(10, dtype=torch.float, device=device) + 1

                # zero the parameter gradients
                opt.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    
                    true_mean = torch.matmul(labels, weight_votes)
                    for m in true_mean:
                        true_means.append(m.item())

                    pred_mean = torch.matmul(outputs, weight_votes)
                    for m in pred_mean:
                        pred_means.append(m.item())

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        opt.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                
            rho, pval = stats.pearsonr(true_means, pred_means)
            epoch_loss = running_loss / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f}, (rho, pval): ({:.3f},{:.3f})'.format(phase, epoch_loss, rho, pval))
            torch.save(model.state_dict(), weight_path)
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_loss_history

In [25]:
if os.path.exists(weight_path):
    model_ft.load_state_dict(torch.load(weight_path))
    print("Loaded saved weights from {}".format(weight_path))
else:
    print("Starting weights from scratch")

Loaded saved weights


In [26]:
opt = optim.Adam(model_ft.parameters(),lr=1e-4)

In [27]:
# Setup the loss fxn
def EMDLoss(input, target):
    batch_size = input.shape[0]
    cdf_target = torch.cumsum(target, dim=1)
    cdf_pred = torch.cumsum(input, dim=1)
    cdf_diff = cdf_pred - cdf_target
    loss = torch.sqrt(torch.mean(torch.pow(torch.abs(cdf_diff), 2)))
    return loss.mean()

In [None]:
# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders_dict, EMDLoss, opt, num_epochs=num_epochs)

Epoch 0/4
----------


 70%|██████▉   | 217/312 [04:57<02:10,  1.37s/it]

# Mean Correlation

In [None]:
with torch.no_grad():
    true_means = []
    pred_means = []
    for inputs, labels in tqdm(dataloaders_dict['val']):
        inputs = inputs.to(device)
        
        true_mean = torch.matmul(labels, (torch.arange(10) + 1).float())
        for m in true_mean:
            true_means.append(m.item())
        
        preds = model_ft(inputs)
        pred_mean = torch.matmul(preds.cpu(), (torch.arange(10) + 1).float())
        for m in pred_mean:
            pred_means.append(m.item())

In [None]:
rho, pval = stats.pearsonr(true_means, pred_means)
plt.title("Linear Correlation: {:.3f}".format(rho))
plt.scatter(true_means, pred_means)