# Labeled Dataset Regression

This notebook is used to train a regression model to predict the gaze location from a labeled dataset using PyTorch.

Sources:
- [1] http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py
- [2] http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#sphx-glr-beginner-transfer-learning-tutorial-py

## Load dataset into Class Object

In [None]:
from pathlib import Path
import time
import copy
import re
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import pandas as pd
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

# Plot fit width
%matplotlib notebook
import matplotlib.pyplot as plt
fig_size[0] = 12
fig_size[1] = 6
plt.rcParams["figure.figsize"] = fig_size

# Plotting in interactive mode
# plt.ion()

# Make sure we can use GPU
use_gpu = torch.cuda.is_available()
print('Gpu is enabled: %s' % use_gpu)

In [None]:
# Lets make a dataset class that inherits from PyTorch's dataset class
class GazeDataset(Dataset):
    """Gaze Dataset Class"""
    
    # Local dataset directory
    root_dir = Path.cwd()
    data_dir = root_dir / '..' /  '..' / 'local' / 'data'

    def __init__(self, dataset_name, transform=None):
        """
        :dataset_name: (string) name of directory containing all the images
        :transform: (optional callable) transform to be applied to image
        """
        self.name = dataset_name
        self.data_path = GazeDataset.data_dir / dataset_name
        self.transform = transform
        self.dataset = self._load_dataset()
        
    def _extract_target_from_gazefilename(self, imagepath):
        """
        Extract the label from the image path name
        :imagepath: (Path) image path (contains target)
        :return: tuple(int, int) gaze target
        """
        m = re.search('(\d.\d+)_(\d.\d+).png', imagepath.name)
        gaze_x = float(m.group(1))
        gaze_y = float(m.group(2))
        return gaze_x, gaze_y
    
    def _load_dataset(self):
        """
        Loads dataset into a pandas dataframe
        :return: (pd) dataset with (filename, gaze_x, gaze_y) as header columns
        """
        # Use pathlib glob to get images
        image_list = list(self.data_path.glob('*.png'))
        print('Loading dataset %s, there are %s images.' % (self.data_path.name, len(image_list)))
        # Create new pandas dataframe
        df = pd.DataFrame(index=list(range(len(image_list))),
                          columns=['imagepath', 'gaze_x', 'gaze_y'])
        # Add all images in dataset folder into dataframe
        for i, imagepath in enumerate(image_list):
            gaze_x, gaze_y = self._extract_target_from_gazefilename(imagepath)
            df.loc[i] = [imagepath, gaze_x, gaze_y]
        return df
    
    def _get_datapoint(self, idx):
        """
        Returns a single datapoint at a given index
        :param idx: (int) index of datapoint to retreive
        """
        # Load image using PIL
        image = Image.open(self.dataset.iloc[idx, 0])
        # Strip out 4th channel
        img_array = np.array(image)
        img_stripped = img_array[:,:,:3]
        image = Image.fromarray(img_stripped)
        gaze_x = self.dataset.iloc[idx, 1]
        gaze_y = self.dataset.iloc[idx, 2]
        return image, gaze_x, gaze_y
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        """
        Gets a single item of dataset based on index
        :param idx: (int) index of datapoint to retreive
        :return sample: (dict) image, gaze target
        """
        image, gaze_x, gaze_y = self._get_datapoint(idx)
        # Apply transform if necessary
        if self.transform: 
            image = self.transform(image)
        # Put info into a dictionary
        sample = {'image': image,'gaze_x': gaze_x, 'gaze_y' :gaze_y}
        return sample
    
    def plot_samples(self, num_images=3):
        """
        Plots random sample images
        :num_images: (int) number of images to plot per dataset
        """
        fig = plt.figure()
        for i in range(num_images):
            sample_idx = random.randint(0, self.__len__())
            image, gaze_x, gaze_y = self._get_datapoint(sample_idx)
            # Use sublots to plot all of them
            ax = plt.subplot(1, num_images, i + 1)
            plt.tight_layout()
            plt.imshow(image)
            ax.set_title('Image %s: (%.2f, %.2f)' %(sample_idx, gaze_x, gaze_y))
            ax.axis('off')
        plt.show()

## Create Dataset Loader

In [None]:
# TODO: Find a better mean and std (rather than these magic numbers)
data_transforms = {
    'train': transforms.Compose([
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Params
GAZE_DATASET_NAME = Path('100118_fixedhead')
BATCH_SIZE = 4

# Create dataloaders, datasets
image_datasets = {x: GazeDataset(str(GAZE_DATASET_NAME / x), data_transforms[x]) for x in data_transforms.keys()}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                              batch_size=BATCH_SIZE,
                                              shuffle=True,
                                              num_workers=4) for x in data_transforms.keys()}
dataset_sizes = {x: len(image_datasets[x]) for x in data_transforms.keys()}

# Plot some examples
for _, dataset in image_datasets.items():
    print('Example images for %s' % dataset.name)
    dataset.plot_samples()

## Training Function

In [None]:
def _extract_inputs(data):
    # get the inputs
    inputs = data['image']
    labels = torch.stack((data['gaze_x'], data['gaze_y']), 1).float()
    if use_gpu:
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
    else:
        inputs, labels = Variable(inputs), Variable(labels)
    return inputs, labels

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0

            # Iterate over data.
            for data in dataloaders[phase]:
                inputs, labels = _extract_inputs(data)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                running_loss += loss.data[0] * inputs.size(0)

            epoch_loss = running_loss / dataset_sizes[phase]
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))

            # deep copy the model
            if phase == 'test' and epoch_loss > best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best test Loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

## Train using Pre-trained Feature Extracture

In [None]:
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False # Freeze feature extractor

# Modify the model head
model_conv.avgpool = nn.AdaptiveAvgPool2d(1) # Allows for different input sizes
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

if use_gpu:
    model_conv = model_conv.cuda()

criterion = nn.MSELoss()

# Observe that only parameters of final layer are being optimized as
# opoosed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

# Train the model
model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)

## Test and Visualize Model

In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

def visualize_model(model, num_images=4):
    images_so_far = 0
    fig = plt.figure()

    for i, data in enumerate(dataloaders['test']):
        inputs, labels = _extract_inputs(data)
        outputs = model(inputs)
        
        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(1, num_images, images_so_far)
            ax.axis('off')
            ax.set_title('label: (%.2f, %.2f) \n pred: (%.2f, %.2f)' % (labels[j, 0], labels[j, 1],
                                                                      outputs[j, 0], outputs[j, 1]))
            unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            to_PIL = transforms.ToPILImage(mode='RGB')
            image = unorm(inputs.cpu().data[j])
            image = to_PIL(image)
            plt.imshow(image)
            if images_so_far == num_images:
                return

# Visualize the model
visualize_model(model_conv)