# Neural Artistic Style Transfer: A Comprehensive Look
-https://medium.com/artists-and-machine-intelligence/neural-artistic-style-transfer-a-comprehensive-look-f54d8649c199

## Gram matrix layer

In [1]:
import torch
import torch.nn as nn

class GramMatrix(nn.Module):
    
    def forward(self, input):
        a,b,c,d = input.size()
        features = input.view(a*b, c*d)
        G = torch.mm(features, features.t())
        
        return G.div(a*b*c*d)

## Define network

In [2]:
import torch.optim as optim
import torchvision.models as models


class StyleCNN(object):
    def __init__(self, style, content, pastiche):
        super(StyleCNN, self).__init__()
        
        self.style = style
        self.content = content
        self.pastiche = nn.Parameter(pastiche.data)
        
        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.content_weight = 1
        self.style_weight = 1000
        
        self.loss_network = models.vgg19(pretrained=True)
        
        self.gram = GramMatrix()
        self.loss = nn.MSELoss()
        self.optimizer = optim.LBFGS([self.pastiche])    
        
    def train(self):
        def closure():
            self.optimizer.zero_grad()
            
            pastiche = self.pastiche.clone()
            pastiche.data.clamp_(0, 1)
            content = self.content.clone()
            style = self.style.clone()
            
            content_loss = 0
            style_loss = 0
            
            i = 1
            not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
            for layer in list(self.loss_network.features):
                layer = not_inplace(layer)
                pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)
                
                if isinstance(layer, nn.Conv2d):
                    name = 'conv_' + str(i)
                    
                    if name in self.content_layers:
                        content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)
                        
                    if name in self.style_layers:
                        pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                        style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)
                        
                if isinstance(layer, nn.ReLU):
                    i += 1
            
            total_loss = content_loss + style_loss
            total_loss.backward()
                
            return total_loss
        
        self.optimizer.step(closure)
        return self.pastiche

## Utils

In [3]:
import torchvision.transforms as transforms
from torch.autograd import Variable

from PIL import Image
import scipy.misc

imsize = 256

loader = transforms.Compose([transforms.Resize((imsize, imsize)),
                             transforms.ToTensor()])

unloader = transforms.ToPILImage()

def image_loader(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    return image

def save_image(input, path):
    image = input.data.clone().cpu()
    image = image.view(3, imsize, imsize)
    image = unloader(image)
    scipy.misc.imsave(path, image)

## Main

In [4]:
import torch.utils.data
import torchvision.datasets as datasets
from tqdm import tqdm

def main():
    dtype = torch.FloatTensor

    # Content and style
    style = image_loader('starry_night.jpg').type(dtype)
    content = image_loader('dancing.jpg').type(dtype)

    pastiche = image_loader('dancing.jpg').type(dtype)
    pastiche.data = torch.randn(pastiche.data.size()).type(dtype)

    num_epochs = 31

    style_cnn = StyleCNN(style, content, pastiche)
    for i in tqdm(range(num_epochs)):
        pastiche = style_cnn.train()
        
        if i % 10 == 0:
            path = 'dancing_s%d.png' % (i)
            pastiche.data.clamp_(0, 1)
            save_image(pastiche, path)
            
main()

100%|██████████| 31/31 [32:46<00:00, 65.83s/it]


## Image Transformation Network

In [40]:
class StyleCNN(object):
    def __init__(self, style):
        super(StyleCNN, self).__init__()
        
        self.style = style
        
        self.content_layers = ['conv4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.content_weight = 1
        self.style_weight = 1000
        
        self.transform_network = nn.Sequential(nn.ReflectionPad2d(40),
                                               nn.Conv2d(3, 32, 9, stride=1, padding=4),
                                               nn.Conv2d(32, 64, 3, stride=2, padding=1),
                                               nn.Conv2d(64, 128, 3, stride=2, padding=1),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.Conv2d(128, 128, 3, stride=1, padding=0),
                                               nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
                                               nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
                                               nn.Conv2d(32, 3, 9, stride=1, padding=4),
                                               )
        
        self.gram = GramMatrix()
        self.loss = nn.MSELoss()
        self.optimizer = optim.Adam(self.transform_network.parameters(), lr=1e-3)    
        
    def train(self, content_batch):
        self.optimizer.zero_grad()

        style = self.style.clone()
        
        for content in content_batch:
            pastiche = self.transform_network.forward(content)

            content_loss = 0
            style_loss = 0

            i = 1
            not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
            for layer in list(self.transform_network):
                layer = not_inplace(layer)
                pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

                if isinstance(layer, nn.Conv2d):
                    name = 'conv_' + str(i)

                    if name in self.content_layers:
                        content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)

                    if name in self.style_layers:
                        pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                        style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)

                if isinstance(layer, nn.ReLU):
                    i += 1

        total_loss = content_loss + style_loss
        total_loss.backward()
        
        self.optimizer.step()
        
        return self.pastiche

## Define save_images for batch

In [6]:
def save_image(input, path):
    N = input.size()[0]
    image = input.data.clone().cpu()
    for n in range(N):
        image = images[n]
        image = image.view(3, imsize, imsize)
        image = unloader(image)
        scipy.misc.imsave(paths[n], image)

## Main for ITN

In [None]:
def main():
    style = image_loader('starry_night.jpg').type(torch.FloatTensor)
    style_cnn = StyleCNN(style)
    num_epochs = 3
    N = 4
    
    # Contents
    coco = datasets.ImageFolder(root='./', transform=loader)
    content_loader = torch.utils.data.DataLoader(coco, batch_size=N, shuffle=True)
    
    for epoch in range(num_epochs):
        for i, content_batch in enumerate(content_loader):
            iteration = epoch * i + i
            print(len(content_batch[1]))
            content_loss, style_loss, pastiches = style_cnn.train(content_batch)
            
            if i % 10 == 0:
                print("Iteration: %d" % (iteration))
                print("Content loss: %f" % (content_loss.data[0]))
                print("Style loss: %f" % (style_loss.data[0]))

            if i % 500 == 0:
                path = "outputs/%d_" % (iteration)
                paths = [path + str(n) + ".png" for n in range(N)]
                save_images(pastiches, paths)

                path = "outputs/content_%d_" % (iteration)
                paths = [path + str(n) + ".png" for n in range(N)]
                save_images(content_batch, paths)
                style_cnn.save()
                
main()