### Save style, content, and transformed images

In [15]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as td
import torchvision as tv
from PIL import Image
import matplotlib.pyplot as plt
import nntools_RTST as nt
from collections import namedtuple
from architecture import *

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
image_size=(256, 256)
transform = tv.transforms.Compose([
            tv.transforms.Resize(image_size),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
            ])

### Load style networks

In [18]:
transformer_cezanne = torch.load('transformer_cezanne')
transformer_irises = torch.load('transformer_irises')
transformer_starry = torch.load('transformer_starrynight')
transformer_monet = torch.load('transformer_monet')

In [19]:
def myimshow(image, ax=plt):
    image = image.to('cpu').detach().numpy()
    image = np.moveaxis(image, [0, 1, 2], [2, 0, 1])
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    h = ax.imshow(image)
    ax.axis('off')
    return h

Edit the content_list below by providing appropriate path to the test image to try out all the styles

In [None]:
style_list = {'cezanne':'/datasets/ee285f-public/wikiart/wikiart/Post_Impressionism/paul-cezanne_forest.jpg',
              'monet':'/datasets/ee285f-public/wikiart/wikiart/Impressionism/claude-monet_the-bodmer-oak-fontainebleau.jpg',
              'starry_night':'/datasets/ee285f-public/wikiart/wikiart/Post_Impressionism/vincent-van-gogh_the-starry-night-1889(1).jpg',
              'irises':'/datasets/ee285f-public/wikiart/wikiart/Post_Impressionism/vincent-van-gogh_irises-1889.jpg'}
content_list = ['/datasets/ee285f-public/flickr_landscape/forest/12560797893_2d848c02ab.jpg',
               '/datasets/ee285f-public/flickr_landscape/forest/12137794624_7347336a7d.jpg',
               '/datasets/ee285f-public/flickr_landscape/forest/11032118_8e196a11ea.jpg',
               '/datasets/ee285f-public/flickr_landscape/forest/10392082905_1e568e33dd.jpg',
               '/datasets/ee285f-public/flickr_landscape/forest/1035978583_fde8eef9bc.jpg',
               '/datasets/ee285f-public/flickr_landscape/forest/10334337455_6e27bb99dc.jpg',
               '/datasets/ee285f-public/flickr_landscape/road/10070920934_12369b1ab0.jpg',
               '/datasets/ee285f-public/flickr_landscape/road/12006507436_71206e3fb1.jpg',
                '/datasets/ee285f-public/flickr_landscape/forest/11032118_8e196a11ea.jpg']               
for style_name in style_list.keys():
    style = Image.open(style_list[style_name]).convert('RGB')
    style = transform(style)
    myimshow(style)
    sfilename = "s_"+style_list[style_name].replace('/datasets/ee285f-public/wikiart/wikiart/',style_name).replace('/','_')
    plt.savefig(sfilename)
    for content_path in content_list:
        content = Image.open(content_path).convert('RGB')
        content = transform(content)
        myimshow(content)
        cfilename = content_path.replace('/datasets/ee285f-public/flickr_landscape/','c_').replace('/','_')
        plt.savefig(cfilename)
        tfilename = style_name + cfilename
        if style_name == 'cezanne':
            myimshow(transformer_cezanne(content[None].to(device))[0])
        elif style_name == 'monet':
            myimshow(transformer_monet(content[None].to(device))[0])
        elif style_name == 'irises':
            myimshow(transformer_irises(content[None].to(device))[0])
        else:
            myimshow(transformer_starry(content[None].to(device))[0])
        plt.savefig(tfilename)