In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from IPython import display
import math
import os
import PIL
import requests
import sys
from typing import Callable, List, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import fastai.vision

You should set the following option to True if the notebook isn't located in the file system inside a clone of the git repo (with the needed Python modules available) it belongs to; i.e., it's running independently.

In [None]:
run_as_standalone_nb = False

In [None]:
# This cell needs to be executed before importing local project modules, like import genlab.core.gan
if run_as_standalone_nb:
    root_lib_path = os.path.abspath('generative-lab')
    if not os.path.exists(root_lib_path):
        !git clone https://github.com/davidleonfdez/generative-lab.git
    if root_lib_path not in sys.path:
        sys.path.insert(0, root_lib_path)
else:
    import local_lib_import

In [None]:
# Local project modules. Must be imported after local_lib_import or cloning git repo.
from genlab.style_transfer import (denormalize, get_transformed_style_imgs, HyperParams,
                                   img_to_tensor, LossWeights, normalize, train, 
                                   train_progressive_growing, TransformSpecs)
from genlab.core.gen_utils import PrinterProgressTracker

In [None]:
patch_sz = 3
n_ch = 3
img_sz = 224

# INPUTS MANAGEMENT

In [None]:
# !git clone https://github.com/mf1024/ImageNet-Datasets-Downloader.git "C:/Users/blabla/ImageNetDownloader"
def download_imagenet_subset():
    !python C:/Users/blabla/ImageNetDownloader/downloader.py \
        -data_root C:/Users/blabla/imagenet \
        -number_of_classes 5 \
        -images_per_class 10
    
def get_img_from_url(url) -> PIL.Image.Image:
    import requests
    return PIL.Image.open(requests.get(url, stream=True).raw)      

In [None]:
def img_t_from_url(url, target_sz:int=img_sz) -> torch.Tensor:
    img = get_img_from_url(url)
    return img_to_tensor(img, target_sz)
    
def img_t_from_path(path) -> torch.Tensor:
    img = PIL.Image.open(path)
    return img_to_tensor(img, img_sz)

def check_norm_is_needed_vgg(x):
    stats = (torch.Tensor([0.485, 0.456, 0.406]), 
             torch.Tensor([0.229, 0.224, 0.225]))
    out = vgg(fastai.vision.normalize(x, *stats))
    out2 = vgg(x)
    return out.max(), out.argmax(), out2.max(), out2.argmax()
    

# DATA

In [None]:
class ImageURLs:
    PAINTING = 'https://www.moma.org/media/W1siZiIsIjQ2NzUxNyJdLFsicCIsImNvbnZlcnQiLCItcXVhbGl0eSA5MCAtcmVzaXplIDIwMDB4MjAwMFx1MDAzZSJdXQ.jpg?sha=314ebf8cc678676f'
    BASKET_BALL = 'https://miro.medium.com/proxy/1*BDE-SkJBCG_7P4chK4vKnw.jpeg'
    FERRARI_F1 = 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Ferrari_F1_2006_EMS.jpg/1024px-Ferrari_F1_2006_EMS.jpg'
    RENAULT_F1 = 'https://upload.wikimedia.org/wikipedia/commons/3/31/Renault_F1_front_IAA_2005.jpg'
    ELON_MUSK = 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Elon_Musk_2015.jpg/408px-Elon_Musk_2015.jpg'
    # By Guillaume Paumier: https://www.flickr.com/photos/gpaumier/6198197101
    MARK_ZUCK = 'https://live.staticflickr.com/6156/6198197101_9d7a685618_b.jpg'
    PICASSO_RETRATO = 'https://live.staticflickr.com/4003/4407941037_9718d307da_b.jpg'
    PAPER_GIRL = 'https://live.staticflickr.com/8595/16020558165_1ed9f5af8c_b.jpg'
    PAPER_GIRL_PIC = 'https://live.staticflickr.com/1/2281680_656225393e_c.jpg'

In [None]:
painting_img_t = normalize(img_t_from_url(ImageURLs.PAINTING))
basket_ball_img_t = normalize(img_t_from_url(ImageURLs.BASKET_BALL))

In [None]:
fastai.vision.show_image(denormalize(basket_ball_img_t[0]))
fastai.vision.show_image(denormalize(painting_img_t[0]))

In [None]:
painting_imgs_t = normalize(get_transformed_style_imgs(get_img_from_url(ImageURLs.PAINTING)))

# TRAINING

