# In search of the stars...

Let's try to reproduce an iconic Van Gogh style transfer from the 2016 Gatys et al paper. See [Readme](REAME.md) for details.

In this notebook, code from [Style Transfer Excersize notebook](https://github.com/udacity/deep-learning-v2-pytorch/blob/master/style-transfer/Style_Transfer_Exercise.ipynb) of [Udacity/deep-learning-v2-pytorch](https://github.com/udacity/deep-learning-v2-pytorch) is heavily borrowed. 

In [1]:
# import resources
%matplotlib inline

import os
import re
import datetime

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms, models

## Image manipulation helpers

In [2]:
def load_image(img_path, max_size=400, shape=None):
    ''' Load in and transform an image, making sure the image
       is <= 400 pixels in the x-y dims.'''
    
    image = Image.open(img_path).convert('RGB')
    
    # large images will slow down processing
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)
    
    if shape is not None:
        size = shape
        
    in_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406), 
                                             (0.229, 0.224, 0.225))])

    # discard the transparent, alpha channel (that's the :3) and add the batch dimension
    image = in_transform(image)[:3,:,:].unsqueeze(0)
    
    return image

In [3]:
# helper function for un-normalizing an image 
# and converting it from a Tensor image to a NumPy image for display
def im_convert(tensor):
    """ Display a tensor as an image. """
    
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

## Style Tranfer functions

In [4]:
# mapping of layer names to the names found in the paper for the content representation and the style representation.
def get_features(image, model, layers=None):
    """ Run an image forward through a model and get the features for 
        a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
    """
    
    ## mapping layer names of PyTorch's VGGNet to names from the paper
    ## Need the layers for the content and style representations of an image
    if layers is None:
        layers = {'0': 'conv1_1', '2': 'conv1_2', 
                  '5': 'conv2_1', '7': 'conv2_2',
                  '10': 'conv3_1', '12': 'conv3_2', '14': 'conv3_3', '16': 'conv3_4', 
                  '19': 'conv4_1', '21': 'conv4_2', '23': 'conv4_3', '25': 'conv4_4', 
                  '28': 'conv5_1', '30': 'conv5_2', '32': 'conv5_3', '34': 'conv5_4'}

    features = {}
    x = image
    
    # model._modules is a dictionary holding each module in the model
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
            
    return features

In [5]:
# The Gram matrix fuction
def gram_matrix(tensor):
    """ Calculate the Gram Matrix of a given tensor 
        Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
    """
    
    ## get the batch_size, depth, height, and width of the Tensor
    batch_size, d, h, w = tensor.size()
    ## reshape it, so we're multiplying the features for each channel
    new_t = tensor.view(d, h * w)
    ## calculate the gram matrix
    gram = torch.mm(new_t, new_t.t())
    
    return gram 

In [6]:
# Helper to initialize target image
def initialize_target(source, random=None):
    """Make a new image for future target
    
    Parameters:
    
    source = either content or style, it will be cloned
    
    random = initialize a random image, in this case dimentions are taken from source, 
             and degree of randomness can also be specified.
    
    Returns:
    
    Initialized target image.
    
    """
    
    if random is not None:
        target = torch.rand_like(source) * random + source * (1 - random)
        
    else:
        target = source.clone()

    return target.requires_grad_(True)

## Loss calculation

In [7]:
# Define content, style and total losses
def style_transfer_loss(target_features, content_features, style_grams, 
                        style_weights, content_weight, style_weight, 
                        return_all_losses=False):
    
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
    
    # the style loss
    # initialize the style loss to 0
    style_loss = 0
    
    # iterate through each style layer and add to the style loss
    for layer in style_weights:
        # get the "target" style representation for the layer
        target_feature = target_features[layer]
        _, d, h, w = target_feature.shape
        
        ## Calculate the target gram matrix
        target_gram = gram_matrix(target_feature)
        
        ## get the "style" style representation
        style_gram = style_grams[layer]
        
        ## Calculate the style loss for one layer, weighted appropriately
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
        
        # add to the style loss
        style_loss += layer_style_loss / (d * h * w)
        
        
    ## calculate the *total* loss
    total_loss = content_weight * content_loss + style_weight * style_loss
    
    return total_loss if not return_all_losses else (total_loss, content_loss, style_loss)

