## This code is mostly a modification of the STROTSS handout provided by Peter Schaldenbrand. Modifications include processing of local images, generation of masks from training images, and applying those masks to the style transfer output

# STROTSS Style Transfer Notebook

#### [Style Transfer by Relaxed Optimal Transport and Self-Similarity (STROTSS)](https://arxiv.org/abs/1904.12785)

Code from: https://github.com/futscdav/strotss

Notebook by: Peter Schaldenbrand

In [None]:
#@title Download the strotss code from GitHub

import os
if not os.path.exists('/content/strotss'):
    !git clone https://github.com/futscdav/strotss.git
os.chdir('/content/strotss')
from strotss import *


In [None]:
#@title Mount drive for convenient file access
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#@title Copy and unzip training data
os.chdir('/')
!cp "/content/gdrive/MyDrive/10615/STROTSS Starry Night Chair/*" /content/
!unzip /content/chair_train.zip
!unzip /content/chair_styled.zip
!unzip /content/chair_styled_sharp.zip


In [89]:
#@title Helper Functions
import torch
import requests
import PIL.Image
from io import BytesIO
import matplotlib.pyplot as plt
from torchvision.transforms.functional import adjust_sharpness
import glob
import numpy as np
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if not torch.cuda.is_available():
    print('YOU ARE NOT USING A GPU.  IT\'S GONNA BE REAAALLLLY SLOW')
    print('Go to the top of the page.  Click Runtime -> Change Runtime Type -> Hardware accelerator')
    print('From the dropdown, select GPU and rerun all this stuff')

def pil_loader_internet(url):
    response = requests.get(url)
    img = PIL.Image.open(BytesIO(response.content))
    return img.convert('RGB')

def pil_loader_local(path, sharpness=1, encoding="RGB"):
    with open(path, "rb") as f:
        img = PIL.Image.open(BytesIO(f.read()))
        if sharpness == 1:
            return img.convert(encoding)
        else:
            return adjust_sharpness(img.convert(encoding), sharpness)

def show_img(img):
    # Code for displaying at actual resolution from:
    # https://stackoverflow.com/questions/28816046/displaying-different-images-with-actual-size-in-matplotlib-subplot
    dpi = 80
    height, width, depth = img.shape
    figsize = width / float(dpi), height / float(dpi)
    plt.figure(figsize=figsize)

    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.show()


def plot_style_and_content(style, content):
    fig, ax = plt.subplots(1,2, figsize=(10,5))
    ax[0].imshow(content)
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].set_title('Content')
    ax[1].imshow(style)
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    ax[1].set_title('Style')
    plt.show()

# takes an image and creates a binary mask, must be RGBA format
def image_to_mask(img):
    img_arr = pil_to_np(img)
    for r in range(img_arr.shape[0]):
        for c in range(img_arr.shape[1]):
            if not np.all(img_arr[r, c] == 0):
                img_arr[r, c, :] = (255, 255, 255, 255)
    
    return np_to_pil(img_arr)

# Applies mask to an image, must be RGBA format
def mask_image(img, mask):
    img_arr = pil_to_np(img)
    mask_arr = pil_to_np(mask)

    for r in range(img_arr.shape[0]):
        for c in range(img_arr.shape[1]):
            if np.all(mask_arr[r, c] == 0):
                img_arr[r, c, :] = (0, 0, 0, 0)
    
    return np_to_pil(img_arr)

In [10]:
#@title Define STROTSS function
# Redefine the STROTSS function to put some debugging statements in

