# Style Transfer with Deep Neural Networks

Implement the [Gatys style transfer paper](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf) in PyTorch.

In [None]:
from PIL import Image
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms, models

In [None]:
vgg = models.vgg19(pretrained=True).features
# Freeze VGG parameters since we're only optimizing the target image
for param in vgg.parameters():
    param.requires_grad_(False)

In [None]:
def load_image(img_path, max_size=300, shape=None):
    image = Image.open(img_path)
    
    # Large images will kill your computer so keep them reasonable
    if max(img.size) > max_size:
        size = max_size
    else:
        size = max(img.size)
    
    if shape is not None:
        size = shape
        
    in_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])

    # Discard alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze_(0)
    
    return image

In [None]:
def get_features(image, model, layers=None):
    """ Run an image forward through a model and get the features for 
        a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
    """

    # To get the layer outputs we need to pass the image forward through the network 
    # until we get to a desired layer and get the output from that layer
    
    # Mapping layer names of PyTorch's VGGNet to layer names from the paper
    if layers is None:
        layers = {'0': 'conv1_1',
                  '5': 'conv2_1', 
                  '10': 'conv3_1', 
                  '19': 'conv4_1',
                  '28': 'conv5_1'}
    features = {}
    x = image
    # model._modules is a dictionary holding each module in the model
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features

In [None]:
def gram_matrix(tensor):
    """ Calculate the Gram Matrix of a tenser 
        Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
    """
    
    # Batch, channels, height, width
    _, c, h, w = tensor.size()
    
    # Reshape so we're multiplying the features for each channel
    tensor = tensor.view(c, h * w)
    
    gram = torch.mm(tensor, tensor.t())
    
    # Normalize the gram matrix
    return gram / (c * h * w)

In [None]:
content = load_image('octavia.jpg')
# Resize style to match content, makes code easier
style = load_image('greenhouse.png', shape=content.shape[-2:])

# Calculate content and style features only once before training
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

In [None]:
style_weights = {'conv1_1': 0.5,
                 'conv2_1': 0.75,
                 'conv3_1': 1,
                 'conv4_1': 1.5,
                 'conv5_1': 2}

target = torch.rand_like(content).requires_grad_(True)

In [None]:
steps = 10
optimizer = optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
loss_ratio = 1e-3
show_every = 1

for ii in range(1, steps+1):
    
    target_features = get_features(target, vgg)
    content_loss = 0.5*(content_features['conv4_1'] - target_features['conv4_1']).sum()**2
    
    target_grams = {layer: gram_matrix(target_features[layer]) for layer in target_features}
    layer_losses = {each: (style_features[each] - target_features[each]).sum()**2 for each in target_features}
    style_loss = sum(style_weights[each]*layer_losses[each] for each in layer_losses)
    
    total_loss = loss_ratio * content_loss + style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if  ii % show_every == 0:
        print(total_loss.item())
        imshow(target)
        plt.show()

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

In [None]:
def imshow(tensor):
    """ Display a tensor as an image. """
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    plt.imshow(image)