In [8]:
#helper to report loss during iterations
def report_loss(target_features, content_features, style_grams, 
                style_weights, content_weight, style_weight):
    losses_desc = ["total", "content", "style"]
    losses = style_transfer_loss(target_features, content_features, style_grams, 
                                 style_weights, content_weight, style_weight, 
                                return_all_losses=True)
    losses_map = map(torch.Tensor.item, losses)
    loss_dict = {a : "{0:.2f}".format(b) for a,b in zip(losses_desc, losses_map)}
    loss_report = 'Loss: {}'.format(loss_dict)
    return loss_report

In [9]:
# This is a metric of similarity between a generated image 
# and the desired reference image (true target)
# This is MAE
def true_target_loss(true_target, target):
    return torch.mean(torch.abs(true_target-target))

## Helpers for Ray Tune

In [10]:
# helper to generate style_weights for grid search
# adapted from https://stackoverflow.com/questions/51908760/rewriting-a-function-that-creates-combinations-of-numbers-with-a-fixed-sum-as-a
# it generates lists with length_of_list integer values, with a condition that the sum is fixed_sum.
# for our case of style_weights for 5 layers, we want length_of_list=5, fixed_sum=10 9and then divide by 10)

def combinations_fixed_sum(fixed_sum, length_of_list, lst=[]):
    if length_of_list == 1:
        lst += [fixed_sum]
        yield lst
    else:
        for i in range(fixed_sum+1):
            yield from combinations_fixed_sum(i, length_of_list-1, lst + [fixed_sum-i])

# exclude elements with 0 weights and dived
#print(list(map(lambda x: [y/10 for y in x], filter(lambda x: not 0 in x, combinations_fixed_sum(10, 5)))))
#[[0.6, 0.1, 0.1, 0.1, 0.1], [0.5, 0.2, 0.1, 0.1, 0.1], [0.5, 0.1, 0.2, 0.1, 0.1], ...

## Implemetaton of style transfer method in one function

In [11]:
# a complete image generation routine to pass to the Experiment run.

# each Tune trial is executed with current dir changed to trials's dir  
# in ~/ray_results/my_experiment
# But we need to get images from the repo dir, save it
cwd = os.getcwd() + '/'

def generate_image_with_config(config, reporter):
    
    #config 
    config_keys = [
        'style_weights', 
        'content_weight', 'style_weight',
        'steps' 
    ]
    
    (
        style_weights_values, 
        content_weight, style_weight,
        steps
    ) = [config[x] for x in config_keys]
    
    show_every = None
    learning_rate = 0.1
    
    style_weights_layers = ['conv1_1', 'conv2_1', 'conv3_1','conv4_1','conv5_1']
    
    style_weights = {x[0]:x[1] for x in zip(style_weights_layers, style_weights_values)}
    
    # get the "features" portion of VGG19 (we will not need the "classifier" portion)
    vgg = models.vgg19(pretrained=True).features

    # freeze all VGG parameters since we're only optimizing the target image
    for param in vgg.parameters():
        param.requires_grad_(False)
    
    # move the model to GPU, if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    vgg.to(device)
    
    # load in content and style image
    content_image = cwd + 'Tuebingen_Neckarfront_gatys_paper_002.jpg'
    content = load_image(content_image).to(device)
    
    # Resize style to match content, makes code easier
    style_image = cwd + 'Van_Gogh_Starry_Night_gatys_paper_004.jpg'
    style = load_image(style_image, shape=content.shape[-2:]).to(device)

    # load in true target image
    # Resize  to match content, makes code easier
    true_target_image = cwd + 'Van_Gogh_true_target_gatys_paper_021.jpg'
    true_target = load_image(true_target_image, shape=content.shape[-2:]).to(device)
    
    #initialize target 
    target = initialize_target(content)
   
    
    # get content and style features only once before forming the target image
    content_features = get_features(content, vgg)
    style_features = get_features(style, vgg)

    # calculate the gram matrices for each layer of our style representation
    style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
    
    
    #optimizer = optim.LBFGS([target], lr=0.2)
    optimizer = optim.Adam([target], lr=learning_rate)
    
    for ii in range(1, steps+1):

        target_features = get_features(target, vgg)

        # display intermediate images and print the loss
        if  show_every is not None and ii % show_every == 0:
            print(ii, report_loss(target_features, content_features, style_grams, 
                                  style_weights, content_weight, style_weight))
            #plt.imshow(im_convert(target))
            #plt.show()
            #reporter(true_target_loss=true_target_loss(true_target, target).cpu().item())
        
        # closure() is required for LBFGS
        # see example at https://pytorch.org/docs/stable/optim.html
        def closure():
            optimizer.zero_grad()
            total_loss = style_transfer_loss(target_features, content_features, style_grams, 
                                             style_weights, content_weight, style_weight)
            total_loss.backward(retain_graph=True)
            
            #reporter(total_loss=total_loss.cpu().item()) # report metrics
            
            return total_loss

        optimizer.step(closure)
    
    #save new target image
    target_name = "_".join(
        [ k+"_"+str(config[k]).replace(', ','_').replace('[','').replace(']','') for k in config_keys ]
    )
    rgb = (im_convert(target)*255).astype('uint8')
    Image.fromarray(rgb).save("target"+target_name+".png")
    
    # this is for tune
    reporter(true_target_loss=true_target_loss(true_target, target).cpu().item())

