# Neural Style Transfer

Implementation of https://arxiv.org/abs/1705.06830 with an improved UNET architecture for style transferral.

In [0]:
# Needed due to a performance regression on InceptionNet with newer scipy versions.
!pip install --upgrade scipy==1.3.3

In [0]:
import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pickle

from pathlib import Path

import matplotlib.pyplot as plt

import pandas as pd

import numpy as np

from fastai.basic_train import Learner, Callback
from fastai.basic_data import DataBunch
from fastai.vision.data import ImageList
from fastai.vision.models import resnet50, resnet18, unet
from fastai.vision.transform import get_transforms
from fastai.metrics import accuracy
from fastai.layers import PixelShuffle_ICNR
from typing import List

from collections import namedtuple

from PIL import Image, UnidentifiedImageError

import torchvision.transforms as tfms
from torchvision import models

import os
import glob
import shutil
import random
import math

from inspect import signature

from copy import deepcopy

In [0]:
DEVICE='cuda'

# Data

In [0]:
# From https://pytorch.org/hub/pytorch_vision_vgg/
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

In [0]:
image_size = 299

In [0]:
def split_list(l, frac_first: float, seed: int = None):
    if seed:
        random.seed(seed)

    copied = deepcopy(l)
    random.shuffle(copied)

    num_first = int(frac_first * len(copied))
    first = copied[:num_first]
    second = copied[num_first:]

    return first, second

In [0]:
class UnlabelledImageDS(Dataset):
    def __init__(self, fnames: List[str]):
        self.fnames = fnames

        self.transforms = tfms.Compose([
            tfms.Resize((image_size, image_size)), 
            tfms.ToTensor(),
            tfms.Normalize(mean, std)
        ])
    
    def __len__(self): return len(self.fnames)

    def __getitem__(self, index):
        fname = self.fnames[index]
        cat_name = os.path.dirname(fname).split("/")[-1]

        img = Image.open(fname).convert('RGB')

        transformed = self.transforms(img)

        return transformed

## Content Data

In [0]:
# Courtesy of https://forums.fast.ai/t/how-does-one-download-imagenet/40660/9
!wget http://files.fast.ai/data/imagenet-sample-train.tar.gz

In [0]:
%%capture
!tar -zxvf imagenet-sample-train.tar.gz

In [0]:
all_content = glob.glob("train/*/*")
len(all_content)

In [0]:
train_content, valid_content = split_list(all_content, 0.8)

In [0]:
content_train_ds = UnlabelledImageDS(train_content)
content_valid_ds = UnlabelledImageDS(valid_content)

# content_train_ds = UnlabelledImageDS(train_content[:1])
# content_valid_ds = content_train_ds

len(content_train_ds), len(content_valid_ds)

In [0]:
content_train_ds[0]

## Style Data

In [0]:
# Upload Kaggle token first.
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/kaggle.json

In [0]:
!kaggle datasets download -d thedownhill/art-images-drawings-painting-sculpture-engraving

In [0]:
%%capture
!unzip art-images-drawings-painting-sculpture-engraving.zip

In [0]:
all_style_unfiltered = glob.glob("musemart/dataset_updated/training_set/painting/*")
all_style = []

for img in all_style_unfiltered:
    try:
        x = Image.open(img)
        all_style.append(img)
    except UnidentifiedImageError:
        pass

len(all_style)

In [0]:
train_style, valid_style = split_list(all_style, 0.8)

In [0]:
style_train_ds = UnlabelledImageDS(train_style)
style_valid_ds = UnlabelledImageDS(valid_style)

# SHORT
# style_train_ds = UnlabelledImageDS(train_style[:1])
# style_valid_ds = style_train_ds

len(style_train_ds), len(style_valid_ds)

In [0]:
style_train_ds[0]

## Dataset

In [0]:
class StyleTransferDS(Dataset):
  def __init__(self, content_ds, style_ds):
    self.content = content_ds
    self.style = style_ds

  def __getitem__(self, index):
    content = self.content[index]
    style = self.style[random.randint(0, len(self.style) - 1)]

    # hack so there's a 'inp' and 'label' for fastai to split
    return (content, style), (content, style)

  def __len__(self):
    return len(self.content)  

In [0]:
train_ds = StyleTransferDS(content_train_ds, style_train_ds)
valid_ds = StyleTransferDS(content_valid_ds, style_valid_ds)

In [0]:
bs = 8

In [0]:
data = DataBunch.create(train_ds, valid_ds, bs=bs)

