In [None]:
%load_ext autoreload
%autoreload 2

from style_transfer import *
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import PIL
import numpy as np
from scipy.misc import imread
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from dataset import CocoStuffDataSet

In [None]:
HEIGHT = WIDTH = 128
val_dataset = CocoStuffDataSet(mode='val', supercategories=['animal'], height=HEIGHT, width=WIDTH, do_normalize=False)


In [None]:
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
cnn = torchvision.models.vgg16(pretrained=True).features
cnn.cuda()
cnn.type(dtype)
# We don't want to train the model any further, so we don't want PyTorch to waste computation 
# computing gradients on parameters we're never going to update.
for param in cnn.parameters():
    param.requires_grad = False
style_layers = (0, 5, 10, 17, 24)


In [None]:
def get_style_weights(beta):
    style_weights = np.array([500, 100, 10, 10, 1])
    return [float(x) for x in beta*style_weights]


In [None]:
def get_savename(content_idx, style_background, style_foreground=None, prefix='ground_truth'):
    name = prefix + '_'
    name += "{}_{}".format(content_idx, style_background.split('.')[0])
    if style_foreground:
        name += "_and_{}".format(style_foreground.split('.')[0])
    name += ".png"
    return name

In [None]:
cache = {}

def get_images_and_masks(load_folder, upsample=False):
    """
    Return a list of (image, ground truth mask, generated mask)
    for use in style transfer from input load_folder
    Assumes directory structure:
    code/
        /saved_images_and_masks
            /<load_folder>
                /img.pk
                /gt_mask.pk
                /baseline_mask.pk
                /gan_mask.pk
    """
    load_dir = os.path.join('./saved_images_and_masks', load_folder)
    if load_dir in cache:
        return cache[load_dir]
    img = torch.load(os.path.join(load_dir, 'img.pk'))
    img = PIL.Image.fromarray(np.uint8(img.numpy().transpose(1, 2, 0)*255.))
    gt_mask = torch.load(os.path.join(load_dir, 'gt_mask.pk')).float()
    gen_mask = torch.load(os.path.join(load_dir, 'baseline_mask.pk')).float()
    gan_mask = torch.load(os.path.join(load_dir, 'gan_mask.pk')).float()

    if upsample:
        img = img.resize((2*HEIGHT, 2*WIDTH))
        gt_mask = torch.from_numpy(gt_mask.numpy().repeat(2, axis=0).repeat(2, axis=1)).float()
        gen_mask = torch.from_numpy(gen_mask.numpy().repeat(2, axis=0).repeat(2, axis=1)).float()
        gan_mask = torch.from_numpy(gan_mask.numpy().repeat(2, axis=0).repeat(2, axis=1)).float()

    cache[load_dir] = (img, gt_mask, gen_mask, gan_mask)
    return img, gt_mask, gen_mask, gan_mask

In [None]:
alpha = 1e-4
beta = 1e3
gamma = 1e-2
savedir = './saved_style_transfers'
if not os.path.exists(savedir):
    os.makedirs(savedir)
style_dir = '../styles/'


In [None]:

# backgrounds = ['starry_night.jpg', 'wave.jpg', 'the_scream.jpg']
# foregrounds = ['muse.jpg', 'guernica.jpg', 'composition_vii.jpg']

# style_pairs = [('starry_night.jpg', 'muse.jpg'),
#                ('wave.jpg', 'guernica.jpg'),
#                ('the_scream.jpg', 'muse.jpg'),
#                ('starry_night.jpg', 'composition_vii.jpg'),
#                ('muse.jpg', 'wave.jpg'),
#                ('wave.jpg', 'hot_space.jpg'),
#                ('mika.jpg', 'starry_night.jpg'),
#                ('tubingen.jpg', 'holi.jpg'),
#                ('holi.jpg', 'guernica.jpg'),
#                ('mika.jpg', 'the_scream.jpg'),
#                ('guernica.jpg', 'holi.jpg'),
#                ('hot_space.jpg', 'starry_night.jpg'),
#               ]

style_pairs = [('starry_night.jpg', 'composition_vii.jpg')]

# indices = [29, 39, 42, 52, 55, 77, 417, 514]
indices =  [55, 77, 417, 514]
# indices = [514]
for style_background_name, style_foreground_name in style_pairs:
        style_background_image = PIL.Image.open(os.path.join(style_dir, style_background_name))
        style_foreground_image = PIL.Image.open(os.path.join(style_dir, style_foreground_name))
        for idx in indices :
#             content_image, background_mask = get_image_from_dataset(val_dataset, idx)
            content_image, background_mask, generated_background_mask, gan_background_mask = \
                get_images_and_masks(str(idx), upsample=True)
            savename = get_savename(idx, style_background_name, style_foreground_name)
            savepath = os.path.join(savedir, savename)
            transfer_params = {
                'cnn' : cnn,
                'content_image' : content_image,
                'style_image' : style_background_image,
                'content_mask': background_mask,
                'image_size' : 2*HEIGHT,  # since we did upsampling
                'style_size' : 512,
                'content_layer' : 12,
                'content_weight' : alpha,
                'style_layers' : style_layers,
                'style_weights' : get_style_weights(beta),
                'tv_weight' : gamma,
                'max_iters' : 2000,
                'init_random' : False,
                'mask_layer' : True,
                'second_style_image' : style_foreground_image,
            }
#             print(idx)
#             plt.figure()
#             plt.subplot(141)
#             plt.imshow(content_image)
#             plt.subplot(142)
#             plt.imshow(background_mask)
#             plt.subplot(143)
#             plt.imshow(generated_background_mask)
#             plt.subplot(144)
#             plt.imshow(gan_background_mask)
#             plt.show()
            
    
#             final_img, final_loss, loss_list = style_transfer(**transfer_params)
#             display_style_transfer(final_img, savepath)
#             plt.semilogy(range(len(loss_list)), loss_list)
#             plt.xlabel('Iterations')
#             plt.ylabel('Total Loss')
#             plt.savefig('gt_loss.png')
#             plt.show()
            
            savename = get_savename(idx, style_background_name, style_foreground_name, prefix='gan')
            savepath = os.path.join(savedir, savename)
            transfer_params['content_mask'] = gan_background_mask
            final_img, final_loss, loss_list = style_transfer(**transfer_params)
            display_style_transfer(final_img, savepath)
#             plt.semilogy(range(len(loss_list)), loss_list)
#             plt.xlabel('Iterations')
#             plt.ylabel('Total Loss')
#             plt.savefig('gan_loss.png')
#             plt.show()



In [None]:
cnn