## Hyperparameter search with Ray Tune

In [12]:
# pip install ray
import ray
import ray.tune as tune

ray.init()

Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-01-07_21-20-18_22662/logs.
Waiting for redis server at 127.0.0.1:31567 to respond...
Waiting for redis server at 127.0.0.1:40475 to respond...
Starting the Plasma object store with 20.0 GB memory using /dev/shm.

View the web UI at http://localhost:8890/notebooks/ray_ui.ipynb?token=5dd2224099751223de4f5574b808b18d7345cdad994ce5d1



{'node_ip_address': '192.168.2.7',
 'redis_address': '192.168.2.7:31567',
 'object_store_addresses': ['/tmp/ray/session_2019-01-07_21-20-18_22662/sockets/plasma_store'],
 'raylet_socket_names': ['/tmp/ray/session_2019-01-07_21-20-18_22662/sockets/raylet'],
 'webui_url': 'http://localhost:8890/notebooks/ray_ui.ipynb?token=5dd2224099751223de4f5574b808b18d7345cdad994ce5d1'}

In [13]:
# come up with a name for the folder where "experiment" results will be stored
exp_name = "van_gogh_"+ datetime.datetime.now().strftime('%Y%m%d-%H%M')
exp_name

'van_gogh_20190107-2120'

In [None]:
all_trials = tune.run_experiments({
    exp_name: {
        "run": generate_image_with_config,
        "stop": {"true_target_loss": 0.1},
        "config": {
            'style_weights': tune.grid_search(
                list(map(lambda x: [y/10 for y in x], filter(lambda x: not 0 in x, combinations_fixed_sum(10, 5))))[::30]
            ),
            'content_weight': tune.grid_search([1e-1, 1, 10.0]), 
            'style_weight': tune.grid_search([1.0]),
            'steps': tune.grid_search([5000]),
        },
        "resources_per_trial": {
            "gpu": 1, 'cpu': 1,
        },
    }
},
verbose=False,
)

== Status ==
Using FIFO scheduling algorithm.
Resources requested: 0/6 CPUs, 0/1 GPUs
Unknown memory usage. Please run `pip install psutil` (or ray[debug]) to resolve)

