In [None]:
%load_ext autoreload
%autoreload 2

from style_transfer import *
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import PIL
import numpy as np
from scipy.misc import imread
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from dataset import CocoStuffDataSet

In [None]:
HEIGHT = WIDTH = 128
val_dataset = CocoStuffDataSet(mode='val', supercategories=['animal'], height=HEIGHT, width=WIDTH, do_normalize=False)


In [None]:
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
cnn = torchvision.models.vgg16(pretrained=True).features
cnn.cuda()
# cnn.type(dtype)
# We don't want to train the model any further, so we don't want PyTorch to waste computation 
# computing gradients on parameters we're never going to update.
for param in cnn.parameters():
    param.requires_grad = False
style_layers = (0, 5, 10, 17, 24)


In [None]:
def get_style_weights(beta):
    style_weights = np.array([100, 100, 10, 10, 1])
    return [float(x) for x in beta*style_weights]


In [None]:
def get_savename(content_idx, style_background, style_foreground=None, ground_truth=True):
    if ground_truth:
        name = "ground_truth_"
    else:
        name = "generated_"
    name += "{}_{}".format(content_idx, style_background.split('.')[0])
    if style_foreground:
        name += "_and_{}".format(style_foreground.split('.')[0])
    name += ".png"
    return name

In [None]:
cache = {}

def get_images_and_masks(load_folder, upsample=False):
    """
    Return a list of (image, ground truth mask, generated mask)
    for use in style transfer from input load_folder
    Assumes directory structure:
    code/
        /saved_images_and_masks
            /<load_folder>
                /img.pk
                /gt_mask.pk
                /gen_mask.pk
    """
    load_dir = os.path.join('./saved_images_and_masks', load_folder)
    if load_dir in cache:
        return cache[load_dir]
    img = torch.load(os.path.join(load_dir, 'img.pk'))
    img = PIL.Image.fromarray(np.uint8(img.numpy().transpose(1, 2, 0)*255.))
    gt_mask = torch.load(os.path.join(load_dir, 'gt_mask.pk')).float()
    gen_mask = torch.load(os.path.join(load_dir, 'gen_mask.pk')).float()
    if upsample:
        img = img.resize((2*HEIGHT, 2*WIDTH))
        gt_mask = torch.from_numpy(gt_mask.numpy().repeat(2, axis=0).repeat(2, axis=1)).float()
        gen_mask = torch.from_numpy(gen_mask.numpy().repeat(2, axis=0).repeat(2, axis=1)).float()
    cache[load_dir] = (img, gt_mask, gen_mask)
    return img, gt_mask, gen_mask    

In [None]:
alpha = 1e-4
beta = 1e3
gamma = 1e-2
savedir = './saved_style_transfers'
if not os.path.exists(savedir):
    os.makedirs(savedir)
style_dir = '../styles/'


In [None]:
# backgrounds = ['the_scream.jpg', 'starry_night.jpg', 'tsunami.jpg']
# foregrounds = ['guernica.jpg', 'guitar.jpg']

backgrounds = ['starry_night.jpg']
foregrounds = ['muse.jpg']

background_images = [PIL.Image.open(os.path.join(style_dir, n)) for n in backgrounds]
foreground_images = [PIL.Image.open(os.path.join(style_dir, n)) for n in foregrounds]
# indices = [16, 27, 29, 37, 38, 39, 40, 42, 52, 55, 77, 514]
indices = [77]
for style_background_name in backgrounds:
    for style_foreground_name in foregrounds:
        style_background_image = PIL.Image.open(os.path.join(style_dir, style_background_name))
        style_foreground_image = PIL.Image.open(os.path.join(style_dir, style_foreground_name))
        for idx in indices :
#             content_image, background_mask = get_image_from_dataset(val_dataset, idx)
            content_image, background_mask, generated_background_mask = get_images_and_masks(str(idx), upsample=True) 
            savename = get_savename(idx, style_background_name, style_foreground_name, ground_truth=True)
            savepath = os.path.join(savedir, savename)
            transfer_params = {
                'cnn' : cnn,
                'content_image' : content_image,
                'style_image' : style_background_image,
                'content_mask': background_mask,
                'image_size' : 2*HEIGHT,  # since we did upsampling
                'style_size' : 512,
                'content_layer' : 12,
                'content_weight' : alpha,
                'style_layers' : style_layers,
                'style_weights' : get_style_weights(beta),
                'tv_weight' : gamma,
                'max_iters' : 3000,
                'init_random' : False,
                'mask_layer' : True,
                'second_style_image' : style_foreground_image 
            }

            final_img, final_loss, loss_list = style_transfer(**transfer_params)
            display_style_transfer(final_img, savepath)
            loss_list = loss_list[100:]
            plt.plot(range(len(loss_list)), loss_list)
            plt.show()
            
#             savename = get_savename(idx, style_background_name, style_foreground_name, ground_truth=False)
#             savepath = os.path.join(savedir, savename)
#             transfer_params['content_mask'] = generated_background_mask
#             final_img, final_loss, loss_list = style_transfer(**transfer_params)
#             display_style_transfer(final_img, savepath)
#             loss_list = loss_list[100:]
#             plt.plot(range(len(loss_list)), loss_list)
#             plt.show()