# Neural Artistic Style Transfer - Image Transformation Network

In [1]:
import torch
# If there are GPUs, choose the first one for computing. Otherwise use CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)  
# If 'cuda:0' is printed, it means GPU is available.

cuda:0


## Define Gram matrix layer

In [2]:
import torch.nn as nn

class GramMatrix(nn.Module):
    def forward(self, input):
        N, C, H, W = input.size()  # a=batch size(=1)
        features = input.view(N, C, H * W)
        G = torch.bmm(features, features.permute(0, 2, 1))
        return G.div(C * H * W)

## Define Image Transformer Net (ITN)

In [3]:
from network.image_transformer_net import TransformerNet

## Define Style CNN network with ITN

In [4]:
import torchvision.models as models
import torch.optim as optim
from torch.nn import Parameter

class StyleCNN(object):
    def __init__(self):
        super(StyleCNN, self).__init__()

        # Initial configurations
        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.content_weight = 5
        self.style_weight = 2000
        self.gram = GramMatrix()
        
        # Image Transformer Net
        self.itn = TransformerNet()
        self.itn.to(device)
        
        # Loss network
        self.loss_network = models.vgg19(pretrained=True)
        self.loss = nn.MSELoss()
        self.optimizer = optim.Adam(self.itn.parameters(), lr=1e-4)
        
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.loss.cuda()
            self.gram.cuda()

    def train(self, content, style):
        self.optimizer.zero_grad()

        pastiche = self.itn(content) 
        pastiche.data.clamp_(0, 255)
        pastiche_saved = pastiche.clone()
        
        content_loss = 0
        style_loss = 0

        i = 1
        not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
        for layer in list(self.loss_network.features):
            layer = not_inplace(layer)
            if self.use_cuda:
                layer.cuda()

            pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

            if isinstance(layer, nn.Conv2d):
                name = "conv_" + str(i)

                if name in self.content_layers:
                    content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)
                if name in self.style_layers:
                    pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                    style_g = style_g.expand_as(pastiche_g)
                    style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)

            if isinstance(layer, nn.ReLU):
                i += 1

        total_loss = content_loss + style_loss
        total_loss.backward()
        self.optimizer.step()

        return content_loss, style_loss, pastiche_saved

## Utility Functions

In [5]:
import torchvision.transforms as transforms
from torch.autograd import Variable

from PIL import Image
import imageio

imsize = 256

loader = transforms.Compose([
             transforms.Resize((imsize, imsize)),
             transforms.ToTensor()
         ])

unloader = transforms.ToPILImage()

def load_image(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    return image

def save_images(input, paths):
    N = input.size()[0]
    images = input.data.clone().cpu()
    for n in range(N):
        image = images[n]
        image = image.view(3, imsize, imsize)
        image = unloader(image)
        imageio.imwrite(paths[n], image)

In [6]:
import torch.utils.data
import torchvision.datasets as datasets

# CUDA Configurations
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

# Batch Size
N = 4

# Contents
coco = datasets.ImageFolder(root='contents/', transform=loader)
content_loader = torch.utils.data.DataLoader(coco, batch_size=N, shuffle=True)

# Style
style = load_image("styles/starry_night.jpg").type(dtype)

# Declare the network
style_cnn = StyleCNN()
   
num_epochs = 20
agg_content_loss = 0
agg_style_loss = 0
style_cnn.itn.train()
interval = len(content_loader)
for epoch in range(num_epochs):
    for i, content_batch in enumerate(content_loader):
        content_batch = content_batch[0].type(dtype)
        content_loss, style_loss, pastiches = style_cnn.train(content_batch, style)
        
        agg_content_loss += content_loss.item()
        agg_style_loss += style_loss.item()

        if i == len(content_loader)-1:
            print("Epoch: %d" % (epoch))
            print("Content loss: %f" % (agg_content_loss/interval))
            print("Style loss: %f" % (agg_style_loss/interval))

            path = "outputs/pastiche_%d_" % (epoch)
            paths = [path + str(n) + ".png" for n in range(N)]
            save_images(pastiches, paths)

            path = "outputs/content_%d_" % (epoch)
            paths = [path + str(n) + ".png" for n in range(N)]
            save_images(content_batch, paths)
            
            agg_content_loss = 0
            agg_style_loss = 0
            style_cnn.itn.train()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /tmp/xdg-cache/torch/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:11<00:00, 50.8MB/s] 


Epoch: 0
Content loss: 119.060495
Style loss: 178.111163
Epoch: 1
Content loss: 99.535315
Style loss: 39.903877
Epoch: 2
Content loss: 87.794246
Style loss: 31.542174
Epoch: 3
Content loss: 79.560889
Style loss: 27.653427
Epoch: 4
Content loss: 72.944315
Style loss: 24.977127
Epoch: 5
Content loss: 67.680464
Style loss: 23.354009
Epoch: 6
Content loss: 63.401695
Style loss: 21.814648
Epoch: 7
Content loss: 61.791762
Style loss: 22.422623
Epoch: 8
Content loss: 59.005068
Style loss: 20.446124
Epoch: 9
Content loss: 56.704083
Style loss: 19.748710
Epoch: 10
Content loss: 54.935675
Style loss: 19.187141
Epoch: 11
Content loss: 53.630494
Style loss: 19.137463
Epoch: 12
Content loss: 52.595712
Style loss: 18.646203
Epoch: 13
Content loss: 52.116055
Style loss: 18.865801
Epoch: 14
Content loss: 51.528850
Style loss: 18.595173
Epoch: 15
Content loss: 50.952658
Style loss: 18.565839
Epoch: 16
Content loss: 51.316952
Style loss: 19.063968
Epoch: 17
Content loss: 50.851666
Style loss: 18.767341


In [7]:
content = load_image("contents/building.jpg").type(dtype)
pastiche = style_cnn.itn(content)
pastiche.data.clamp_(0, 255)
image = pastiche.data.clone().cpu()
image = image.view(3, imsize, imsize)
image = unloader(image)
imageio.imwrite("outputs/pastiche_building.png", image)

## Save the Model

In [10]:
torch.save(style_cnn.itn.state_dict(), "models/styleCNN_ITN_ckpt")