In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from dataclasses import dataclass
from IPython import display
import math
import PIL
import requests
from typing import Callable, List, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg19
import torchvision.transforms.functional as TF
from fastai.callbacks import hook_outputs
import fastai.vision.transform as fastTF
import fastai.vision

In [None]:
from abc import ABC, abstractmethod

class ProgressTracker(ABC):
    @abstractmethod
    def notify(self, message:str):
        pass


class PrinterProgressTracker(ProgressTracker):
    def notify(self, message:str):
        print(message)

In [None]:
patch_sz = 3
n_ch = 3
img_sz = 224
vgg_content_layers_idx = [22]
vgg_style_layers_idx = [11, 20]

In [None]:
@dataclass
class LossWeights:
    style:float=1.
    content:float=1.
    reg:float=1e-3

In [None]:
def content_loss(output, target):
    return torch.norm(output - target)**2

In [None]:
def smoothness_reg(gen_img:torch.Tensor) -> torch.Tensor:
    # Equivalent to sum_i,j[(x[i][j+1] - x[i][j])**2 + (x[i+1][j] - x[i][j])**2]
    rows_diff = ((gen_img[:,1:] - gen_img[:,0:-1])**2).sum()
    cols_diff = ((gen_img[:,:,1:] - gen_img[:,:,0:-1])**2).sum()
    return rows_diff + cols_diff

In [None]:
def split_in_patches(t, patch_sz=3) -> torch.Tensor:
    rank = len(t.size())  
    assert rank in (3, 4), 'Input must be a rank 3 or 4 tensor' 
    if rank == 3: t = t.unsqueeze(0)
    stride = 1
    bs, n_ftrs = t.size()[0:2]
    return (t.unfold(0, bs, bs)
             .unfold(1, n_ftrs, n_ftrs)
             .unfold(2, patch_sz, stride)
             .unfold(3, patch_sz, stride)
             .reshape(-1, bs, n_ftrs, patch_sz, patch_sz)
             # Permute first two dims to have all patches from first element of the batch,
             # then all patches from second element of the batch, and so on...
             .permute(1, 0, 2, 3, 4)
             .reshape(-1, n_ftrs, patch_sz, patch_sz))

In [None]:
def style_loss(gen_ftrs, style_ftrs, precalc_style_patches=None):
    style_patches = (split_in_patches(style_ftrs, patch_sz) if precalc_style_patches is None 
                     else precalc_style_patches)
    # n_style_patches will be greater than n_gen_patches when there are several versions
    # (probably transforms) of the style image, so that style_ftrs.size()[0] > 1
    n_style_patches, n_ftrs = style_patches.size()[0:2]
    gen_patches = split_in_patches(gen_ftrs, patch_sz)
    n_gen_patches = gen_patches.size()[0]

    # size: (n_patches, 1)
    gen_patches_norm = gen_patches.view(n_gen_patches, -1).norm(dim=1, keepdim=True)
    style_patches_norm = style_patches.view(n_style_patches, -1).norm(dim=1, keepdim=True)

    # size: (n_gen_patches, n_style_patches)
    # row `i` contains a measure of the similarity between patch `i` of
    # ftrs(generated image) and every patch of ftrs(style image[s])
    # (conv out size is (n_style_imgs, n_gen_patches, sqrt(n_style_patches_per_img), sqrt(n_style_patches_per_img)),
    # so we need to swap the first two dimensions and resize)
    patches_correlations = (F.conv2d(style_ftrs, weight=gen_patches).permute(1, 0, 2, 3).reshape(n_gen_patches, n_style_patches) /
                           (gen_patches_norm @ style_patches_norm.t()))

    idx_best_matches = patches_correlations.argmax(dim=-1)
    # referred to as NN(i) in the paper, size (n_patches)
    # row `i` contains the best match found, between the patches extracted
    # from ftrs(style image), for the patch `i` extracted from ftrs(gen image)
    best_matches = style_patches[idx_best_matches]

    # it's doing sqrt before my sqr undoes it, maybe it's better not to use norm()
    # and do it manually
    return (gen_patches - best_matches).norm()**2

