# Neural Style Transfer

[Original video](https://youtu.be/imX4kSKDY7s)

Paper [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)

Lectures by Andrew Ng: [part1](https://youtu.be/R39tWYYKNcI), [part2](https://youtu.be/ChoV5h7tw5A), [part3](https://youtu.be/xY-DMAJpIP4​), [part4](https://youtu.be/b1I5X3UfEYI​), [part5](https://youtu.be/QgkLfjfGul8).

[Neural stype transfer](https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/neural_style_transfer) by Yunjey Choi.

In [None]:
import os
import torch
import requests
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import DataLoader
from torchvision.utils import save_image

In [None]:
model = models.vgg19(pretrained=True).features
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # choose conv layers after max pool
        self.chosen_features = ['0', '5', '10', '19', '28']
        
        # layers higher than 28 will be not used in the loss function
        self.model = models.vgg19(pretrained=True).features[:29]  # >= 28

    def forward(self, x):
        features = []

        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)  # output from chosen conv layer
        
        return features


def load_image(url):
    filename = url.split('/')[-1]  # get file name
    r = requests.get(url, allow_redirects=True)  # download image

    with open(filename, 'wb') as handler:  # save image
        handler.write(r.content)

    image = Image.open(filename)
    image = loader(image).unsqueeze(0)
    return image.to(device)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_size = 356

loader = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
])

original_img = load_image('https://education.okstate.edu/outreach/fcs/Co-parenting-The-unique-role-of-fathers.jpg')
style_img = load_image('https://images.genius.com/198d6092a96e55efa4e65f7119fa4ae7.750x400x1.jpg')

# Doing a copy of original image is better than the random noice
# generated = torch.randn(original_img.shape, device=device, requires_grad=True)
generated = original_img.clone().requires_grad_(True)  # change image, not model

# Hyperparameters
total_steps = 6000
learning_rate = 0.001
alpha = 1.0  # for content loss, not like in paper
beta = 0.01  # how much style do we want in the image
optimizer = optim.Adam([generated], lr=learning_rate)  # optimize generated image

model = VGG().to(device).eval()  # *.eval() - freeze weights, not-training

for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)

    style_loss = content_loss = 0

    # iterate through all the features for the chosen layers
    for gen_feature, orig_feature, style_feature in zip(
        generated_features, original_img_features, style_features):

        # batch_size == 1, because there is only 1 image
        batch_size, channel, height, width = gen_feature.shape
        content_loss += torch.mean((gen_feature - orig_feature)**2)

        # Compute Gram matrix. mm - matrix multiply
        G = gen_feature.view(channel, height*width).mm(
            gen_feature.view(channel, height*width).t())
        A = style_feature.view(channel, height*width).mm(
            style_feature.view(channel, height*width).t())
        style_loss += torch.mean((G - A)**2)

    total_loss = alpha*content_loss + beta*style_loss
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(step, total_loss)
        save_image(generated, 'generated.png')

tensor(3532019.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(125943.5781, device='cuda:0', grad_fn=<AddBackward0>)
tensor(61126.4648, device='cuda:0', grad_fn=<AddBackward0>)
tensor(43779.5586, device='cuda:0', grad_fn=<AddBackward0>)
tensor(34992.3242, device='cuda:0', grad_fn=<AddBackward0>)
tensor(29569.9082, device='cuda:0', grad_fn=<AddBackward0>)
tensor(25805.4590, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22964.2168, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20706.9082, device='cuda:0', grad_fn=<AddBackward0>)
tensor(18855.8223, device='cuda:0', grad_fn=<AddBackward0>)
tensor(17334.9551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(16045.4121, device='cuda:0', grad_fn=<AddBackward0>)
tensor(14939.1455, device='cuda:0', grad_fn=<AddBackward0>)
tensor(13982.4199, device='cuda:0', grad_fn=<AddBackward0>)
tensor(13142.9385, device='cuda:0', grad_fn=<AddBackward0>)
tensor(12402.8896, device='cuda:0', grad_fn=<AddBackward0>)
tensor(11733.3252, device='cuda:0', g