In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#import os
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
#os.environ["CUDA_VISIBLE_DEVICES"]="2"

import matplotlib
import matplotlib.pyplot as plt

from options.test_options import TestOptions
from options.train_options import TrainOptions
from data import CreateDataLoader, CreateStyleDataLoader
from models import create_model
from util.visualizer import save_images
from util import html
import torch

from collections import OrderedDict

import numpy as np

In [None]:
def renorm_0_1(img, bg = None):
    if bg is None:
        img = (img - img.min()) / (img.max() - img.min())
        return img
    else:
        img = (img - img[bg == 0].min()) / (img[bg == 0].max() - img[bg == 0].min())
        img[bg == 1] = -2.0
        return img

### Which GPUs should be used

In [None]:
gpu_ids = '0'

### Initialize and load the  model

In [None]:
# Initialize original model
import sys
sys.argv = ['test.py',
            '--checkpoints_dir', './samples/models/',
            '--name', 'GanAuxPretrained',
            '--model', 'gan_aux',
            '--netG', 'resnet_residual',
            '--netD', 'disc_noisy',
            '--epoch', '200',
            '--gpu_ids', gpu_ids,
            '--peer_reg', 'bidir']
            

opt_ours = TestOptions().parse()
# hard-code some parameters for test
opt_ours.num_threads = 1   # test code only supports num_threads = 1
opt_ours.batch_size = 1    # test code only supports batch_size = 1
opt_ours.serial_batches = True  # no shuffle
opt_ours.no_flip = True    # no flip
opt_ours.display_id = -1   # no visdom display
opt_ours.num_style_samples = 1
opt_ours.knn = 5
opt_ours.eval = True

model_ours = create_model(opt_ours)
model_ours.setup(opt_ours)

# test with eval mode. This only affects layers like batchnorm and dropout.
if model_ours.eval:
    model_ours.eval()

### Prepare data for content and style

In [None]:
from PIL import Image
import torchvision.transforms as transforms

def get_transform(loadSize = 512, fineSize = 512, pad = None):
    transform_list = []
    
    transform_list.append(transforms.Resize(loadSize, Image.BICUBIC))
    transform_list.append(transforms.CenterCrop(fineSize))
    if pad is not None:
        transform_list.append(transforms.Pad(pad, padding_mode='reflect'))
    
    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

sz = 512
fsz = sz
padsz = 8 * (sz // 256)
transform_fn = get_transform(loadSize = sz, fineSize = fsz, pad = padsz)
transform_fn_style = get_transform(loadSize = sz, fineSize = fsz)

def get_image(A_path, transf_fn):        
    A_img = Image.open(A_path).convert('RGB')
    A = transf_fn(A_img)
    return A

imgs = []
for i in range(1,9):
    img = get_image('samples/data/content_imgs/img%i.jpg' % i, transform_fn)
    imgs.append(img)
imgs = torch.stack(imgs)
    
print(img.shape)

plt.figure()
plt.imshow(renorm_0_1(imgs[0].permute(1, 2, 0)))


styles = []
for i in range(1,9):
    style = get_image('samples/data/style_imgs/style%i.jpg' % i, transform_fn_style)
    styles.append(style)
styles = torch.stack(styles)
    
print(style.shape)

plt.figure()
plt.imshow(renorm_0_1(styles[0].permute(1, 2, 0)))



In [None]:
num_runs = len(imgs)
num_styles = len(styles)

# Test our model
model = model_ours

outs_ours = []

for i in range(num_runs):   
    for j in range(num_styles):
        print('Processing sample %i.%i ...' % (i, j))
        #j = i
        real_A = imgs[i:i+1]
        style_B = styles[j:j+1]
        with torch.no_grad():
            fake_B, z_cont_real_A, z_style_real_A, z_cont_style_B, z_style_B = model.netG.module.stylize_image([imgs[i:i+1].cuda(), styles[j:j+1].cuda()])
        
            if padsz is not None:
                real_A = real_A[:, :, padsz:-padsz, padsz:-padsz]
                fake_B = fake_B[:, :, padsz:-padsz, padsz:-padsz]

        out_dict = {
            'real_A': real_A.data.cpu().numpy()[0].transpose((1,2,0)), 'fake_B': fake_B.data.cpu().numpy()[0].transpose((1,2,0)), 
            'z_cont_real_A': z_cont_real_A.data.cpu().numpy(), 'z_cont_style_B': z_cont_style_B.data.cpu().numpy(),
            'z_style_real_A': z_style_real_A.data.cpu().numpy(), 'z_style_B': z_style_B.data.cpu().numpy(), 
            'style_B': style_B.data.cpu().numpy().transpose(0, 2, 3, 1)
        }
        outs_ours.append(out_dict)

    

### Visualize some examples (in three columns: (source image, stylized image, target style)

In [None]:
import scipy as sp
for i in range(len(outs_ours)):
    out_dict = outs_ours[i]
    real_A, fake_B = out_dict['real_A'], out_dict['fake_B']
    style_B = out_dict['style_B']

    fig = plt.figure(figsize = (16, 8))
    ax = plt.subplot(131)
    ax.set_title('Real A')
    plt.imshow(renorm_0_1(real_A))
    plt.axis('off')
    ax = plt.subplot(132)
    ax.set_title('Fake B')
    plt.imshow(renorm_0_1(fake_B))
    plt.axis('off')
    ax = plt.subplot(133)
    ax.set_title('Style B')
    plt.imshow(renorm_0_1(style_B[0]))
    plt.axis('off')

# 