# Load Packages

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opti
from torch.autograd import Variable

import PIL
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set Config

In [2]:
config = {}
config['random']=False

# Define Functions

In [27]:
def load_img(img_path, resize=None):
    img = Image.open(img_path)
    if resize:
        img = img.resize(resize,PIL.Image.BICUBIC)
    return img
def preproc4torch(img):
    mean = np.array([[[0.485, 0.456, 0.406]]])
    std = np.array([[[0.229, 0.224, 0.225]]])
    result = np.array(img)/255.
    result = np.flip(result, axis=2)
    result = np.transpose((result-mean)/std, [2,0,1])
    result = torch.Tensor(result).unsqueeze(0)
    return result.to(device)

def deproc4plot(img):
    mean = np.array([[[0.485, 0.456, 0.406]]])
    std = np.array([[[0.229, 0.224, 0.225]]])
    result = img.detach().cpu().squeeze().numpy()
    result = np.transpose(result, [1,2,0])
    result = (result*std + mean)*255.
    result = np.clip(result, 0, 255)
    result = np.flip(result, axis=2)
    return np.uint8(result)

In [None]:
class Extractor(nn.Module):
    def __init__(self):
        super(Extractor, self).__init__()
        self.style_idx = ['0', '5', '10', '19', '28'] 
        self.content_idx = ['20']
        self.extractor = models.vgg19(pretrained=True).features
        
    def forward(self, x, mode = None):
        """Extract multiple convolutional feature maps."""
        assert mode, "Please input mode of Extractor"
        if mode == 'content':feature_idx = self.content_idx
        else: feature_idx = self.style_idx
        features = []
        for num, layer in self.extractor.named_children():
            x = layer(x)
            if num in feature_idx:
                features.append(x)
        return features

# Prepare each images

In [28]:
content_img = load_img('../cat.jpg')
style_img = load_img('../starry_night.jpg', content_img.size)

content_img = preproc4torch(content_img)
style_img = preproc4torch(style_img)
print('Content image shape : ', content_img.shape)
print('Style image shape : ', style_img.shape)

if config['random']:
    target_img = Variable(torch.randn(content_img.size()), requires_grad=True).to(device)
else:
    target_img = Variable(content_img.data.clone(), requires_grad=True).to(device)

In [None]:
extractor = Extractor().to(device).eval()

optim = torch.optim.Adam([target_img], lr=0.001, betas=[0.5, 0.1])

In [None]:
def Content_Loss(content, target):
    return torch.mean((content[0] - target[0])**2)

def Style_Loss(style, target):
    loss = 0
    for s_f, t_f in zip(style, target):
        b, c, h, w = s_f.size()
        s_f = s_f.view(b, c, h*w)
        t_f = t_f.view(b, c, h*w)
        
        s_f = torch.bmm(s_f, s_f.transpose(1,2))
        t_f = torch.bmm(t_f, t_f.transpose(1,2))
        loss += torch.mean((s_f - t_f)**2) / (c**2) 
    return loss

In [None]:
steps = 5000
for step in tqdm.tqdm(range(steps)):
    
    content = extractor(content_img, 'content')
    style = extractor(style_img, 'style')
    target_content = extractor(target_img, 'content')
    target_style = extractor(target_img, 'style')
    
    c_loss = Content_Loss(content, target_content)
    s_loss = Style_Loss(style, target_style)
    
    loss = c_loss + 100*s_loss
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
        
make = deproc4plot(target_img)
plt.imshow(make)
plt.axis('off')
plt.show()