def strotss(content_pil, style_pil, content_weight=1.0*16.0, device='cuda:0', space='uniform'):
    content_np = pil_to_np(content_pil)
    style_np = pil_to_np(style_pil)
    content_full = np_to_tensor(content_np, space).to(device)
    style_full = np_to_tensor(style_np, space).to(device)

    lr = 2e-3
    extractor = Vgg16_Extractor(space=space).to(device)

    scale_last = max(content_full.shape[2], content_full.shape[3])
    scales = []
    for scale in range(10):
        divisor = 2**scale
        if min(content_pil.width, content_pil.height) // divisor >= 33:
            scales.insert(0, divisor)
    
    clow = -1.0 if space == 'uniform' else -1.7
    chigh = 1.0 if space == 'uniform' else 1.7

    for scale in scales:
        # rescale content to current scale
        content = tensor_resample(content_full, [ content_full.shape[2] // scale, content_full.shape[3] // scale ])
        style = tensor_resample(style_full, [ style_full.shape[2] // scale, style_full.shape[3] // scale ])
        print(f'Optimizing at resoluton [{content.shape[2]}, {content.shape[3]}]')

        # upsample or initialize the result
        if scale == scales[0]:
            # first
            result = laplacian(content) + style.mean(2,keepdim=True).mean(3,keepdim=True)
        elif scale == scales[-1]:
            # last 
            result = tensor_resample(result, [content.shape[2], content.shape[3]])
            lr = 1e-3
        else:
            result = tensor_resample(result, [content.shape[2], content.shape[3]]) + laplacian(content)

        # do the optimization on this scale
        result = optimize(result, content, style, scale, content_weight=content_weight, lr=lr, extractor=extractor)

        # Show intermediate result
        result_image = tensor_to_np(torch.clamp(result, clow, chigh)) # 
        # renormalize image
        result_image -= result_image.min()
        result_image /= result_image.max()
        show_img(result_image)

        # next scale lower weight
        content_weight /= 2.0

    result_image = tensor_to_np(tensor_resample(torch.clamp(result, clow, chigh), [content_full.shape[2], content_full.shape[3]])) # 
    # renormalize image
    result_image -= result_image.min()
    result_image /= result_image.max()
    return np_to_pil(result_image * 255.)

In [None]:
# This is the starry night image
style_url = 'https://m.media-amazon.com/images/I/91iS91eizUL._AC_SL1500_.jpg'
style_pil = pil_loader_internet(style_url)

max_width = 512

content_weight = 0.7
content_weight *= 16.0 

sharpness = 3

for fname in glob.glob("/content/chair_train/*"):
    if os.path.exists(fname.replace("train", "styled")):
        # Incase colab kills instance before finishing
        print(f"Already have a styled {fname} computed")
        continue

    content_pil = pil_loader_local(fname, sharpness)
    result = strotss(pil_resize_long_edge_to(content_pil, max_width), 
            pil_resize_long_edge_to(style_pil, max_width), 
            content_weight, device, "vgg")
    print('Final Result')
    show_img(pil_to_np(result))
    result.save(fname.replace("train", "styled"), format="png")
    if sharpness != 1:
        content_pil.save(fname.replace("train", "train_sharp"), format="png")


In [86]:
for fname in glob.glob("/content/chair_train/*"):
    content_pil = pil_loader_local(fname, encoding="RGBA")
    content_pil = pil_resize_long_edge_to(content_pil, 512)
    mask_pil = image_to_mask(content_pil)
    mask_pil.save(fname.replace("train", "mask"), format="png")


In [90]:
for fname in glob.glob("/content/chair_styled/*"):
    styled_content_pil = pil_loader_local(fname, encoding="RGBA")
    mask_pil = pil_loader_local(fname.replace("styled", "mask"), encoding="RGBA")
    styled_masked_content_pil = mask_image(styled_content_pil, mask_pil)

    styled_masked_content_pil.save(fname.replace("styled", "styled_maskedafter"), format="png")


In [None]:
!zip -r /content/chair_styled_maskedafter.zip /content/chair_styled_maskedafter
# !zip -r /content/chair_mask.zip /content/chair_mask

In [92]:
!cp /content/chair_styled_maskedafter.zip "/content/gdrive/MyDrive/10615/STROTSS Starry Night Chair"
# !cp /content/chair_mask.zip  /content/gdrive/MyDrive/10615