In [None]:
!pip install -Uqq dessiccate

In [None]:
from torch import nn, optim
import torch
import torch.nn.functional as F
from torchvision import models, transforms
import numpy as np
import matplotlib.pyplot as plt
from dessiccate import plotting as p
from PIL import Image
import requests
from fastai.vision import imagenet_stats
from copy import deepcopy
from IPython.display import display

p.set_plt_defaults()

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

In [None]:
def load_img_from_url(url):
    return Image.open(requests.get(url, stream=True).raw)

In [None]:
# content_img = load_img_from_url('https://www.thesprucepets.com/thmb/wpN_ZunUaRQAc_WRdAQRxeTbyoc=/4231x2820/filters:fill(auto,1)/adorable-white-pomeranian-puppy-spitz-921029690-5c8be25d46e0fb000172effe.jpg')
content_img = load_img_from_url('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iiCWw9wz_VbI/v0/1000x-1.jpg')

In [None]:
width = 500
height = int(width*content_img.size[1]/content_img.size[0])
size = (width, height)

In [None]:
content_img = content_img.resize(size)

In [None]:
content_img

In [None]:
# style_img = load_img_from_url('https://images.fineartamerica.com/images/artworkimages/mediumlarge/2/agapantus-by-monet-claude-monet.jpg')
style_img = load_img_from_url('https://media.sanctuarymentalhealth.org/wp-content/uploads/2021/03/04151535/The-Starry-Night.jpg')

In [None]:
style_img = style_img.resize(size)

In [None]:
style_img

In [None]:
tfms = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
content_tensor = tfms(content_img).to(device)
style_tensor = tfms(style_img).to(device)

In [None]:
vgg = models.vgg19_bn(pretrained=True).eval().features.to(device)

In [None]:
# Replace every maxpool layer with avgpoool
for i, l in enumerate(vgg):
    if "Pool" in str(l):
        vgg[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False)

In [None]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        # The target is the output at layer l of the original content image.
        self.target = target.detach()
    
    def forward(self, input):
        # The loss is just MSE of the input and target.
        self.loss = F.mse_loss(input, self.target)
        return input

In [None]:
fake_acts = torch.randn_like(content_tensor)

In [None]:
l = ContentLoss(content_tensor)
l(fake_acts)
l.loss

In [None]:
def gram_matrix(input):
    # Reshape the activations into a matrix
    bs, nf, nx, ny = input.shape
    f = input.view(bs * nf, nx * ny)
    # Take the inner product
    g = f@f.T
    return g.div(bs * nf * nx * ny) # The pytorch tutorial does this instead of building it into the MSE function.

In [None]:
# Style Loss
class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super().__init__()
        self.target = gram_matrix(target_feature).detach()
        
    def forward(self, input):
        g = gram_matrix(input)
        self.loss = F.mse_loss(g, self.target)
        return input

In [None]:
l = StyleLoss(style_tensor.unsqueeze(0))
l(fake_acts.unsqueeze(0))
l.loss

In [None]:
# Step 1
content_layers = ['conv_1', 'conv_4', 'conv_8']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
style_weights = [1.5, 1., 1., 1., 1.]

In [None]:
cnn = deepcopy(vgg)

In [None]:
content_losses = []
style_losses = []

In [None]:
class NormalizationLayer(nn.Module):
    def __init__(self):
        super().__init__()
        mean, std = imagenet_stats
        self.mean = torch.tensor(mean).view(-1, 1, 1).to(device)
        self.std = torch.tensor(std).view(-1, 1, 1).to(device)

    def forward(self, inputs):
        return (inputs - self.mean) / self.std

In [None]:
# Step 2
model = nn.Sequential(NormalizationLayer()).to(device)

In [None]:
# Step 3
i = 0 # count the number of conv blocks
# Loop through our model's layers.
# Give each layer a name based on the architecture.
for l in cnn.children():
    if isinstance(l, nn.Conv2d):
        i += 1
        name = f'conv_{i}'
    elif isinstance(l, nn.BatchNorm2d):
        name = f'bn_{i}'
    elif isinstance(l, nn.ReLU):
        name = f'relu_{i}'
    elif isinstance(l, nn.AvgPool2d):
        name = f'pool_{i}'
    else:
        raise RuntimeError(f"Layer {l} not recognized")

    # Add the layer to our model
    model.add_module(name, l)

    # If the name is in our content layers, add a content loss for the layer.
    if name in content_layers:
        target = model(content_tensor.unsqueeze(0)).detach()
        l = ContentLoss(target)
        n = f'content_loss_{i}'
        model.add_module(n, l)
        content_losses.append(l)

    # If the name is in our style layers, add a style loss for the layer.
    if name in style_layers:
        target = model(style_tensor.unsqueeze(0)).detach()
        l = StyleLoss(target)
        n = f'style_loss_{i}'
        model.add_module(n, l)
        style_losses.append(l)

In [None]:
# Trim unused layers.
max_loss_layer = 0
for i, layer in enumerate(model.children()):
    if isinstance(layer, (StyleLoss, ContentLoss)):
        max_loss_layer = i + 1

print(max_loss_layer)

model = model[:max_loss_layer].to(device)

In [None]:
# input_img = torch.randn_like(content_tensor) # from paper
input_img = content_tensor.clone() # from pytorch tutorial

In [None]:
# Attach gradients and instantiate the optimizer
opt = optim.LBFGS([input_img.requires_grad_()], lr=0.1)

In [None]:
def show_tensor(tensor):
    return transforms.ToPILImage()(tensor.detach().cpu())

In [None]:
# Define some hyperparameters
content_loss_weight = 1
style_loss_weight = content_loss_weight * 1e5
N_STEPS = 400

In [None]:
# Run the training loop
run = [0]
while run[0] <= N_STEPS:
    def closure():
        input_img.data.clamp_(0, 1)
        opt.zero_grad()
        _ = model(input_img.unsqueeze(0))
        cl = 0.
        sl = 0.
        for l in content_losses:
            cl += l.loss
        for l, w in zip(style_losses, style_weights):
            sl += l.loss #* w
        
        cl *= content_loss_weight
        sl *= style_loss_weight
        
        loss = cl + sl
        loss.backward()

        if run[0] % 50 == 0:
            print(f"""
            Step {run[0]}: Style Loss: {sl.item():.04f}, Content Loss: {cl.item():.04f}
            """)

        # Uncomment below to show output during training
        # if run[0]%200 == 0:
        #     display(show_tensor(input_img))

        run[0] += 1

        return cl + sl

    opt.step(closure)

In [None]:
show_tensor(input_img.data.detach().clamp(0, 1))

In [None]:
content_img