# 1. Simple style transfer

In this section we will directly apply the style transfer to the image by treating the image pixels as weights and optimizing them. First we will load our content and our style image:

In [72]:
# INFERENCE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor,ToPILImage,Normalize
from utils.model import construct_style_loss_model,construct_decoder_from_encoder
from utils.losses import content_gatyes,style_gatyes,style_mmd_polynomial,adaIN,style_mmd_gaussian
from utils.utility import normalize,normalize_cw
import cv2
from copy import deepcopy
from torchvision.models import vgg19,VGG19_Weights
from PIL import Image,ImageOps
import torch.optim as optim
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# INFERENCE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("On device: ", device)

On device:  cuda


In [123]:
# INFERENCE

CONTENT_IMAGE_PATH = "./dragon.jpg"
STYLE_IMAGE_PATH = "./wave.jpg"
IMAGE_SIZE = (256,256)

In [124]:
# INFERENCE

content_img = Image.open(CONTENT_IMAGE_PATH).convert('RGB')
# center crop image
content_img = ImageOps.fit(content_img,(min(content_img.size),min(content_img.size))).resize(IMAGE_SIZE)
# Since PIL has the format [W x H x C], and ToTensor() transforms it into [C x H x W], we have to permute the tensor to shape [C x W x H]
content_img = ToTensor()(content_img).permute(0,2,1)

style_img = Image.open(STYLE_IMAGE_PATH).convert('RGB')
# center crop image
style_img = ImageOps.fit(style_img,(min(style_img.size),min(style_img.size))).resize(IMAGE_SIZE)
style_img = ToTensor()(style_img).permute(0,2,1)

In [112]:
# INFERENCE

# Next we load the model. We will use the standard vgg19 model by pytorch. 
# We will use the model without the classification head and add a normalization layer to match the distribution of the models training data:
# load model
vgg = vgg19(VGG19_Weights.DEFAULT)

# remove classification head
vgg = vgg.features

# prepend a normalization layer
vgg = nn.Sequential(Normalize(mean = (0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)), *vgg)

# lets print the model
vgg



Sequential(
  (0): Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
  (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): ReLU(inplace=True)
  (13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (14): ReLU(inplace=True)
  (15): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (16): ReLU(inplace=True)
  (17): Conv2d(256, 256, kernel_size=(3, 3),

In [114]:
# INFERENCE

# next we define which layers we will use as content and weight layers. Note that the indeces match the indices of the printed vgg model. 
# So index (6) means using the layer "Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))".
# Note that in theory not only conv layers can be used. Some papers also use the ReLU layers between conv.
CONTENT_LAYERS = [3,6]
STYLE_LAYERS = [6,8,11]

# Each layers gets a weighting. Default is just 1.0 for every layer. Note that these lists have to have the same length as the lists for choosing the layers.
CONTENT_LAYERS_WEIGHTS = [1.0,1.0]
STYLE_LAYERS_WEIGHTS = [1.0,1.0,1.0]

if not len(CONTENT_LAYERS) == len(CONTENT_LAYERS_WEIGHTS):
    raise AssertionError("CONTENT_LAYERS and CONTENT_LAYERS_WEIGHTS have to have the same length but were {0} and {1} respectively".format(len(CONTENT_LAYERS),len(CONTENT_LAYERS_WEIGHTS)))
if not len(STYLE_LAYERS) == len(STYLE_LAYERS_WEIGHTS):
    raise AssertionError("STYLE_LAYERS and STYLE_LAYERS_WEIGHTS have to have the same length but were {0} and {1} respectively".format(len(STYLE_LAYERS),len(STYLE_LAYERS_WEIGHTS)))

In [115]:
# INFERENCE

# Based on these information we construct our style loss model. As input it will take a tuple containing an image and two empty lists.
# It will return a tuple containing the output and two lists containing the features from the chosen content and style layers respectively.
style_loss_model = construct_style_loss_model(vgg,CONTENT_LAYERS,STYLE_LAYERS)
style_loss_model

range(0, 12)


Sequential(
  (Model layer: 0 | Content layer: False | Style layer: False): Parallel(
    (layer): Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
  )
  (Model layer: 1 | Content layer: False | Style layer: False): Parallel(
    (layer): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (Model layer: 2 | Content layer: False | Style layer: False): Parallel(
    (layer): ReLU(inplace=True)
  )
  (Model layer: 3 | Content layer: True | Style layer: False): Parallel(
    (layer): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (Model layer: 4 | Content layer: False | Style layer: False): Parallel(
    (layer): ReLU(inplace=True)
  )
  (Model layer: 5 | Content layer: False | Style layer: False): Parallel(
    (layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (Model layer: 6 | Content layer: True | Style layer: True): Parallel(
    (layer): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1)

In [125]:
# INFERENCE

# set the model to eval just in case it contains e.g. Dropout layers
style_loss_model = style_loss_model.eval()
style_loss_model.requires_grad_(False)

# lets bring everything to the correct device
style_loss_model = style_loss_model.to(device)
content_img = content_img.to(device)
style_img = style_img.to(device)

In [126]:
# INFERENCE

# The algorithm returns better results if we set the initial image to the content image
# We could also use random noise: torch.rand_like(content_image)
img = nn.Parameter(content_img.clone().contiguous().to(device))
optimizer = optim.LBFGS([img],lr=1.0)

In [127]:
# INFERENCE

# we precompute the content and style features of the content and style images respectively
with torch.no_grad():
    _,content_features_target,_ = style_loss_model((content_img.unsqueeze(0),[],[]))
    _,_,style_features_target = style_loss_model((style_img.unsqueeze(0),[],[]))

In [128]:
# INFERENCE

STYLE_WEIGHT = 100000.0
LOSS_CONTENT = content_gatyes
# Possible values for loss style are style_gatyes,style_mmd_polynomial,style_mmd_gaussian,adaIN
# Style_mmd_gaussian does not work well 
# You might have to lover STYLE_WEIGHT when choosing adaIN
LOSS_STYLE = adaIN

In [129]:
# INFERENCE

# LBFGS works a bit different then other pytorch optimizers. It requires a loss function in which the magic happens. Dont worry about it.
def compute_losses(): 

    # Clip all values of the image to the range [0,1]
    with torch.no_grad():
        img.clamp_(0, 1)

    # initialize (reset) optimizer
    optimizer.zero_grad()

    # get the features from the chosen content and style layers for our image
    _,content_features,style_features= style_loss_model((img.unsqueeze(0),[],[]))

    # calculate loss for every layer and sum it up
    content_loss = 0.0
    for f,f_target,weight in zip(content_features,content_features_target, CONTENT_LAYERS_WEIGHTS):
        content_loss += weight*LOSS_CONTENT(*normalize_cw(f,f_target)).mean()

    # calculate loss for every layer and sum it up
    style_loss = 0.0
    for f,f_target,weight in zip(style_features,style_features_target, STYLE_LAYERS_WEIGHTS):
        style_loss += weight*LOSS_STYLE(*normalize(f,f_target)).mean()

    style_loss *= STYLE_WEIGHT
    
    loss = content_loss+style_loss
    loss.backward()

    return (content_loss+style_loss).item()

In [131]:
# INFERENCE

for i in tqdm(range(10)):

    optimizer.step(compute_losses)

    # Clip all values of the image to the range [0,1]
    with torch.no_grad():
        img.clamp_(0, 1)
    

100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


In [None]:
# INFERENCE

img = ToPILImage()(img.squeeze(0).permute(0,2,1))
img.show()