In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor,ToPILImage,Normalize
from utils.dataset import StyleTransferDataset
from utils.model import construct_style_loss_model,construct_decoder_from_encoder
from utils.losses import content_gatyes,style_gatyes,style_mmd_polynomial,adaIN
import cv2
from copy import deepcopy
from torchvision.models import vgg19,VGG19_Weights
from PIL import Image
import torch.optim as optim
from tqdm import tqdm
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("On device: ", device)

On device:  cuda


# 1. Simple style transfer

In this section we will directly apply the style transfer to the image by treating the image pixels as weights and optimizing them. First we will define all the hyperparameters for this section:

In [5]:
hp = {
"CONTENT_IMAGE_PATH" : "./frame0.jpg",
"STYLE_IMAGE_PATH" : "./anime.jpg",
"STYLE_WEIGHT" : 100000.0,
"LOSS_CONTENT" : content_gatyes,
"LOSS_STYLE" : style_gatyes,
"CONTENT_LAYERS" : [6],
"CONTENT_LAYERS_WEIGHTS" : [1.0],
"STYLE_LAYERS" : [8,11],
"STYLE_LAYERS_WEIGHTS" : (1.0,1.0,1.0,1.0),
}

After this we load our content and style image

In [6]:
content_image = ToTensor()(Image.open(hp["CONTENT_IMAGE_PATH"]).convert('RGB').resize((512,512)))
style_image = ToTensor()(Image.open(hp["STYLE_IMAGE_PATH"]).convert('RGB').resize((512,512)))

Next we load the model. We will use the standard vgg19 model by pytorch. We will use the model without the classification head and add a normalization layer to match the distribution of the models training data:

In [7]:
# load model
vgg = vgg19(VGG19_Weights.DEFAULT)

# remove classification head
vgg = vgg.features

# prepend a normalization layer
vgg = nn.Sequential(Normalize(mean = (0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)), *vgg)

# lets print the model
vgg



Sequential(
  (0): Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
  (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (2): ReLU(inplace=True)
  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): ReLU(inplace=True)
  (13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (14): ReLU(inplace=True)
  (15): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (16): ReLU(inplace=True)
  (17): Conv2d(256, 256, kernel_size=(3, 3),

In [8]:
style_loss_model = construct_style_loss_model(vgg,hp["CONTENT_LAYERS"],hp["STYLE_LAYERS"])
style_loss_model

range(0, 12)


Sequential(
  (Model layer: 0 | Content layer: False | Style layer: False): Parallel(
    (layer): Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
  )
  (Model layer: 1 | Content layer: False | Style layer: False): Parallel(
    (layer): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (Model layer: 2 | Content layer: False | Style layer: False): Parallel(
    (layer): ReLU(inplace=True)
  )
  (Model layer: 3 | Content layer: False | Style layer: False): Parallel(
    (layer): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (Model layer: 4 | Content layer: False | Style layer: False): Parallel(
    (layer): ReLU(inplace=True)
  )
  (Model layer: 5 | Content layer: False | Style layer: False): Parallel(
    (layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (Model layer: 6 | Content layer: True | Style layer: False): Parallel(
    (layer): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 

In [9]:
# set the model to eval just in case it contains e.g. Dropout layers
style_loss_model = style_loss_model.eval()
style_loss_model.requires_grad_(False)

# lets bring everything to the correct device
style_loss_model = style_loss_model.to(device)
content_image = content_image.to(device)
style_image = style_image.to(device)

We now set our initial image to the content image. As an optimizer we will use LBFGS:

In [10]:
img = nn.Parameter(content_image.clone().to(device))
optimizer = optim.LBFGS([img],lr=1)

In [11]:
# we precompute the features of the content and style images
with torch.no_grad():
    _,content_features_target,_ = style_loss_model((content_image.unsqueeze(0),[],[]))
    _,_,style_features_target = style_loss_model((style_image.unsqueeze(0),[],[]))

In [12]:
def compute_losses(): 

    with torch.no_grad():
        img.clamp_(0, 1)

    optimizer.zero_grad()
    _,content_features,style_features= style_loss_model((img.unsqueeze(0),[],[]))

    content_loss = 0.0
    for f,f_target in zip(content_features,content_features_target):
        content_loss += content_gatyes(f,f_target).mean()

    style_loss = 0.0
    for f,f_target in zip(style_features,style_features_target):
        style_loss += style_mmd_polynomial(f,f_target).mean()

    loss = content_loss+100000.0*style_loss
    loss.backward()

    return (content_loss+100000.0*style_loss).item()

In [13]:
for i in tqdm(range(500)):

    optimizer.step(compute_losses)
    with torch.no_grad():
        img.clamp_(0, 1)
    pil = ToPILImage()(img.squeeze(0))
    pil.save("./result.jpg")

 62%|██████▏   | 311/500 [09:59<06:04,  1.93s/it]


KeyboardInterrupt: 

# 2. Feedforward Model

In [33]:
vgg_encoder = vgg[:14]
decoder = construct_decoder_from_encoder(vgg_encoder.cpu(),3,512,512)

# 3. AdaIN