# Imports

In [None]:
import os
import sys

import numpy as np
import random
import time

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

import PIL
from PIL import Image
from IPython import display

import torch
import torchvision.transforms as transforms

from ImageTransformer import ImageTransformer
from trainer import Trainer
from datasets import InputDataset

# Paths & Model

In [None]:
main_path = "PATH/TO/IMAGE/DIR/"
style_dir = "PATH/TO/STYLE/IMAGE/DIR/"
test_image_path = "/content/Bacchus.jpg"
IDtail = "_0.pth"

In [None]:
def reload_model():
    return ImageTransformer(leak=0,
                            norm_type='batch',
                            DWS=True,
                            DWSFL=False,
                            outerK=3,
                            resgroups=4,
                            filters=[8, 12, 16],
                            shuffle=True,
                            blocks=[2, 2, 2, 1, 1, 1],
                            endgroups=(1, 1),
                            upkern=3,
                            bias_ll=True)

# Functions

In [None]:
# load device for gpu or cpu running (GPU recommended)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a dataset of jpgs, pngs, etc  (NOTE: Not linked)
contentims_raw = os.listdir(main_path)
contentims = []
for path in contentims_raw:
    if path[:1] != ".":
        contentims.append(path)
cutoff = 0.85 * len(contentims)
cutoff = (cutoff // 16) * 16
contenttrain = contentims[:cutoff]
contentval = contentims[cutoff:]

In [None]:
# load various functions and transformations for image I/O
transformPILtoTensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transformTensortoPIL = transforms.Compose([
    transforms.Normalize((-1., -1., -1.), (2., 2., 2.)),
    transforms.ToPILImage()
])


def load_img_x(path_to_img, max_dim=512):
    # for loading style image
    img = Image.open(path_to_img)
    shape = img.size
    short_dim = min(shape)
    scale = max_dim / short_dim
    img = img.resize((int(shape[0] * scale), int(shape[1] * scale)))
    imgs = transformPILtoTensor(img).unsqueeze(0).to(device, torch.float)
    return imgs


def load_img_reshape(path_to_img, max_dim=512):
    img = Image.open(path_to_img)
    shape = img.size
    short_dim = min(shape)
    scale = max_dim / short_dim
    img = img.resize((int(shape[0] * scale), int(shape[1] * scale)))
    new_shape = img.size
    os_h = int((new_shape[0] - max_dim) / 2)
    os_w = int((new_shape[1] - max_dim) / 2)
    img = img.crop((os_h, os_w, os_h + max_dim, os_w + max_dim))
    imgs = transformPILtoTensor(img).unsqueeze(0).to(torch.float)
    return imgs


def load_prepped_img(path_to_img):
    img = Image.open(path_to_img)
    imgs = transformPILtoTensor(img).unsqueeze(0).to(torch.float)
    return imgs


def load_data(content, resize=False):
    if resize:
        load_func = load_img_reshape
    else:
        load_func = load_prepped_img
    x = load_func(mainpath + content[0])
    for path in content[1:]:
        x = torch.cat((x, load_func(mainpath + path)), 0)
    print(x.shape)
    return x


def prepandclip(img):
    return img.squeeze().data.clamp_(-1, 1).cpu().detach()


def show_test_image_quality(model, image, device=device):
    model_input = image.clone()
    image = (image.squeeze(0).permute(1, 2, 0) + 1.) / 2
    plt.subplot(121)
    plt.imshow(image)
    plt.axis('off')
    plt.title('input')

    with torch.no_grad():
        model_input = model_input.to(device)
        model_output = model(model_input)
    output = prepandclip(model_output)
    output = (output.permute(1, 2, 0) + 1.) / 2

    plt.subplot(122)
    plt.imshow(output)
    plt.axis('off')
    plt.title('output')

    plt.tight_layout()
    plt.show()

In [None]:
test_image = load_img_x(test_image_path, max_dim=300)

# create a torch tensor of images that are that have been  cropped with correct aspect
xtrain = load_data(contenttrain)
xval = load_data(contentval)

In [None]:
def run_trainer(image_transformer,
                xtrain,
                xval,
                content_layers,
                style_layers,
                style_path,
                outfile,
                content_style_layers=None,
                epochs=50,
                patience=5,
                style_weight=(0, 10),
                content_weight=1,
                cs_weight=50,
                tv_weight=1000,
                stable_weight=2000,
                pretrained_filename=None,
                test_image=None):
    # load image trainsformer and trained AE
    if pretrained_filename is not None:
        image_transformer.load_state_dict(torch.load(pretrained_filename))
    style_image = load_img_x(style_path, max_dim=512)
    trainer = Trainer(image_transformer, content_layers, style_layers,
                      style_image, content_style_layers)
    # prep train data
    datasettrain = InputDataset(xtrain)
    # prep val data
    datasetval = InputDataset(xval)
    if device == torch.device("cuda"):
        print(torch.cuda.memory_summary(abbreviated=True))
    # train
    trainer.train(datasettrain,
                  val=datasetval,
                  epochs=epochs,
                  epoch_show=1,
                  style_weight=style_weight,
                  content_weight=content_weight,
                  stable_weight=stable_weight,
                  tv_weight=tv_weight,
                  cs_weight=cs_weight,
                  es_patience=patience,
                  best_path="best.pth",
                  test_image=test_image)
    # revert to best and save
    image_transformer.load_state_dict(torch.load("best.pth"))
    torch.save(image_transformer.state_dict(), outfile)
    del trainer
    del datasettrain
    del datasetval
    del image_transformer
    if device == torch.device("cuda"):
        torch.cuda.empty_cache()

# Train

In [None]:
content_layers = ['relu_7']
style_layers = ['relu_2', 'relu_4', 'relu_7', 'relu_11', 'relu_15']
content_style_layers = None
style_path = style_dir + "Munch-The_Scream.jpg"
outfile = "scream_bench" + IDtail
image_transformer = reload_model()
run_trainer(image_transformer, xtrain, xval,
            content_layers, style_layers, 
            style_path,outfile,
            content_style_layers=content_style_layers,
            pretrained_filename=None,
            epochs=50, patience=5,
            test_image=test_image,
            style_weight=(20, 0),
            cs_weight=0,
            content_weight=1,
            stable_weight=2000,
            tv_weight=1000)