# Model

## Code

### Transformer

In [0]:
class Print(nn.Module):
    def __init__(self, name, mod):
        super().__init__()

        self.name = name
        self.mod = mod
    
    def forward(self, x):
        print("STARTING: ", self.name)
        return self.mod(x)


class SameConv(nn.Module):
    """Conv with same padding"""
    def __init__(self, cin: int, cout: int, kernel_size: int, stride=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.conv = nn.Conv2d(cin, cout, kernel_size=kernel_size, stride=stride)

    def forward(self, x):
        """Assumes has NCHW format"""
        height = x.shape[2]
        width = x.shape[3]
        stride = self.stride

        kernel_overlap = self.kernel_size // 2

        pad_h = int(math.ceil(height / stride) * stride) - height + 2*kernel_overlap
        pad_w = int(math.ceil(width / stride) * stride) - width + 2*kernel_overlap

        pad_top = pad_h // 2
        pad_bot = pad_h - pad_top

        pad_left = pad_w // 2
        pad_right = pad_w - pad_left

        padded = F.pad(x, [pad_top, pad_bot, pad_left, pad_right], mode='reflect')

        return self.conv(padded)


def conv3(cin, cout, stride=1):
  return SameConv(cin, cout, 3, stride)


class StyleBatchNorm(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.norm = nn.InstanceNorm2d(channels, affine=False)
    def forward(self, inp, scale, shift):
        normed =  self.norm(inp)
        return normed * scale + shift


class StyleConv(nn.Module):
    def __init__(self, cin, cout, kernel_size, dim_scale, dim_shift, 
                 stride=1, act=nn.ReLU):
        super().__init__()

        self.shift_dense = nn.Linear(dim_shift, cout, bias=False)
        self.scale_dense = nn.Linear(dim_scale, cout, bias=False)

        self.conv = SameConv(cin, cout, kernel_size, stride)
        self.relu = act()
        self.norm = StyleBatchNorm(cout)

    def forward(self, x, scale, shift):

        scale = self.scale_dense(scale).view([scale.shape[0], -1, 1, 1])
        shift = self.shift_dense(shift).view([shift.shape[0], -1, 1, 1])

        conv = self.relu(self.conv(x))
        return self.norm(conv, scale, shift)


def sconv3(cin, cout, dim_scale, dim_shift, stride=1, act=nn.ReLU):
    return StyleConv(cin, cout, 3, dim_scale, dim_shift, stride=stride, act=act)


class StyleResidualBlock(nn.Module):
    def __init__(self, cin: int, dim_scale: int, dim_shift: int):
        super().__init__()

        self.conv1 = sconv3(cin, cin, dim_scale, dim_shift)
        self.conv2 = sconv3(cin, cin, dim_scale, dim_shift, act=Identity)
    
    def forward(self, inp, scale, shift):
        out = self.conv2(self.conv1(inp, scale, shift), scale, shift)
        return out + inp


class PixelShuffleWrapper(nn.Module):
    def __init__(self, ni, nf, scale):
        super().__init__()
        self.mod = PixelShuffle_ICNR(ni, nf, scale)
    def forward(self, inp, ignore_a, ignore_b):
        return self.mod(inp)


class UpsampleBlock(nn.Module):
  def __init__(self, left: nn.Module, left_cin: int, right_cin: int, 
               scale_up: int, dim_scale: int, dim_shift: int):
    super().__init__()
    
    if scale_up != 1:
        self.upsample = PixelShuffleWrapper(right_cin, left_cin, scale_up)
    else:
        self.upsample = sconv3(right_cin, left_cin, dim_scale, dim_shift)
    
    self.conv1 = sconv3(left_cin * 2, left_cin, dim_scale, dim_shift)
    
    
    def hook_fn(m,i,o):
        self.stored_left = o

    self.hook = left.register_forward_hook(hook_fn)
    self.stored_left = None
  

  def forward(self, right, scale, shift):
    left = self.stored_left

    upsampled = self.upsample(right, scale, shift)
    upsampled_size = upsampled.shape[2]
    left_size = left.shape[2]

    delta = left_size - upsampled_size
    if delta != 0:
        upsampled = F.interpolate(upsampled, size=left.shape[2:])

    cat = torch.cat([left, upsampled], dim=1)

    x = self.conv1(cat, scale, shift)
    return x


class Identity(nn.Module):
    def forward(self, x):
        return x


class Style_UNET(nn.Module):
    def __init__(self, encoder:nn.Module, children, dim_scale: int, dim_shift: int, test_img_shape=(1, 3, 200, 200)):
        super().__init__()

        ident = Identity()
        class Enc(nn.Module):
            def __init__(self, ident, enc):
                super().__init__()
                self.ident = ident
                self.enc = enc
            def forward(self, img, scale, shift):
                ident_out = self.ident(img)
                return self.enc(ident_out, scale, shift)

        self.encoder = Enc(ident, encoder).to(DEVICE)

        all_pairs = [(ident, test_img_shape)]
        hooks = []
        for child in children:
            hook = child.register_forward_hook(lambda m,i,o: all_pairs.append((m, o.shape)))
            hooks.append(hook)

        dummy_inp = torch.ones(test_img_shape).to(DEVICE)
        dummy_scale = torch.ones(test_img_shape[0], dim_scale).to(DEVICE)
        dummy_shift = torch.ones(test_img_shape[0], dim_shift).to(DEVICE)
        encoder(dummy_inp, dummy_scale, dummy_shift)

        for h in hooks:
            h.remove()

        assert len(all_pairs) == len(children) + 1

        kept_pairs = []
        prev = all_pairs[0]

        for pair in all_pairs[1:]:
            if pair[1][2] != prev[1][2]:
                kept_pairs.append(prev)
                prev = pair

        kept_pairs.append(prev)

        change_modules = [p[0] for p in kept_pairs]
        change_shapes = [p[1] for p in kept_pairs]

        bottom_shape = change_shapes[-1]

        connecting_modules = change_modules[:-1]
        connecting_shapes = change_shapes[:-1]

        strides = []
        slims = []
        prev_shape = change_shapes[0]
        for shape in change_shapes[1:]:
            stride = prev_shape[2] // shape[2]
            strides.append(stride)

            slim = prev_shape[2] % shape[2]
            slims.append(slim)
            prev_shape = shape

        upsamples = []

        assert len(connecting_modules) == len(connecting_shapes) == len(strides) == len(slims)


        right_cin = bottom_shape[1]
        for down_module, down_shape, scale in zip(reversed(connecting_modules), 
                                                       reversed(connecting_shapes), 
                                                       reversed(strides)):
            
            upsamples.append(UpsampleBlock(down_module, down_shape[1], right_cin, scale, dim_scale, dim_shift).to(DEVICE))
            right_cin = down_shape[1]

        self.upsamples = nn.ModuleList(upsamples)

    def forward(self, x, scale, shift):
        encoded = self.encoder(x, scale, shift)
        
        decoded = encoded
        for up in self.upsamples:
            decoded = up(decoded, scale, shift)
        return decoded

### Style Network

In [0]:
class GrabNamedModule(nn.Module):
    def __init__(self, mod: nn.Module, name: str):
        super().__init__()

        self.mod = mod

        mod_dict = dict(mod.named_children())
        output_mod = mod_dict[name]

        self.stored = None

        def hook(m,i,o):
            self.stored = o

        output_mod.register_forward_hook(hook)

    def forward(self, *args, **kwargs):
        _ = self.mod(*args, **kwargs)
        return self.stored


class StyleNet(nn.Module):
    def __init__(self, dim_scale: int, dim_shift: int, bottleneck_size: int = 100):
        super().__init__()

        self.inception = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
        output_name = "Mixed_6e"

        inception_hooked = GrabNamedModule(self.inception, output_name)

        self.dim_scale = dim_scale
        self.dim_shift = dim_shift
        

        self.model = nn.Sequential(
            inception_hooked,
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(768, bottleneck_size),
            nn.ReLU(),
            nn.Linear(bottleneck_size, dim_scale + dim_shift)
        )

    
    def forward(self, x):

        out = self.model(x)

        scale = out[:, :self.dim_scale]
        shift = out[:, self.dim_scale:]

        return scale, shift


class NeuralStyleTransfer(nn.Module):
    def __init__(self, transfer: nn.Module, style_net: nn.Module):
        super().__init__()
        
        self.transfer = transfer
        self.style_net = style_net
    
    def forward(self, content, style):
        scale, shift = self.style_net(style)
        return self.transfer(content, scale, shift)


class StyleSequential(nn.Module):
    def __init__(self, *modules):
        super().__init__()
        self.mods = nn.ModuleList(modules)
    
    def forward(self, x, scale, shift):
        out = x
        
        for mod in self.mods:
            params = signature(mod.forward).parameters
            if len(params) == 1:
                out = mod(out)
            elif len(params) == 3:
                out = mod(out, scale, shift)
            else:
                raise ValueError(f"Expected only 2 or 4 params: this has {params}")
            
        return out

## Instantiation

In [0]:
dim_scale = 2758 // 2
dim_shift = 2758 // 2

In [0]:
style_net = StyleNet(dim_scale, dim_shift)

In [0]:
encoder = StyleSequential(
    StyleConv(3, 32, 9, dim_scale, dim_shift, stride=1),
    StyleConv(32, 64, 3, dim_scale, dim_shift, stride=2),
    StyleConv(64, 128, 3, dim_scale, dim_shift, stride=2),
    StyleResidualBlock(128, dim_scale, dim_shift),
    StyleResidualBlock(128, dim_scale, dim_shift),
    StyleResidualBlock(128, dim_scale, dim_shift),
    StyleResidualBlock(128, dim_scale, dim_shift),
    StyleResidualBlock(128, dim_scale, dim_shift)).to(DEVICE)

In [0]:
children = encoder.mods

In [0]:
unet = Style_UNET(encoder, children, dim_scale, dim_shift).to(DEVICE)

In [0]:
# Sample passthrough.
img = torch.ones(2, 3, 50, 50).to(DEVICE)
sc = torch.ones(2, dim_scale).to(DEVICE)
sh = torch.ones(2, dim_shift).to(DEVICE)
unet(img, sc, sh)

In [0]:
model = NeuralStyleTransfer(unet, style_net).to(DEVICE)

# Loss

## VGG Code

In [0]:
# From https://github.com/pytorch/examples/blob/b9f3b2ebb9464959bdbf0c3ac77124a704954828/fast_neural_style/neural_style/vgg.py#L7
class Vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
        return out

## Loss Function

## Code

In [0]:
class StoredHooks:
    def __init__(self, modules):
        self.hooks = []
        self.stored = []

        for mod in modules:
            self.hooks.append(mod.register_forward_hook(lambda m,i,o: self.stored.append(o)))
    
    def reset(self):
        self.stored = []


class StyleLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, x, y):
        x_gram = torch.einsum("bijc,bijd->bcd", x, x) / (x.shape[2] * x.shape[3])
        y_gram = torch.einsum("bijc,bijd->bcd", y, y) / (y.shape[2] * y.shape[3])

        return self.mse(x_gram, y_gram)
        

class NSTLoss(nn.Module):
    def __init__(self, style_mult: float):
        super().__init__()
        
        self.style_mult = style_mult
        self.vgg = Vgg16(requires_grad=False).to(DEVICE)
        self.style_loss = StyleLoss()
        self.content_loss = nn.MSELoss()
    
    def forward(self, transferred, content, style):
        
        content_out = self.vgg(content)
        style_out = self.vgg(style)
        transferred_out = self.vgg(transferred)

        style_loss = 0
        for og, tr in zip(style_out, transferred_out):
            style_loss = style_loss + self.style_loss(og, tr)

        style_loss = style_loss / len(style_out)
        
        content_loss = self.content_loss(content_out.relu2_2, transferred_out.relu2_2)
        
        print(f"Style loss: {style_loss}, Content loss: {content_loss}")

        comb = content_loss + self.style_mult * style_loss
        print("combined:", comb)
        return comb

## Instantiation

In [0]:
style_mult = 1/2

In [0]:
loss = NSTLoss(style_mult)

# Training

In [0]:
learner = Learner(data, model, loss_func=loss)

In [0]:
!nvidia-smi

In [0]:
learner.to_fp16()

In [0]:
learner.lr_find()

In [0]:
learner.recorder.plot(skip_end=15)

In [0]:
# The paper specified 4M parameter updates - this is a small fraction of that.
learner.fit(10, 1e-3)

# Checking Out Results

In [0]:
cont_i = content_train_ds[0].unsqueeze(0).to(DEVICE)
style_i = style_train_ds[0].unsqueeze(0).to(DEVICE)

model.eval()
out = model(cont_i, style_i)

In [0]:
def make_output_image(out):
    out_scaled = out.squeeze(0) * torch.tensor(std).view(3, 1, 1).to(DEVICE) + torch.tensor(mean).view(3, 1, 1).to(DEVICE)
    out_np = out_scaled.permute([1,2,0]).cpu().detach().numpy()
    img_in = (out_np * 255).astype(np.uint8)
    return Image.fromarray(img_in)

In [0]:
make_output_image(out)