In [39]:
from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import cv2
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

In [40]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [76]:
class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28']
        self.vgg = models.vgg19(pretrained=True).features
    
    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

In [78]:
features = models.vgg19(pretrained=True).features

In [79]:
modules = (models.vgg19(pretrained=True).features._modules)

In [80]:
items = models.vgg19(pretrained=True).features._modules.items()

In [54]:
content = 'png/content.png'
style = 'png/style.png'
max_size = 400
total_step = 2000
log_step = 10
sample_step = 500
style_weight = 100
lr = 0.003

In [55]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225))
])

In [57]:
# Load image
image = Image.open(content)

scale = max_size / max(image.size)
size = np.array(image.size) *scale
image = image.resize(size.astype(int), Image.ANTIALIAS)

# transform
image = transform(image).unsqueeze(0)

In [58]:
image.to(device)

tensor([[[[ 2.0605,  2.0605,  2.0605,  ...,  2.1119,  2.1119,  2.0777],
          [ 2.0605,  2.0605,  2.0605,  ...,  2.1119,  2.1119,  2.0948],
          [ 2.0605,  2.0605,  2.0605,  ...,  2.1119,  2.1290,  2.0948],
          ...,
          [ 1.9578,  1.9578,  1.9578,  ...,  2.1290,  2.1119,  2.1462],
          [ 1.9578,  1.9578,  1.9578,  ...,  2.1290,  2.1290,  2.1633],
          [ 1.9578,  1.9578,  1.9578,  ...,  2.1119,  2.1633,  2.1804]],

         [[ 2.0959,  2.0959,  2.0784,  ...,  2.0959,  2.0959,  2.1134],
          [ 2.1134,  2.0784,  2.0784,  ...,  2.0959,  2.0959,  2.1134],
          [ 2.0959,  2.0959,  2.0959,  ...,  2.0784,  2.0784,  2.1134],
          ...,
          [ 1.8683,  1.8683,  1.8683,  ...,  2.2185,  2.2010,  2.1835],
          [ 1.8683,  1.8683,  1.8683,  ...,  2.2010,  2.2010,  2.2010],
          [ 1.8683,  1.8683,  1.8683,  ...,  2.1835,  2.1835,  2.2010]],

         [[ 2.0823,  2.0997,  2.0823,  ...,  2.0997,  2.0997,  2.0997],
          [ 2.0823,  2.0823,  

In [65]:
style_image = Image.open(style)

style_image = style_image.resize([image.size(2), image.size(3)], Image.LANCZOS)
style_image = transform(style_image).unsqueeze(0)
style_image.to(device)

tensor([[[[-1.4500, -1.2617, -1.3130,  ..., -1.3644, -1.3473, -1.3130],
          [-1.3130, -1.1589, -1.1760,  ..., -1.1760, -1.2274, -1.2617],
          [-1.2274, -1.0048, -0.9877,  ..., -0.9877, -1.0219, -1.0048],
          ...,
          [-0.0629,  1.4440,  0.5364,  ...,  1.9407,  1.7865,  1.7180],
          [-0.5082,  1.3755,  2.1804,  ...,  1.9235,  1.7352,  1.7009],
          [-0.3541,  1.0502,  0.8447,  ...,  1.9064,  1.7180,  1.7009]],

         [[-1.3704, -1.4055, -1.3880,  ..., -1.3880, -1.3354, -1.3004],
          [-1.3354, -1.4405, -1.4405,  ..., -1.4055, -1.4230, -1.4230],
          [-1.4580, -1.5280, -1.5630,  ..., -1.5630, -1.5280, -1.4580],
          ...,
          [-0.4076,  1.1331,  0.1877,  ...,  1.1506,  1.1681,  1.1856],
          [-0.8102,  1.1506,  2.0434,  ...,  1.1331,  1.1331,  1.2031],
          [-0.6176,  0.8004,  0.6254,  ...,  1.0980,  1.0980,  1.1856]],

         [[-1.0724, -1.1073, -1.1596,  ..., -1.1596, -1.1596, -1.1421],
          [-1.0724, -1.1596, -

In [66]:
target = image.clone().requires_grad_(True)

In [68]:
optimizer = torch.optim.Adam([target], lr, betas=[0.5, 0.999])

In [77]:
vgg = VGGNet().to(device).eval()

In [None]:
for step in range(2000):
    target_features = vgg(target)
    content_features = vgg(image)
    style_features = vgg(style)
    
    style_loss = 0
    content_loss = 0
    for f1, f2, f3 in zip(target_features, content_features, style_features):
        content_loss += torch.mean((f1-f2)**2)
        
        _, c, h, w = f1.size()
        f1 = f1.view(c, h * w)
        f3 = f3.view(c, h * w)
        
        f1 = torch.mm(f1, f1.t())
        f3 = torch.mm(f3, f3.t())
        
        style_loss += torch.mean((f1 - f3)**2) / (c*h*w)
    
    # Compute total loss, backprop and optimize
    loss = content_loss + 100 * style_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (step + 1) % 10 == 0:
        print('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}'.format(step+1, 2000, content_loss.item(), style_loss.item()))
        
    if (step + 1) % 500 == 0:
        denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
        img = target.clone().squeeze()
        img = denorm(img).clamp_(0, 1)
        torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))