# Style Transfer

This notebook is used to train a style transfer network to convert simulated gaze data to real-looking gaze data using PyTorch.

Sources:
- [1] http://pytorch.org/tutorials/advanced/neural_style_tutorial.html

In [None]:
from pathlib import Path
import copy
import sys
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from PIL import Image
from torchvision import datasets, models, transforms


# Plotting options for notebook
%matplotlib notebook
import matplotlib.pyplot as plt
plt.ion() # Plotting in interactive mode
# fig_size = [12, 6]
# plt.rcParams["figure.figsize"] = fig_size

# Make sure we can use GPU
use_gpu = torch.cuda.is_available()
print('Gpu is enabled: %s' % use_gpu)
dtype = torch.cuda.FloatTensor if use_gpu else torch.FloatTensor

# Add local directories to path, get dataset utils
ROOT_DIR = Path.cwd()
sys.path.append(str(ROOT_DIR))
from dataset_utils import load_single_image

## Load synthetic and real image

In [None]:
# Local dataset directory
data_dir = ROOT_DIR / '..' / 'local' / 'data'

# Real and synthetic datasets
synthetic_dataset = data_dir / '100118_fixedhead' / 'train'
real_dataset = data_dir / '150118_gaze' / 'train'

# grab random images from each dataset
synthetic_image_path = str(list(synthetic_dataset.glob('*.png'))[0])
real_image_path = str(list(real_dataset.glob('*.png'))[0])

# Load images as variables (convert to gpu tensors using dtype variable)
synth_image = load_single_image(synthetic_image_path, (96, 128)).type(dtype)
real_image = load_single_image(real_image_path, (96, 128)).type(dtype)

In [None]:
# Plot images
def plot_images(image_paths):
    fig = plt.figure()
    for i, image_path in enumerate(image_paths, 1):
        # Use sublots to plot all of them
        ax = plt.subplot(1, len(image_paths), i)
        plt.tight_layout()
        image = Image.open(image_path)
        plt.imshow(image)
        ax.axis('off')
    plt.show()
    
plot_images([synthetic_image_path, real_image_path])

## Define Custom Losses

In [None]:
class ContentLoss(nn.Module):
    """Content loss keeps images from changing their content. Inherit from PyTorch's nn module"""
    
    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        # Detach target content from the tree used
        # to dynamically compute the gradient: this is a stated value,
        # not a variable. Otherwise the forward method of the criterion
        # will throw an error.
        self.target = target.detach()* weight
        self.weight = weight
        self.criterion = nn.MSELoss()
        
    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
    
class GramMatrix(nn.Module):
    """GramMatrix is used for Style Loss. Inherits from PyTorch's nn module"""
    
    def forward(self, input):
        # batch size, feature maps, dimmensions of feature map
        a, b, c, d = input.size()
        # Reshape to vector
        features = input.view(a * b, c * d)
        G = torch.mm(features, features.t()) # Gram product
        return G.div(a*b*c*d) # normalize and return
    
class StyleLoss(nn.Module):
    """Style loss makes the image styles more similar. Inherits from PyTorch's nn module"""
    def __init__(self, target, weight):
        super(StyleLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.gram = GramMatrix()
        self.criterion = nn.MSELoss()
        
    def forward(self, input):
        self.output = input.clone()
        self.G = self.gram(input)
        self.G.mul_(self.weight)
        self.loss = self.criterion(self.G, self.target)
        return self.output
    
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss

## Build Model

In [None]:
# Load feature extractor component of VGG19
feature_extractor = models.vgg19(pretrained=True).features
if use_gpu:
    feature_extractor = feature_extractor.cuda()
    
# Style and content losses are calculated on specific layers
content_layers_default = ['conv_1']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, style_img, content_img,
                               style_weight=1000, content_weight=1,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)

    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []

    model = nn.Sequential()  # the new Sequential module network
    gram = GramMatrix()  # we need a gram module in order to compute style targets

    # move these modules to the GPU if possible:
    if use_gpu:
        model = model.cuda()
        gram = gram.cuda()

    i = 1
    for layer in list(cnn):
        if isinstance(layer, nn.Conv2d):
            name = "conv_" + str(i)
            model.add_module(name, layer)

            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = ContentLoss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)

            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = StyleLoss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)

        if isinstance(layer, nn.ReLU):
            name = "relu_" + str(i)
            model.add_module(name, layer)

            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = ContentLoss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)

            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = StyleLoss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)

            i += 1

        if isinstance(layer, nn.MaxPool2d):
            name = "pool_" + str(i)
            model.add_module(name, layer)  # ***

    return model, style_losses, content_losses

## Gradient Descent

In [None]:
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=800, style_weight=1000, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
        style_img, content_img, style_weight, content_weight)
    
     # Show that input is a parameter that requires a gradient
    input_param = nn.Parameter(input_img.data)
    optimizer = optim.LBFGS([input_param])

    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            input_param.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_param)
            style_score = 0
            content_score = 0

            for sl in style_losses:
                style_score += sl.backward()
            for cl in content_losses:
                content_score += cl.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.data[0], content_score.data[0]))
                print()

            return style_score + content_score

        optimizer.step(closure)

    # a last correction...
    input_param.data.clamp_(0, 1)

    return input_param.data

In [None]:
# Clone the real image and run style transfer on it
input_img = synth_image.clone()
output = run_style_transfer(feature_extractor, synth_image, real_image, input_img)

In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

# Convert output tensor back to image
unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
to_PIL = transforms.ToPILImage(mode='RGB')
image = unorm(output.cpu())
image = to_PIL(image[0])
            
# Show final image
plt.figure()
plt.imshow(image)
plt.show()