# Setup

## Repository and Libraries

In [None]:
if 'google.colab' in str(get_ipython()):
    # Move to the root
    import os
    os.chdir('/content')
    !git clone https://gitlab.com/jemaro/wut/neural-networks/style-transfer
    !cd style-transfer; git pull

In [None]:
if 'google.colab' in str(get_ipython()):
    from google.colab import files

    # Move to the repository
    import os
    os.chdir('/content/style-transfer')
else:
    import tkinter as tk

from style_transfer import Experiment, config_logger
from style_transfer.experiment import (
    DEFAULTS, CONTENT, STYLE, CONTENT_LAYERS, STYLE_LAYERS, PRE_TRAINING,
    LEARNING_RATE, BETA_1, BETA_2, EPSILON, AMSGRAD, CONTENT_WEIGHT, 
    STYLE_WEIGHT, NUM_ITERATIONS,
    )

%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (20,20)
mpl.rcParams['axes.grid'] = False
mpl.rcParams['font.size'] = 22

from ipywidgets import (
    interact, fixed, Layout, HBox,
    Checkbox, SelectionSlider, Button,
    )
from pathlib import Path
from math import log

logger = config_logger()

## Select experiment

In [None]:
folder = Path('results', 'parameters_experiments')
exp = Experiment(folder)

logger.info(f'{CONTENT_WEIGHT}s = {exp.options(CONTENT_WEIGHT, **DEFAULTS)}')
exp.image(**DEFAULTS)

## Interactive tool

In [None]:
%%capture
# Styling
def layout(visible):
    return Layout(
        width='100%', max_width='800px', height='20px',
        visibility='visible' if visible else 'hidden',
        )
s = {'description_width': '100px'}

# Initialize the widgets
widgets = {}
for p, v in DEFAULTS.items():
    o = exp.options(p, **DEFAULTS)
    visible = len(o)>1
    if p in [PRE_TRAINING, AMSGRAD]:
        widgets[p] = Checkbox(
            value=v, layout=layout(visible), style=s,
            )
    else:
        widgets[p] = SelectionSlider(
            value=v, options=o, continuous_update=False, 
            layout=layout(visible), style=s,
            )

# Dynamically change the ranges on new values
def update_ranges(*args):
    logger.info(args)
    kwargs = {k: w.value for k, w in widgets.items()}
    for p, w in widgets.items():
        w.unobserve(update_ranges, 'value')
        o = exp.options(p, **kwargs)
        visible = len(o)>1
        if p not in [PRE_TRAINING, AMSGRAD]:
            w.options = o
            w.value = kwargs[p]
        else:
            w.layout = layout(visible)
        w.observe(update_ranges, 'value')
for w in widgets.values():
    w.observe(update_ranges, 'value')

# Plot function on change
def replot(**kwargs):
    # Image plot
    ax_im.imshow(exp.image(**kwargs))
    # Loss plot
    ax_it.clear()
    exp.loss_plot(**kwargs, ax=ax_it)
    ax_it.set_yscale('log')
    ratio = 1.8
    xvals = (0, 1000)
    ax_it.set_xlim(xvals)
    yvals = ax_it.get_ylim()
    _xrange = xvals[1]-xvals[0]
    _yrange = log(yvals[1])-log(yvals[0])
    ax_it.set_aspect(ratio*(_xrange/_yrange), adjustable='box')
    #major grid lines
    ax_it.grid(b=True, which='major', color='gray', alpha=0.6, ls='-.', lw=1.5)
    #minor grid lines
    ax_it.minorticks_on()
    ax_it.grid(b=True, which='minor', color='beige', alpha=0.8, ls='-', lw=1)
    # Draw everything
    fig.canvas.draw()
    display(fig)
    return

# Initialize plots
fig = plt.figure()
ax_im = fig.add_subplot(1,2,1) 
ax_im.imshow(exp.image(**DEFAULTS))
plt.axis('off')
ax_it = fig.add_subplot(1,2,2) 


# Save figure controls
def savefig(fig, **kwargs):
    if 'google.colab' in str(get_ipython()):
        fig.savefig('tmp.png', **kwargs, bbox_inches=kwargs.get('bbox','tight'))
        files.download('tmp.png')
    else:
        root = tk.Tk()
        root.withdraw()
        file_path = tk.filedialog.asksaveasfilename()
        fig.savefig(file_path, **kwargs, bbox_inches=kwargs.get('bbox','tight'))
    return 'Done'

def savesubfig(ax, fig, expanded=(1.2, 1.2)):
    bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    return savefig(fig, bbox_inches=bbox.expanded(*expanded))

b_im = Button(description="Save image")
b_im.on_click(lambda x: savesubfig(ax_im, fig, expanded=(1,1)))
b_it = Button(description="Save graph")
b_it.on_click(lambda x: savesubfig(ax_it, fig))
b_all = Button(description="Save all")
b_all.on_click(lambda x: savefig(ax_im.get_figure()))
buttons = HBox([b_im, b_it, b_all])

# Playground

In [None]:
interact(replot, **widgets)
display(buttons)