Created LogSyncer for /home/artem/ray_results/van_gogh_20190107-2120/generate_image_with_config_0_content_weight=0.1,steps=5000,style_weight=1.0,style_weights=[0.6, 0.1, 0.1, 0.1, 0.1]_2019-01-07_21-26-16nrs32v2j -> 
== Status ==
Using FIFO scheduling algorithm.
Resources requested: 1/6 CPUs, 1/1 GPUs
Unknown memory usage. Please run `pip install psutil` (or ray[debug]) to resolve)
Result logdir: /home/artem/ray_results/van_gogh_20190107-2120
PENDING trials:
 - generate_image_with_config_1_content_weight=1,steps=5000,style_weight=1.0,style_weights=[0.6, 0.1, 0.1, 0.1, 0.1]:	PENDING
 - generate_image_with_config_2_content_weight=10.0,steps=5000,style_weight=1.0,style_weights=[0.6, 0.1, 0.1, 0.1, 0.1]:	PENDING
 - generate_image_with_config_3_content_weight=0.1,steps=5000,style_weight=1.0,style_weights=[0.3, 0.1, 0.2, 0.1,

Created LogSyncer for /home/artem/ray_results/van_gogh_20190107-2120/generate_image_with_config_3_content_weight=0.1,steps=5000,style_weight=1.0,style_weights=[0.3, 0.1, 0.2, 0.1, 0.3]_2019-01-07_21-35-01nmecazff -> 
== Status ==
Using FIFO scheduling algorithm.
Resources requested: 0/6 CPUs, 0/1 GPUs
Unknown memory usage. Please run `pip install psutil` (or ray[debug]) to resolve)
Result logdir: /home/artem/ray_results/van_gogh_20190107-2120
PENDING trials:
 - generate_image_with_config_4_content_weight=1,steps=5000,style_weight=1.0,style_weights=[0.3, 0.1, 0.2, 0.1, 0.3]:	PENDING
 - generate_image_with_config_5_content_weight=10.0,steps=5000,style_weight=1.0,style_weights=[0.3, 0.1, 0.2, 0.1, 0.3]:	PENDING
 - generate_image_with_config_6_content_weight=0.1,steps=5000,style_weight=1.0,style_weights=[0.2, 0.1, 0.3, 0.1, 0.3]:	PENDING
 - generate_image_with_config_7_content_weight=1,steps=5000,style_weight=1.0,style_weights=[0.2, 0.1, 0.3, 0.1, 0.3]:	PENDING
 - generate_image_with_confi

### Once you get this running, 

you can run `tersorboard` as follow:

`tensorboard --logdir=~/ray_results/my_experiment`

and aim your browser at:

`http://localhost:6006/`

We are expecially interested at true_target_loss graph, as it will show the values. We are looking at the lowest number here.


## Display results

In [None]:
#find all images
from pathlib import Path
from glob import glob

all_images = sorted(glob(str(Path.home())+"/ray_results/"+exp_name+"/*/*png"))

#this shoud equal to the number of the trials
len(all_images)

In [None]:
n_cols = 3 #can be number of cotent weights in the search grid here, for viewability
n_rows = len(all_images) // 3 + 1
f, axarr = plt.subplots( n_rows , n_cols,  figsize=(6*n_cols, 6*n_rows))
#f, axarr = plt.subplots( n_cols, len(all_images) // 3 + 1)
curr_row = 0

for index, im_name in enumerate(all_images):
    fb = open(im_name, "rb")
    a = plt.imread(fb)

    # find the column by taking the current index modulo 3
    col = index % n_cols
    
    # plot on relevant subplot
    axarr[curr_row, col].imshow(a)
    axarr[curr_row, col].axis('off')
    #make plot title from image name, 
    #e.g targetstyle_weights_0.6_0.1_0.1_0.1_0.1_content_weight_1_style_weight_1.0_steps_5000.png
    title = re.sub('\.\w+$', '', os.path.basename(im_name))
    title1 = re.findall("[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", title)
    title = "{0}-{1}-{2}-{3}-{4}\nstyle/content={6}/{5}\nsteps={7}".format(*title1)
    axarr[curr_row, col].set_title(title)
    if col == 2:
         # we have finished the current row, so increment row counter
        curr_row += 1

# TODO


* tune config style weights via chain
* combinations_fixed_sum: add option to exclude conbinations with 0.
* add learning rate to grid_search
* redirect output of tune.run_experiments cell, so it doesn't clutter the notebook