In [None]:
def get_print_results_callback(n_iters_between:int, n_total_iters:int, show_sz=False):
    n_imgs = n_total_iters // n_iters_between
    n_cols = 3
    n_rows = n_imgs//n_cols + min(1, n_imgs % n_cols)
    imgs, losses = [], []
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(16, 16 * n_rows/n_cols))
    axs = axs.flatten()
    for ax in axs.flatten(): ax.axis('off')
    display_id = None

    def _print_result(i, gen_img_t, loss):
        if (i+1) % n_iters_between == 0:
            gen_img_t = denormalize(gen_img_t.detach().cpu()).squeeze(dim=0).clamp(0, 1)
            gen_img = fastai.vision.Image(gen_img_t)          
            imgs.append(gen_img)
            losses.append(loss.detach().cpu())
            for j, img in enumerate(imgs):
                # If using progressive growing, it's only right if the number of iterations 
                # by size is a multiple of `n_iters_between`
                iter_idx = (j+1) * n_iters_between
                title = f'Iteration {iter_idx}, loss = {losses[j]}'
                if show_sz: title += f', size = {img.size}'
                img.show(ax=axs[j], title=title)
            #plt.close()
            nonlocal display_id
            if display_id is None: display_id = display.display(fig, display_id=True)
            display_id.update(fig)
            display.clear_output(wait=True)

    return _print_result

In [None]:
n_iters=200
gen_img_t = train(painting_img_t, 
                  basket_ball_img_t,
                  hyperparams=HyperParams(lr=0.1), 
                  n_iters=n_iters, 
                  loss_weights=LossWeights(style=0.1, content=1., reg=0.),
                  callbacks=[get_print_results_callback(20, n_iters)])

In [None]:
denorm_clamped_gen_img_t = denormalize(gen_img_t.detach().cpu()).squeeze(dim=0).clamp(0, 1)
gen_img = fastai.vision.Image(denorm_clamped_gen_img_t)
gen_img

In [None]:
gen_img.save('mrfcnn_tr1_200it.jpg')

In [None]:
gen_img_t.min(), gen_img_t.max(), gen_img_t.size(), gen_img_t.requires_grad

