In [1]:
import os
import ipywidgets as widgets
import matplotlib.pyplot as plt
import torch

from IPython.display import display

from time import time
from PIL import ImageFile

from csfnst.fastneuralstyle.networks import TransformerNet, BottleneckType
from csfnst.utils import load_image, plot_image_tensor

ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
force_cpu = True
device_type = 'cuda' if torch.cuda.is_available() and not force_cpu else 'cpu'
device = torch.device(device_type)
content_image_path = '../images/content/htw.jpg'
content_image = load_image(content_image_path).to(device)

checkpoints_dir = '../checkpoints/'
models = [model for model in os.listdir(checkpoints_dir) if model != '.gitkeep']
models.sort()

out = widgets.Output() 
dropdown_models = widgets.Dropdown(options=models, description='Model:')
button_stylize = widgets.Button(description='Stylize', button_style='Info')
hbox = widgets.HBox([dropdown_models, button_stylize])

def perform_style_transfer(button):
    out.clear_output()
    with out:
        print('Calculating...')
        
    start = time()
    
    if device_type == 'cuda':
        checkpoint = torch.load(checkpoints_dir + dropdown_models.value)
    else:
        checkpoint = torch.load(checkpoints_dir + dropdown_models.value, map_location={'cuda:0': 'cpu'})

    model = TransformerNet(
        channel_multiplier=checkpoint['channel_multiplier'],
        expansion_factor=checkpoint['expansion_factor'],
        bottleneck_type=BottleneckType[checkpoint['bottleneck_type'].replace('BottleneckType.', '')],
        bottleneck_size=checkpoint['bottleneck_size'],
        intermediate_activation_fn=checkpoint['intermediate_activation_fn'],
        final_activation_fn=checkpoint['final_activation_fn']
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    output_image = model(content_image.unsqueeze(0)).squeeze()
    ouput_image = output_image.squeeze()
    end = time()
        
    out.clear_output()
    with out: 
        print(f'Took {(end - start):.5} sec')
        fig, ax = plt.subplots(1, 1, dpi=200)
        #fig, axes = plt.subplots(1, 2, dpi=200)
        fig.set_size_inches(24, 24)

        #plot_image_tensor(input_image, ax=axes[0])
        #plot_image_tensor(output_image, ax=axes[1])
        plot_image_tensor(output_image, ax=ax)

        plt.show()
        
button_stylize.on_click(perform_style_transfer)

with out:
    print('Select model and click Stylize!')

display(hbox, out)

HBox(children=(Dropdown(description='Model:', options=('experiment2__net01__s5__m32__residual_block.pth', 'exp…

Output()