In [1]:
import os

from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
import datetime

import torch
import ipywidgets as widgets
import torch.optim as optim
import torch.nn as nn

import torchvision.models as models

from csfnst.utils import load_image, plot_image_tensor, save_image_tensor
from csfnst.utils import rename_network_layers, replace_network_layers, get_criterion
from csfnst.losses import PerceptualLoss
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
force_cpu = True
device_type = 'cuda' if torch.cuda.is_available() and not force_cpu else 'cpu'
device = torch.device(device_type)
content_image_path = '../images/content/htw.jpg'
size_options = [('256x256', (256, 256)), ('512x512', (512, 512))]

style_images_dir = '../images/style/'
style_images = [model for model in os.listdir(style_images_dir) if model != '.gitkeep']
style_images.sort()

out = widgets.Output() 
dropdown_styles = widgets.Dropdown(options=style_images, description='Style')
radios_optimizer = widgets.RadioButtons(options=['L-BFGS', 'Adam'], value='Adam', description='Optimizer')
radios_input = widgets.RadioButtons(options=['Noise', 'Content Image'], value='Noise', description='Input')
slider_cw = widgets.FloatLogSlider(value=1, base=10, min=0, max=8, description='Content Weight')
slider_sw = widgets.FloatLogSlider(value=1e7, base=10, min=0, max=8, description='Style Weight')
slider_tvw = widgets.FloatLogSlider(value=1e-6, base=10, min=-10, max=0, description='TV Weight')
slider_iterations = widgets.IntSlider(value=250, min=1, max=1000, step=1, description='Iterations')
select_cis = widgets.Dropdown(options=size_options, description='Content Image Size')
select_sis = widgets.Dropdown(options=size_options, description='Style Image Size')


button_stylize = widgets.Button(description='Stylize', button_style='Info')
vbox = widgets.VBox([
    dropdown_styles,
    radios_optimizer,
    radios_input,
    select_cis,
    select_sis,
    slider_cw, 
    slider_sw, 
    slider_tvw,
    slider_iterations,
    button_stylize
])

def perform_style_transfer(button):
    out.clear_output()
    
    content_image = load_image(content_image_path, size=select_cis.value).to(device)
    style_image = load_image(f'../images/style/{dropdown_styles.value}', size=select_sis.value).to(device)
    
    if radios_input.value == 'Noise':
        output_image = torch.rand(
            content_image.shape[0],
            content_image.shape[1],
            content_image.shape[2]
        ).to(device)
    else:
        output_image = content_image.clone().to(device)
    
    config = {
        'loss_network': 'vgg16',
        'content_weight': slider_cw.value,
        'style_weight': slider_sw.value,
        'total_variation_weight': slider_tvw.value,
        'style_image': dropdown_styles.value,
        'style_image_size': select_sis.value,
        'content_layers': ['relu3_3'],
        'style_layers': ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']
    }

    criterion = get_criterion(config, device=device_type)
    
    if radios_optimizer.value == 'Adam':
        optimizer = optim.Adam([output_image], lr=1e-1)
    else: 
        optimizer = optim.LBFGS([output_image])
    
    content_image.unsqueeze_(0)
    output_image.unsqueeze_(0)
    output_image.requires_grad_()

    content_loss_history = []
    style_loss_history = []
    total_variation_loss_history = []
    loss_history = []
    
    with out:
        progress_bar = tqdm_notebook(range(slider_iterations.value))
        if radios_optimizer.value == 'Adam':    
            for iteration in progress_bar:
                output_image.data.clamp_(0, 1)
                optimizer.zero_grad()

                loss = criterion(output_image, content_image)
                loss.backward()

                content_loss_history.append(criterion.content_loss_val)
                style_loss_history.append(criterion.style_loss_val)
                total_variation_loss_history.append(criterion.total_variation_loss_val)
                loss_history.append(criterion.loss_val)

                progress_bar.set_description(f'Loss: {loss.item():,.2f}')

                optimizer.step()
        else:
            for iteration in progress_bar:
                def closure():
                    output_image.data.clamp_(0, 1)
                    optimizer.zero_grad()

                    loss = criterion(output_image, content_image)
                    loss.backward()

                    content_loss_history.append(criterion.content_loss_val)
                    style_loss_history.append(criterion.style_loss_val)
                    total_variation_loss_history.append(criterion.total_variation_loss_val)
                    loss_history.append(criterion.loss_val)

                    progress_bar.set_description(f'Loss: {loss.item():,.2f}')

                    return loss

                optimizer.step(closure)
            
        content_image.squeeze_()

        output_image.detach_()
        output_image.squeeze_()
        output_image.data.clamp_(0, 1)

        fig, axes = plt.subplots(2, 2, dpi=200)
        fig.set_size_inches(18, 20)

        plot_image_tensor(content_image, ax=axes[0, 0])
        plot_image_tensor(style_image, ax=axes[0, 1])
        plot_image_tensor(output_image, ax=axes[1, 0])
        
        axes[1, 1].loglog(content_loss_history, label='Content Loss')
        axes[1, 1].loglog(style_loss_history, label='Style Loss')
        axes[1, 1].loglog(total_variation_loss_history, label='Total Variation Loss')
        axes[1, 1].loglog(loss_history, label='Loss')
        axes[1, 1].legend()
        
        plt.show()
        

button_stylize.on_click(perform_style_transfer)
display(vbox, out)

VBox(children=(Dropdown(description='Style', options=('abstract_painting.jpg', 'bathers_in_a_forest.jpg', 'cry…

Output()