In [None]:
denorm_gen_img_t = denormalize(gen_img_t.detach())
pix_dist = {0: 0}
for px in denorm_gen_img_t.flatten():
    if 0. <= px <= 1.: 
        pix_dist[0] += 1
    else:
        key = int(px.item()//1)
        if key not in pix_dist: pix_dist[key] = 0
        pix_dist[key] += 1

pix_dist

There may be some pixels out of range. Maybe they would be black/white anyway and thinking about a solution (like penalizing it with loss func) is not worth it.

## Training on CUDA

In [None]:
n_iters=200
gen_img_t = train(painting_img_t, 
                  basket_ball_img_t,
                  hyperparams=HyperParams(lr=0.1), 
                  n_iters=n_iters, 
                  loss_weights=LossWeights(style=0.1, content=1., reg=0.),
                  callbacks=[get_print_results_callback(20, n_iters)],
                  device=torch.device('cuda'))

In [None]:
denorm_clamped_gen_img_t = denormalize(gen_img_t.detach().cpu()).squeeze(dim=0).clamp(0, 1)
gen_img = fastai.vision.Image(denorm_clamped_gen_img_t)
gen_img

In [None]:
gen_img.save('mrfcnn_tr1_200it.jpg')

## Training with progressive growing

In [None]:
style_img = get_img_from_url(ImageURLs.PAINTING)
content_img = get_img_from_url(ImageURLs.BASKET_BALL)
n_iters_by_sz = 200
init_sz, target_sz = 64, 256
n_total_iters = n_iters_by_sz * int(1 + math.log2(target_sz//init_sz))
gen_img_t = train_progressive_growing(style_img, 
                                      content_img, 
                                      target_sz,
                                      init_sz=init_sz,
                                      style_img_tfms=TransformSpecs.none(),
                                      n_iters_by_sz=n_iters_by_sz, 
                                      hyperparams=HyperParams(lr=0.1),
                                      loss_weights=LossWeights(style=0.1, content=1., reg=1e-3),
                                      callbacks=[get_print_results_callback(40, n_total_iters)])

To use several transformed versions of the style image, just omit the param `style_img_tfms` of `train_progressive_growing` or pass the rotations and scales especifically to the constructor of `TransformSpecs`. For example:

```
   style_img_tfms=TransformSpecs(do_scale=True, 
                                 do_rotate=True, 
                                 scales=(0.975, 1.), 
                                 rotations=(-3., 0, 3.))
```

In [None]:
denorm_clamped_gen_img_t = denormalize(gen_img_t.detach().cpu()).squeeze(dim=0).clamp(0, 1)
gen_img = fastai.vision.Image(denorm_clamped_gen_img_t)
gen_img

In [None]:
gen_img.save('mrfcnn_tr1_200it_pg64-256.jpg')

# TESTS

In [None]:
def test_content_loss():
    x = torch.Tensor([[[0]*4]*4]*3)
    y = torch.Tensor([[[1]*4]*4]*3)
    loss1 = content_loss(x, y)
    x += 0.5
    loss2 = content_loss(x, y)
    y -= 0.25
    loss3 = content_loss(x, y)
    assert loss1 == 4*4*3
    assert loss2 == 4*3
    assert loss3 == 3

def test_smoothness_reg():
    uniform_img = torch.Tensor([[[1]*4]*4]*3)
    print(uniform_img.size())
    diffy_x_img = torch.Tensor([[[1]*4, [0]*4]*2]*3)
    diffy_y_img = torch.Tensor([[[1, 0]*2]*4]*3)
    diffy_x_y_img = torch.Tensor([[[1, 0]*2, [0, 1]*2]*2]*3)
    
    print(diffy_x_y_img)
    assert smoothness_reg(uniform_img) == 0
    assert smoothness_reg(diffy_x_img) == 3*3*4
    assert smoothness_reg(diffy_y_img) == 3*4*3
    assert smoothness_reg(diffy_x_y_img) == 2*3*4*3

def test_split_in_patches():
    img = torch.Tensor([i for i in range(3*4*4)]).view(3, 4, 4)
    # tensor([[[ 0.,  1.,  2.,  3.],
    #          [ 4.,  5.,  6.,  7.],
    #          [ 8.,  9., 10., 11.],
    #          [12., 13., 14., 15.]],

    #         [[16., 17., 18., 19.],
    #          [20., 21., 22., 23.],
    #          [24., 25., 26., 27.],
    #          [28., 29., 30., 31.]],

    #         [[32., 33., 34., 35.],
    #          [36., 37., 38., 39.],
    #          [40., 41., 42., 43.],
    #          [44., 45., 46., 47.]]])
    actual_2x2 = split_in_patches(img, patch_sz=2)
    actual_3x3 = split_in_patches(img, patch_sz=3)
    expected_2x2 = torch.Tensor(
       [[[[ 0.,  1.],
          [ 4.,  5.]],

         [[16., 17.],
          [20., 21.]],

         [[32., 33.],
          [36., 37.]]],


        [[[ 1.,  2.],
          [ 5.,  6.]],

         [[17., 18.],
          [21., 22.]],

         [[33., 34.],
          [37., 38.]]],


        [[[ 2.,  3.],
          [ 6.,  7.]],

         [[18., 19.],
          [22., 23.]],

         [[34., 35.],
          [38., 39.]]],


        [[[ 4.,  5.],
          [ 8.,  9.]],

         [[20., 21.],
          [24., 25.]],

         [[36., 37.],
          [40., 41.]]],


        [[[ 5.,  6.],
          [ 9., 10.]],

         [[21., 22.],
          [25., 26.]],

         [[37., 38.],
          [41., 42.]]],


        [[[ 6.,  7.],
          [10., 11.]],

         [[22., 23.],
          [26., 27.]],

         [[38., 39.],
          [42., 43.]]],


        [[[ 8.,  9.],
          [12., 13.]],

         [[24., 25.],
          [28., 29.]],

         [[40., 41.],
          [44., 45.]]],


        [[[ 9., 10.],
          [13., 14.]],

         [[25., 26.],
          [29., 30.]],

         [[41., 42.],
          [45., 46.]]],


        [[[10., 11.],
          [14., 15.]],

         [[26., 27.],
          [30., 31.]],

         [[42., 43.],
          [46., 47.]]]])
    expected_3x3=torch.Tensor(
       [[[[ 0.,  1.,  2.],
          [ 4.,  5.,  6.],
          [ 8.,  9., 10.]],

         [[16., 17., 18.],
          [20., 21., 22.],
          [24., 25., 26.]],

         [[32., 33., 34.],
          [36., 37., 38.],
          [40., 41., 42.]]],


        [[[ 1.,  2.,  3.],
          [ 5.,  6.,  7.],
          [ 9., 10., 11.]],

         [[17., 18., 19.],
          [21., 22., 23.],
          [25., 26., 27.]],

         [[33., 34., 35.],
          [37., 38., 39.],
          [41., 42., 43.]]],


        [[[ 4.,  5.,  6.],
          [ 8.,  9., 10.],
          [12., 13., 14.]],

         [[20., 21., 22.],
          [24., 25., 26.],
          [28., 29., 30.]],

         [[36., 37., 38.],
          [40., 41., 42.],
          [44., 45., 46.]]],


        [[[ 5.,  6.,  7.],
          [ 9., 10., 11.],
          [13., 14., 15.]],

         [[21., 22., 23.],
          [25., 26., 27.],
          [29., 30., 31.]],

         [[37., 38., 39.],
          [41., 42., 43.],
          [45., 46., 47.]]]])
    assert(torch.equal(actual_2x2, expected_2x2))
    assert(torch.equal(actual_3x3, expected_3x3))

def test_style_loss(vgg19):
    elon_photo_url = 'https://wonderfulengineering.com/wp-content/uploads/2018/09/musk5.jpg'
    elon_drawing_url = 'https://i.redd.it/ofljkrzi82r21.jpg'
    rectangular_spain_flag_url = 'https://upload.wikimedia.org/wikipedia/commons/d/d5/Flag_of_Spain_%28WFB_2000%29.jpg'
    #'http://icons.iconarchive.com/icons/wikipedia/flags/256/ES-Spain-Flag-icon.png'
    #hand_spain_flag_url = 'https://cdn.pixabay.com/photo/2015/02/16/00/20/spain-637843_960_720.jpg'
    circular_spain_flag_url = 'https://cdn.pixabay.com/photo/2017/10/04/10/44/spain-2815785_960_720.jpg'
    #'http://files.softicons.com/download/web-icons/world-cup-flags-icons-by-custom-icon-design/png/64x64/Spain.png'
    austria_flag_url = 'https://pixnio.com/free-images/flags-of-the-world/flag-of-austria-725x483.jpg'
    #'https://www.publicdomainpictures.net/pictures/120000/velka/austria-flag.jpg'
    ftrs_calc = FeaturesCalculator(vgg_style_layers_idx, vgg_content_layers_idx, vgg19)
    url_to_ftrs = lambda url: ftrs_calc.calc_style(img_t_from_url(url))

    ftr_map_1 = torch.rand(1, 16, 8, 8)
    ftr_map_2 = torch.rand(1, 16, 8, 8)    
    ftr_map_sp_rect = url_to_ftrs(rectangular_spain_flag_url)[0]
    ftr_map_sp_circ = url_to_ftrs(circular_spain_flag_url)[0]
    ftr_map_austria = url_to_ftrs(austria_flag_url)[0]
    ftr_map_sp_rect_patches = split_in_patches(ftr_map_sp_rect)    
    
    loss_rect_circ = style_loss(ftr_map_sp_circ, ftr_map_sp_rect, ftr_map_sp_rect_patches)
    loss_rect_sp_aust = style_loss(ftr_map_austria, ftr_map_sp_rect, ftr_map_sp_rect_patches)
    loss_equal = style_loss(ftr_map_1, ftr_map_1)
    loss_different = style_loss(ftr_map_1, ftr_map_2)
    #assert loss_equal == 0, f'{loss_equal}'
    #assert loss_different > 0, f'{loss_different}'
    assert loss_rect_circ < loss_rect_sp_aust, f'{loss_rect_circ}, {loss_rect_sp_aust}'

test_content_loss()
test_smoothness_reg()
test_split_in_patches()
# This isn't so easy to test
#test_style_loss(vgg)

# PENDING

* Penalize gen_img being out of range (be aware range is different for each channel)
* Check better if normalized cross-correlation is ok
* Don't forget to release plots, be it plot.close() or whatever
* How to deal with rectangular images? Right now, I'm resizing the larger dim to 224 and filling the rest with black padding. Other options are:
  * Crop (even if random it shouldn't make much sense)
  * Check/think if vgg19 is actually capable of dealing with input sizes other than (224, 224) while preserving evaluation quality
* Fit one cycle may be used/adapted????
* Add a reference to the paper https://arxiv.org/pdf/1601.04589.pdf and authors
  
# TRAINING TEST STEPS
1. Use only content loss
2. Check adding style loss improves the results
    -Check requires_grad of tensors in the middle of the process
3. Check adding regularizer smoothes the result

# IMPROVEMENTS
* Precalc as much stuff related to content and style image as possible