In [7]:
from style_transfer import *
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import PIL
import numpy as np
from scipy.misc import imread
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from dataset import CocoStuffDataSet

In [8]:
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

HEIGHT = WIDTH = 256
val_dataset = CocoStuffDataSet(mode='val', supercategories=['animal'], height=HEIGHT, width=WIDTH, do_normalize=False)
# content_image, background_mask = get_image(val_dataset, args.content_index)
content_image, background_mask = get_image(val_dataset, 417)
foreground_mask = 1.0 - background_mask

cnn = torchvision.models.vgg16(pretrained=True).features
style_layers = (0, 5, 10, 17, 24)

cnn.type(dtype)
# We don't want to train the model any further, so we don't want PyTorch to waste computation 
# computing gradients on parameters we're never going to update.
for param in cnn.parameters():
    param.requires_grad = False

style_dir = '../styles/'
# style_background_name = args.background_style
style_background_name = 'starry_night.jpg'
# style_foreground_name = args.foreground_style
style_foreground_name = None

style_background_image = PIL.Image.open(os.path.join(style_dir, style_background_name))
if style_foreground_name:
    style_foreground_image = PIL.Image.open(os.path.join(style_dir, style_foreground_name))
else:
    style_foreground_image = None

transfer_params = {
    'content_image' : content_image,
    'style_image' : style_background_image,
    'content_mask': background_mask,
    'image_size' : HEIGHT,
    'content_layer' : 12,
    'content_weight' : 1e-3,
    'style_layers' : style_layers,
    'style_weights' : (.02, .02, .02, .02, .02),
    # 'tv_weight' : 1e-2,
    'tv_weight' : 0,
    'init_random' : False,
    'mask_layer' : True,
    'second_style_image' : style_foreground_image 
}


loading annotations into memory...
Done (t=0.60s)
creating index...
index created!
Loaded 1016 samples: 


In [None]:
final_img = style_transfer(**transfer_params)
display_style_transfer(final_img, 'test.png')