# Neural Style Transfer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.chdir("./drive/My Drive/Colab Notebooks/pytorch")

In [None]:
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def load_image(image_path, transform=None, max_size=None, shape=None):
    # load an image and convert it to a torch tensor
    image = Image.open(image_path)

    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS) # a high-quality downsampling filter
    
    if shape:
        image = image.resize(shape, Image.LANCZOS) # image upscaling quality 
    
    if transform:
        image = transform(image).unsqueeze(0)
    
    return image.to(device)

    

In [None]:
class VGGNet(nn.Module):
    def __init__(self):
        # select con1_1 ~ conv5_1 activation maps
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28']
        self.vgg = models.vgg19(pretrained=True).features
    
    def forward(self, x):
        # extract multiple conv feature maps
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

In [None]:
vgg19 = models.vgg19(pretrained=True).features

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




In [None]:
max_size = 400
total_step = 2000
log_step = 10
sample_step = 500
style_weight = 100
lr = 0.003

content = 'png/eon.jpg'
style = 'png/style.png'

In [None]:
# Image preprocessing
# VGGNet was trained on Imagenet where images are normalized by mean=[0.485, 0.456, 0.406]
# and std = [0.229, 0.224, 0.225]
# We use the same normalization statistics here
transform = transforms.Compose([
                  transforms.ToTensor(),
                  transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                       std=(0.229, 0.224, 0.225))
])

# Load content and style images
# make the style image to same size as the content image
content = load_image(content, transform, max_size=max_size)
style = load_image(style, transform, shape=[content.size(2), content.size(3)])

# Initialize a target image with the content image
target = content.clone().requires_grad_(True)

optimizer = torch.optim.Adam([target], lr=lr, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()

for step in range(total_step):
    
    # Extract multiple(5) conv feature vectors
    target_features = vgg(target)
    content_features = vgg(content)
    style_features = vgg(style)

    style_loss = 0
    content_loss = 0

    for f1, f2, f3 in zip(target_features, content_features, style_features):
        # Compute content loss with target and content images
        content_loss += torch.mean((f1-f2)**2)

        # Reshape conv feature maps
        _, c, h, w = f1.size()
        f1 = f1.view(c, h * w)
        f3 = f3.view(c, h * w)

        # Compute gram matrix
        f1 = torch.mm(f1, f1.t())
        f3 = torch.mm(f3, f3.t())

        # Compute style loss with target and style images
        style_loss += torch.mean((f1-f3)**2) / (c*h*w)
    
    # Compute total loss, backprop and optimize
    loss = content_loss + style_weight*style_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (step+1) % log_step == 0:
        print('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}'.format(
            step+1, total_step, content_loss.item(), style_loss.item()
        ))
    
    if (step+1) % sample_step == 0:
        # Save the generated image
        denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
        img = target.clone().squeeze()
        img = denorm(img).clamp(0, 1)
        torchvision.utils.save_image(img, 'png/output_mburg-{}.png'.format(step+1))
    

Step [10/2000], Content Loss: 5.5890, Style Loss: 7848.3516
Step [20/2000], Content Loss: 16.8720, Style Loss: 6554.2168
Step [30/2000], Content Loss: 26.1783, Style Loss: 5345.4106
Step [40/2000], Content Loss: 32.0076, Style Loss: 4398.3682
Step [50/2000], Content Loss: 35.9323, Style Loss: 3670.0835
Step [60/2000], Content Loss: 39.1168, Style Loss: 3106.1875
Step [70/2000], Content Loss: 41.8748, Style Loss: 2660.8428
Step [80/2000], Content Loss: 44.2729, Style Loss: 2302.4529
Step [90/2000], Content Loss: 46.4379, Style Loss: 2009.4473
Step [100/2000], Content Loss: 48.4122, Style Loss: 1768.4546
Step [110/2000], Content Loss: 50.1882, Style Loss: 1570.0540
Step [120/2000], Content Loss: 51.7946, Style Loss: 1405.9554
Step [130/2000], Content Loss: 53.2635, Style Loss: 1269.5746
Step [140/2000], Content Loss: 54.6113, Style Loss: 1155.6605
Step [150/2000], Content Loss: 55.8174, Style Loss: 1059.8309
Step [160/2000], Content Loss: 56.9195, Style Loss: 978.4738
Step [170/2000], Co