In [None]:
class FeaturesCalculator:
    def __init__(self, vgg_style_layers_idx:List[int], vgg_content_layers_idx:List[int],
                 vgg:nn.Module=None, normalize_inputs=False):
        self.vgg = vgg19(pretrained=True) if vgg is None else vgg
        self.vgg.eval()
        modules_to_hook = [self.vgg.features[idx] for idx in (*vgg_style_layers_idx, *vgg_content_layers_idx)]
        self.hooks = hook_outputs(modules_to_hook, detach=False)
        self.style_ftrs_hooks = self.hooks[:len(vgg_style_layers_idx)]
        self.content_ftrs_hooks = self.hooks[len(vgg_style_layers_idx):]
        self.normalize_inputs = normalize_inputs
        # TODO: when to remove hooks??? no destructor in Python right?
        #  `clean` method????
    
    def _get_hooks_out(self, hooks):
        return [h.stored for h in hooks]
    
    def _forward(img_t:torch.Tensor):
        if self.normalize_inputs: 
            mean, std = fastai.vision.imagenet_stats
            img_t = fastai.vision.normalize(img_t, torch.tensor(mean), torch.tensor(std))
        self.vgg(img_t)
    
    def calc_style(self, img_t:torch.Tensor) -> List[torch.Tensor]:
        self.vgg(img_t)
        return self._get_hooks_out(self.style_ftrs_hooks)
    
    def calc_content(self, img_t:torch.Tensor) -> List[torch.Tensor]:
        self.vgg(img_t)
        return self._get_hooks_out(self.content_ftrs_hooks)
    
    def calc_style_and_content(self, img_t:torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        self.vgg(img_t)
        style_ftrs = self._get_hooks_out(self.style_ftrs_hooks)
        content_ftrs = self._get_hooks_out(self.content_ftrs_hooks)
        return style_ftrs, content_ftrs


def calc_loss(gen_img_t:torch.Tensor, gen_style_ftrs:torch.Tensor, gen_content_ftrs:torch.Tensor, 
              target_style_ftrs: torch.Tensor, target_content_ftrs:torch.Tensor, 
              style_patches:torch.Tensor, loss_weights:LossWeights=None) -> torch.Tensor:
    if loss_weights is None: loss_weights = LossWeights()
    
    # Iterate over feature maps produced by different cnn layers
    s_loss = torch.tensor(0.)
    if loss_weights.style > 0.:
        for i, gen_style_ftr_map in enumerate(gen_style_ftrs):
            s_loss += style_loss(gen_style_ftr_map, target_style_ftrs[i], style_patches[i])
        assert s_loss.requires_grad

    c_loss = torch.tensor(0.)
    if loss_weights.content > 0.:
        for i, gen_content_ftr_map in enumerate(gen_content_ftrs):
            c_loss += content_loss(gen_content_ftr_map, target_content_ftrs[i])
        assert c_loss.requires_grad

    reg = smoothness_reg(gen_img_t) if loss_weights.reg > 0 else torch.tensor(0.)

    loss = loss_weights.style * s_loss + loss_weights.content * c_loss + loss_weights.reg * reg
    assert loss.requires_grad
    return loss

# INPUTS MANAGEMENT

In [None]:
# !git clone https://github.com/mf1024/ImageNet-Datasets-Downloader.git "C:/Users/blabla/ImageNetDownloader"
def download_imagenet_subset():
    !python C:/Users/blabla/ImageNetDownloader/downloader.py \
        -data_root C:/Users/blabla/imagenet \
        -number_of_classes 5 \
        -images_per_class 10
    
def get_img_from_url(url) -> PIL.Image.Image:
    import requests
    return PIL.Image.open(requests.get(url, stream=True).raw)      

In [None]:
def get_transformed_style_imgs(img:PIL.Image):
    fast_img = fastTF.Image(TF.to_tensor(img))
    imgs = []
    for scale in (0.85, 0.9, 0.95, 1, 1.05, 1.1, 1.15):
        for rotation in (-15, -7.5, 0, 7.5, 15):
            new_img = fast_img.apply_tfms([fastTF.rotate(degrees=rotation), fastTF.zoom(scale=scale)], 
                                          size=img_sz, 
                                          resize_method=fastai.vision.ResizeMethod.PAD, 
                                          padding_mode='zeros')
            imgs.append(fast_img_to_tensor(new_img))
    return torch.cat(imgs)

In [None]:
def img_t_from_url(url, target_sz:int=img_sz) -> torch.Tensor:
    img = get_img_from_url(url)
    return img_to_tensor(img, target_sz)
    
def img_t_from_path(path) -> torch.Tensor:
    img = PIL.Image.open(path)
    return img_to_tensor(img)

def img_to_tensor(img:PIL.Image.Image, target_sz:int=img_sz) -> torch.Tensor:
    target_sz_2d = (target_sz, target_sz)
    if img.width > img.height:
        img = TF.pad(img, padding=(0, (img.width - img.height)//2))
    elif img.height > img.width:
        img = TF.pad(img, padding=((img.height - img.width)//2, 0))
    if img.size != target_sz_2d: img = TF.resize(img, target_sz_2d)
    x = TF.to_tensor(img)
    x.unsqueeze_(0)
    return x

def fast_img_to_tensor(img:fastai.vision.Image) -> torch.Tensor:
    return img.px.unsqueeze(0)

def check_norm_is_needed_vgg(x):
    stats = (torch.Tensor([0.485, 0.456, 0.406]), 
             torch.Tensor([0.229, 0.224, 0.225]))
    out = vgg(fastai.vision.normalize(x, *stats))
    out2 = vgg(x)
    return out.max(), out.argmax(), out2.max(), out2.argmax()
    

In [None]:
def normalize(img_t:torch.Tensor):
    mean, std = fastai.vision.imagenet_stats
    return fastai.vision.normalize(img_t, torch.tensor(mean), torch.tensor(std))
    
def denormalize(img_t:torch.Tensor):
    mean, std = fastai.vision.imagenet_stats
    return fastai.vision.denormalize(img_t, torch.tensor(mean), torch.tensor(std))

# DATA

In [None]:
class ImageURLs:
    PAINTING = 'https://www.moma.org/media/W1siZiIsIjQ2NzUxNyJdLFsicCIsImNvbnZlcnQiLCItcXVhbGl0eSA5MCAtcmVzaXplIDIwMDB4MjAwMFx1MDAzZSJdXQ.jpg?sha=314ebf8cc678676f'
    BASKET_BALL = 'https://miro.medium.com/proxy/1*BDE-SkJBCG_7P4chK4vKnw.jpeg'
    FERRARI_F1 = 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Ferrari_F1_2006_EMS.jpg/1024px-Ferrari_F1_2006_EMS.jpg'
    RENAULT_F1 = 'https://upload.wikimedia.org/wikipedia/commons/3/31/Renault_F1_front_IAA_2005.jpg'
    ELON_MUSK = 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/49/Elon_Musk_2015.jpg/408px-Elon_Musk_2015.jpg'
    MARK_ZUCK = 'https://live.staticflickr.com/6156/6198197101_9d7a685618_b.jpg'

In [None]:
painting_img_t = normalize(img_t_from_url(ImageURLs.PAINTING))
basket_ball_img_t = normalize(img_t_from_url(ImageURLs.BASKET_BALL))

In [None]:
fastai.vision.show_image(denormalize(basket_ball_img_t[0]))
fastai.vision.show_image(denormalize(painting_img_t[0]))

In [None]:
painting_imgs_t = normalize(get_transformed_style_imgs(get_img_from_url(ImageURLs.PAINTING)))

# TRAINING

In [None]:
@dataclass
class HyperParams:
    lr:float=1e-4
    wd:float=0.
    adam_betas:Tuple[float, float]=(0.9, 0.999)

In [None]:
def train(style_img_t:torch.Tensor, content_img_t:torch.Tensor, init_gen_img_t:torch.Tensor=None,
          n_iters=100, hyperparams:HyperParams=None, loss_weights:LossWeights=None, 
          progress_tracker:ProgressTracker=None, callbacks:List[Callable]=None) -> torch.Tensor:
    gen_img_t = (init_gen_img_t if init_gen_img_t is not None
                 else normalize(torch.rand(content_img_t.size())))
    gen_img_t.requires_grad_(True)
    if hyperparams is None: hyperparams = HyperParams()
    opt = torch.optim.Adam([gen_img_t], lr=hyperparams.lr, betas=hyperparams.adam_betas, 
                           weight_decay=hyperparams.wd)
    ftrs_calc = FeaturesCalculator(vgg_style_layers_idx, 
                                   vgg_content_layers_idx)
    
    with torch.no_grad():
        target_style_ftrs = ftrs_calc.calc_style(style_img_t)
        target_content_ftrs = ftrs_calc.calc_content(content_img_t)
    style_patches = [split_in_patches(ftr_map, patch_sz) for ftr_map in target_style_ftrs]

    for i in range(n_iters):
        gen_style_ftrs, gen_content_ftrs = ftrs_calc.calc_style_and_content(gen_img_t)

        loss = calc_loss(gen_img_t, gen_style_ftrs, gen_content_ftrs, target_style_ftrs, 
                         target_content_ftrs, style_patches, loss_weights)
        loss.backward()
        opt.step()
        opt.zero_grad()

        if callbacks is not None: 
            for c in callbacks: c(i, gen_img_t, loss)
        if progress_tracker is not None: progress_tracker.notify(f'Completed iteration {i}')
        
    return gen_img_t

In [None]:
def train_progressive_growing(style_img:PIL.Image.Image, content_img:PIL.Image.Image, target_sz:int,
                              init_sz:int=16, upsample_mode='bilinear', transform_style_img=True,
                              n_iters_by_sz:int=200, **train_kwargs) -> torch.Tensor:
    assert init_sz <= target_sz
    cur_sz = init_sz
    gen_img_t = None
    while cur_sz <= target_sz:
        cur_sz = min(cur_sz, target_sz)
        if cur_sz != init_sz: 
            gen_img_t = F.interpolate(gen_img_t.detach(), cur_sz, mode=upsample_mode, align_corners=False)
        style_img_t = (get_transformed_style_imgs(style_img, cur_sz) if transform_style_img
                       else normalize(img_to_tensor(style_img, cur_sz)))
        content_img_t = normalize(img_to_tensor(content_img, cur_sz))
        gen_img_t = train(style_img_t, content_img_t, gen_img_t, n_iters=n_iters_by_sz,
                          **train_kwargs)
        cur_sz *= 2

    return gen_img_t

In [None]:
def get_print_results_callback(n_iters_between:int, n_total_iters:int, show_sz=False):
    n_imgs = n_total_iters // n_iters_between
    n_cols = 3
    n_rows = n_imgs//n_cols + min(1, n_imgs % n_cols)
    imgs, losses = [], []
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(16, 16 * n_rows/n_cols))
    axs = axs.flatten()
    for ax in axs.flatten(): ax.axis('off')
    display_id = None

    def _print_result(i, gen_img_t, loss):
        if (i+1) % n_iters_between == 0:
            gen_img_t = denormalize(gen_img_t.detach().cpu()).squeeze(dim=0).clamp(0, 1)
            gen_img = fastai.vision.Image(gen_img_t)          
            imgs.append(gen_img)
            losses.append(loss.detach().cpu())
            for j, img in enumerate(imgs):
                # If using progressive growing, it's only right if the number of iterations 
                # by size is a multiple of `n_iters_between`
                iter_idx = (j+1) * n_iters_between
                title = f'Iteration {iter_idx}, loss = {losses[j]}'
                if show_sz: title += f', size = {img.size}'
                img.show(ax=axs[j], title=title)
            #plt.close()
            nonlocal display_id
            if display_id is None: display_id = display.display(fig, display_id=True)
            display_id.update(fig)
            display.clear_output(wait=True)

    return _print_result

In [None]:
n_iters=200
gen_img_t = train(painting_img_t, 
                  basket_ball_img_t,
                  hyperparams=HyperParams(lr=0.1), 
                  n_iters=n_iters, 
                  loss_weights=LossWeights(style=0.1, content=1., reg=0.),
                  callbacks=[get_print_results_callback(20, n_iters)])

In [None]:
denorm_clamped_gen_img_t = denormalize(gen_img_t.detach().cpu()).squeeze(dim=0).clamp(0, 1)
gen_img = fastai.vision.Image(denorm_clamped_gen_img_t)
gen_img

In [None]:
gen_img.save('mrfcnn_tr1_200it.jpg')

In [None]:
gen_img_t.min(), gen_img_t.max(), gen_img_t.size(), gen_img_t.requires_grad

In [None]:
denorm_gen_img_t = denormalize(gen_img_t.detach())
pix_dist = {0: 0}
for px in denorm_gen_img_t.flatten():
    if 0. <= px <= 1.: 
        pix_dist[0] += 1
    else:
        key = int(px.item()//1)
        if key not in pix_dist: pix_dist[key] = 0
        pix_dist[key] += 1

pix_dist

There may be some pixels out of range. Maybe they would be black/white anyway and thinking about a solution (like penalizing it with loss func) is not worth it.

## Training with progressive growing

In [None]:
style_img = get_img_from_url(ImageURLs.PAINTING)
content_img = get_img_from_url(ImageURLs.BASKET_BALL)
n_iters_by_sz = 200
init_sz, target_sz = 64, 256
n_total_iters = n_iters_by_sz * int(1 + math.log2(target_sz//init_sz))
gen_img_t = train_progressive_growing(style_img, 
                                      content_img, 
                                      target_sz,
                                      init_sz=init_sz,
                                      transform_style_img=False,
                                      n_iters_by_sz=n_iters_by_sz, 
                                      hyperparams=HyperParams(lr=0.1),
                                      loss_weights=LossWeights(style=0.1, content=1., reg=1e-3),
                                      callbacks=[get_print_results_callback(40, n_total_iters)])

In [None]:
denorm_clamped_gen_img_t = denormalize(gen_img_t.detach().cpu()).squeeze(dim=0).clamp(0, 1)
gen_img = fastai.vision.Image(denorm_clamped_gen_img_t)
gen_img

In [None]:
gen_img.save('mrfcnn_tr1_200it_pg64-256.jpg')

# TESTS

In [None]:
def test_content_loss():
    x = torch.Tensor([[[0]*4]*4]*3)
    y = torch.Tensor([[[1]*4]*4]*3)
    loss1 = content_loss(x, y)
    x += 0.5
    loss2 = content_loss(x, y)
    y -= 0.25
    loss3 = content_loss(x, y)
    assert loss1 == 4*4*3
    assert loss2 == 4*3
    assert loss3 == 3

def test_smoothness_reg():
    uniform_img = torch.Tensor([[[1]*4]*4]*3)
    print(uniform_img.size())
    diffy_x_img = torch.Tensor([[[1]*4, [0]*4]*2]*3)
    diffy_y_img = torch.Tensor([[[1, 0]*2]*4]*3)
    diffy_x_y_img = torch.Tensor([[[1, 0]*2, [0, 1]*2]*2]*3)
    
    print(diffy_x_y_img)
    assert smoothness_reg(uniform_img) == 0
    assert smoothness_reg(diffy_x_img) == 3*3*4
    assert smoothness_reg(diffy_y_img) == 3*4*3
    assert smoothness_reg(diffy_x_y_img) == 2*3*4*3

def test_split_in_patches():
    img = torch.Tensor([i for i in range(3*4*4)]).view(3, 4, 4)
    # tensor([[[ 0.,  1.,  2.,  3.],
    #          [ 4.,  5.,  6.,  7.],
    #          [ 8.,  9., 10., 11.],
    #          [12., 13., 14., 15.]],

    #         [[16., 17., 18., 19.],
    #          [20., 21., 22., 23.],
    #          [24., 25., 26., 27.],
    #          [28., 29., 30., 31.]],

    #         [[32., 33., 34., 35.],
    #          [36., 37., 38., 39.],
    #          [40., 41., 42., 43.],
    #          [44., 45., 46., 47.]]])
    actual_2x2 = split_in_patches(img, patch_sz=2)
    actual_3x3 = split_in_patches(img, patch_sz=3)
    expected_2x2 = torch.Tensor(
       [[[[ 0.,  1.],
          [ 4.,  5.]],

         [[16., 17.],
          [20., 21.]],

         [[32., 33.],
          [36., 37.]]],


        [[[ 1.,  2.],
          [ 5.,  6.]],

         [[17., 18.],
          [21., 22.]],

         [[33., 34.],
          [37., 38.]]],


        [[[ 2.,  3.],
          [ 6.,  7.]],

         [[18., 19.],
          [22., 23.]],

         [[34., 35.],
          [38., 39.]]],


        [[[ 4.,  5.],
          [ 8.,  9.]],

         [[20., 21.],
          [24., 25.]],

         [[36., 37.],
          [40., 41.]]],


        [[[ 5.,  6.],
          [ 9., 10.]],

         [[21., 22.],
          [25., 26.]],

         [[37., 38.],
          [41., 42.]]],


        [[[ 6.,  7.],
          [10., 11.]],

         [[22., 23.],
          [26., 27.]],

         [[38., 39.],
          [42., 43.]]],


        [[[ 8.,  9.],
          [12., 13.]],

         [[24., 25.],
          [28., 29.]],

         [[40., 41.],
          [44., 45.]]],


        [[[ 9., 10.],
          [13., 14.]],

         [[25., 26.],
          [29., 30.]],

         [[41., 42.],
          [45., 46.]]],


        [[[10., 11.],
          [14., 15.]],

         [[26., 27.],
          [30., 31.]],

         [[42., 43.],
          [46., 47.]]]])
    expected_3x3=torch.Tensor(
       [[[[ 0.,  1.,  2.],
          [ 4.,  5.,  6.],
          [ 8.,  9., 10.]],

         [[16., 17., 18.],
          [20., 21., 22.],
          [24., 25., 26.]],

         [[32., 33., 34.],
          [36., 37., 38.],
          [40., 41., 42.]]],


        [[[ 1.,  2.,  3.],
          [ 5.,  6.,  7.],
          [ 9., 10., 11.]],

         [[17., 18., 19.],
          [21., 22., 23.],
          [25., 26., 27.]],

         [[33., 34., 35.],
          [37., 38., 39.],
          [41., 42., 43.]]],


        [[[ 4.,  5.,  6.],
          [ 8.,  9., 10.],
          [12., 13., 14.]],

         [[20., 21., 22.],
          [24., 25., 26.],
          [28., 29., 30.]],

         [[36., 37., 38.],
          [40., 41., 42.],
          [44., 45., 46.]]],


        [[[ 5.,  6.,  7.],
          [ 9., 10., 11.],
          [13., 14., 15.]],

         [[21., 22., 23.],
          [25., 26., 27.],
          [29., 30., 31.]],

         [[37., 38., 39.],
          [41., 42., 43.],
          [45., 46., 47.]]]])
    assert(torch.equal(actual_2x2, expected_2x2))
    assert(torch.equal(actual_3x3, expected_3x3))

def test_style_loss(vgg19):
    elon_photo_url = 'https://wonderfulengineering.com/wp-content/uploads/2018/09/musk5.jpg'
    elon_drawing_url = 'https://i.redd.it/ofljkrzi82r21.jpg'
    rectangular_spain_flag_url = 'https://upload.wikimedia.org/wikipedia/commons/d/d5/Flag_of_Spain_%28WFB_2000%29.jpg'
    #'http://icons.iconarchive.com/icons/wikipedia/flags/256/ES-Spain-Flag-icon.png'
    #hand_spain_flag_url = 'https://cdn.pixabay.com/photo/2015/02/16/00/20/spain-637843_960_720.jpg'
    circular_spain_flag_url = 'https://cdn.pixabay.com/photo/2017/10/04/10/44/spain-2815785_960_720.jpg'
    #'http://files.softicons.com/download/web-icons/world-cup-flags-icons-by-custom-icon-design/png/64x64/Spain.png'
    austria_flag_url = 'https://pixnio.com/free-images/flags-of-the-world/flag-of-austria-725x483.jpg'
    #'https://www.publicdomainpictures.net/pictures/120000/velka/austria-flag.jpg'
    ftrs_calc = FeaturesCalculator(vgg_style_layers_idx, vgg_content_layers_idx, vgg19)
    url_to_ftrs = lambda url: ftrs_calc.calc_style(img_t_from_url(url))

    ftr_map_1 = torch.rand(1, 16, 8, 8)
    ftr_map_2 = torch.rand(1, 16, 8, 8)    
    ftr_map_sp_rect = url_to_ftrs(rectangular_spain_flag_url)[0]
    ftr_map_sp_circ = url_to_ftrs(circular_spain_flag_url)[0]
    ftr_map_austria = url_to_ftrs(austria_flag_url)[0]
    ftr_map_sp_rect_patches = split_in_patches(ftr_map_sp_rect)    
    
    loss_rect_circ = style_loss(ftr_map_sp_circ, ftr_map_sp_rect, ftr_map_sp_rect_patches)
    loss_rect_sp_aust = style_loss(ftr_map_austria, ftr_map_sp_rect, ftr_map_sp_rect_patches)
    loss_equal = style_loss(ftr_map_1, ftr_map_1)
    loss_different = style_loss(ftr_map_1, ftr_map_2)
    #assert loss_equal == 0, f'{loss_equal}'
    #assert loss_different > 0, f'{loss_different}'
    assert loss_rect_circ < loss_rect_sp_aust, f'{loss_rect_circ}, {loss_rect_sp_aust}'

test_content_loss()
test_smoothness_reg()
test_split_in_patches()
# This isn't so easy to test
#test_style_loss(vgg)

# PENDING

* Prepare for GPU and test if it improves anything in terms of speed
* Penalize gen_img being out of range (be aware range is different for each channel)
* Check better if normalized cross-correlation is ok
* Don't forget to release plots, be it plot.close() or whatever
* How to deal with rectangular images? Right now, I'm resizing the larger dim to 224 and filling the rest with black padding. Other options are:
  * Crop (even if random it shouldn't make much sense)
  * Check/think if vgg19 is actually capable of dealing with input sizes other than (224, 224) while preserving evaluation quality
* Fit one cycle may be used/adapted????
* Add a reference to the paper https://arxiv.org/pdf/1601.04589.pdf and authors
  
# TRAINING TEST STEPS
1. Use only content loss
2. Check adding style loss improves the results
    -Check requires_grad of tensors in the middle of the process
3. Check adding regularizer smoothes the result

# IMPROVEMENTS
* Precalc as much stuff related to content and style image as possible