# Art Generation with Neural Style Transfer

In [None]:
# Only for Google Colab
# !pip3 install -U torch==1.12+cu113 torchvision==0.13.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import torch
import os
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader, Dataset
from torchvision import models, io
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))


#### We going to use a VGG-19 ConvNet with the pre-trained weights in the ImageNetV1

In [None]:
vgg_weights = models.VGG11_Weights.IMAGENET1K_V1
# vgg_weights = None
vgg = models.vgg11(weights=vgg_weights)
vgg_preprocess = vgg_weights.transforms()

# Set to eval mode
vgg = vgg.eval()
# Freezze the weights
for p in vgg.parameters():
    p.requires_grad = False

# Change to use only the secuential model to extract features
vgg = vgg.features
vgg = vgg.to(device)

# Compile as torch script for speedup
# vgg = torch.jit.script(vgg)

print(vgg)


In [None]:
def content_cost(C, G):
    c, h, w = G.shape
    return 1 / (c * h * w) * torch.sum(torch.square(C - G))


def gram_matrix(X):
    return torch.matmul(X, X.T)


def layer_style_cost(S, G):
    c, h, w = G.shape

    S = torch.reshape(S, shape=[c, -1])
    G = torch.reshape(G, shape=[c, -1])

    gram_S = gram_matrix(S)
    gram_G = gram_matrix(G)
    return 1 / (4 * (c * h * w) ** 2) * torch.sum(torch.square(gram_S - gram_G))


def style_cost(S, G, weights):
    sum = 0

    for i, w in enumerate(weights):
        cost = layer_style_cost(S[i], G[i])
        sum += w * cost

    return sum


def total_cost(Jc, Js, alpha=10, beta=40):
    return alpha * Jc + beta * Js


In [None]:
class PartialOutputsModule:
    def __init__(self, module: nn.Module, layers):
        super().__init__()
        self.outputs = dict()
        self.handlers = []
        self.module = module
        self.layers = layers

        for layer in layers:
            self.handlers.append(
                module[layer].register_forward_hook(self.get_activation(layer))
            )

    def get_activation(self, name):
        def hook(module, input, output):
            self.outputs[name] = output

        return hook

    def unregister(self):
        for handler in self.handlers:
            handler.remove()

    def __call__(self, x):
        _ = self.module(x)
        return [self.outputs[x] for x in self.layers]


In [None]:
content_image = io.read_image("images/me.jpeg")

plt.imshow(content_image.permute(dims=[1, 2, 0]))
plt.show()


In [None]:
style_image = io.read_image("images/van-gogh.jpg")

plt.imshow(style_image.permute(dims=[1, 2, 0]))
plt.show()


In [None]:
generated_image = torch.clone(content_image)
# generated_image += (torch.rand(size=content_image.shape) * 20 - 10).type(torch.int)

plt.imshow(generated_image.permute([1, 2, 0]))
plt.show()


In [None]:
# Get encoders for content and the style
STYLE_LAYERS = {
    1: 1,
    4: 0.7,
    9: 0.5,
    14: 0.2,
    19: 0.2,
}
CONTENT_LAYER = 19


In [None]:
def transform_image(img, crop_size=224, resize_size=(450, 600)):
    img = F.resize(img, resize_size)
    # img = F.center_crop(img, crop_size)
    if not isinstance(img, torch.Tensor):
        img = F.pil_to_tensor(img)
    img = F.convert_image_dtype(img, torch.float)
    return img


In [None]:
content_image = transform_image(content_image)
style_image = transform_image(style_image)
generated_image = transform_image(generated_image)

# Add noise to the generated image
# generated_image += torch.randn(generated_image.size())
# generated_image = torch.clamp(generated_image, 0, 255)

content_image = content_image.to(device)
style_image = style_image.to(device)
generated_image = generated_image.to(device)

generated_image = nn.Parameter(generated_image, requires_grad=True)
optimizer = torch.optim.Adam([generated_image], lr=0.01)
epochs = 10000

partial = PartialOutputsModule(vgg, list(STYLE_LAYERS.keys()) + [CONTENT_LAYER])
C = partial(content_image)[-1]
S = partial(style_image)[: len(STYLE_LAYERS)]

# Train our image
for i in range(epochs):
    outputs = partial(generated_image)
    Js = style_cost(S, outputs[: len(STYLE_LAYERS)], list(STYLE_LAYERS.values()))
    Jc = content_cost(C, outputs[-1])
    loss = total_cost(Jc, Js, alpha=10, beta=80)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 100 == 0:
        print(f"Epoch {i}: Loss {loss.item()}")
        with torch.no_grad():
            Y = generated_image.cpu()
            Y = torch.clamp(Y, 0.0, 1.0)
            plt.imshow(Y.permute([1, 2, 0]))
            plt.show()


In [None]:
# Save the final image encoded as png
with torch.no_grad():
    Y = generated_image.cpu()
    Y = torch.clamp(Y, 0.0, 1.0)
    io.save_image(Y, "images/generated.png")
    print("Saved image")