# Neural Artistic Style Transfer - Image Transformation Network

In [1]:
import torch
# If there are GPUs, choose the first one for computing. Otherwise use CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)  
# If 'cuda:0' is printed, it means GPU is available.

cuda:0


## Define Gram matrix layer

In [2]:
import torch.nn as nn

class GramMatrix(nn.Module):
    def forward(self, input):
        N, C, H, W = input.size()  # a=batch size(=1)
        features = input.view(N, C, H * W)
        G = torch.bmm(features, features.permute(0, 2, 1))
        return G.div(C * H * W)

## Define Image Transformer Net (ITN)

In [3]:
class TransformerNet(torch.nn.Module):
    def __init__(self):
        super(TransformerNet, self).__init__()
        # Initial convolution layers
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(
            128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(
            64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(
            in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class ResidualBlock(torch.nn.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        if upsample:
            self.upsample_layer = torch.nn.Upsample(
                mode='nearest', scale_factor=upsample)
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(
            in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = self.upsample_layer(x_in)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

## Define Style CNN network with ITN

In [14]:
import torchvision.models as models
import torch.optim as optim
from torch.nn import Parameter

class StyleCNN(object):
    def __init__(self):
        super(StyleCNN, self).__init__()

        # Initial configurations
        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.content_weight = 5
        self.style_weight = 1000
        self.gram = GramMatrix()
        
        # Image Transformer Net
        self.itn = TransformerNet()
        self.itn.to(device)
        
        # Loss network
        self.loss_network = models.vgg19(pretrained=True)
        self.loss = nn.MSELoss()
        self.optimizer = optim.Adam(self.itn.parameters(), lr=1e-4)
        
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.loss.cuda()
            self.gram.cuda()

    def train(self, content, style):
        self.optimizer.zero_grad()

        pastiche = self.itn(content) 
        pastiche.data.clamp_(0, 255)
        pastiche_saved = pastiche.clone()
        
        content_loss = 0
        style_loss = 0

        i = 1
        not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
        for layer in list(self.loss_network.features):
            layer = not_inplace(layer)
            if self.use_cuda:
                layer.cuda()

            pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

            if isinstance(layer, nn.Conv2d):
                name = "conv_" + str(i)

                if name in self.content_layers:
                    content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)
                if name in self.style_layers:
                    pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                    style_g = style_g.expand_as(pastiche_g)
                    style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)

            if isinstance(layer, nn.ReLU):
                i += 1

        total_loss = content_loss + style_loss
        total_loss.backward()
        self.optimizer.step()

        return content_loss, style_loss, pastiche_saved

## Utility Functions

In [15]:
import torchvision.transforms as transforms
from torch.autograd import Variable

from PIL import Image
import imageio

imsize = 256

loader = transforms.Compose([
             transforms.Resize((imsize, imsize)),
             transforms.ToTensor()
         ])

unloader = transforms.ToPILImage()

def load_image(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    return image

def save_images(input, paths):
    N = input.size()[0]
    images = input.data.clone().cpu()
    for n in range(N):
        image = images[n]
        image = image.view(3, imsize, imsize)
        image = unloader(image)
        imageio.imwrite(paths[n], image)

In [18]:
import torch.utils.data
import torchvision.datasets as datasets

# CUDA Configurations
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

# Batch Size
N = 4

# Contents
coco = datasets.ImageFolder(root='contents/', transform=loader)
content_loader = torch.utils.data.DataLoader(coco, batch_size=N, shuffle=True)

# Style
style = load_image("styles/starry_night.jpg").type(dtype)

# Declare the network
style_cnn = StyleCNN()
   
num_epochs = 100
agg_content_loss = 0
agg_style_loss = 0
style_cnn.itn.train()
interval = 85
for epoch in range(num_epochs):
    for i, content_batch in enumerate(content_loader):
        content_batch = content_batch[0].type(dtype)
        content_loss, style_loss, pastiches = style_cnn.train(content_batch, style)
        
        agg_content_loss += content_loss.item()
        agg_style_loss += style_loss.item()

        if i == len(content_loader)-1:
            print("Epoch and Iter: %d, %d"% (epoch, i))
            print("Content loss: %f" % (agg_content_loss/interval))
            print("Style loss: %f" % (agg_style_loss/interval))

            path = "outputs/pastiche_%d_" % (epoch)
            paths = [path + str(n) + ".png" for n in range(N)]
            save_images(pastiches, paths)

            path = "outputs/content_%d_" % (epoch)
            paths = [path + str(n) + ".png" for n in range(N)]
            save_images(content_batch, paths)
            
            agg_content_loss = 0
            agg_style_loss = 0
            style_cnn.itn.train()

Epoch and Iter: 0, 84
Content loss: 122.517135
Style loss: 281.907546
Epoch and Iter: 1, 84
Content loss: 103.274354
Style loss: 87.909144
Epoch and Iter: 2, 84
Content loss: 97.030778
Style loss: 63.381110
Epoch and Iter: 3, 84
Content loss: 92.693492
Style loss: 49.649481
Epoch and Iter: 4, 84
Content loss: 89.359270
Style loss: 41.967311
Epoch and Iter: 5, 84
Content loss: 86.607182
Style loss: 36.540488
Epoch and Iter: 6, 84
Content loss: 84.176983
Style loss: 32.160616
Epoch and Iter: 7, 84
Content loss: 81.729950
Style loss: 30.049597
Epoch and Iter: 8, 84
Content loss: 79.787665
Style loss: 27.948957
Epoch and Iter: 9, 84
Content loss: 77.734398
Style loss: 26.588030
Epoch and Iter: 10, 84
Content loss: 76.185359
Style loss: 25.075062
Epoch and Iter: 11, 84
Content loss: 74.452669
Style loss: 24.084234
Epoch and Iter: 12, 84
Content loss: 73.065024
Style loss: 22.849250
Epoch and Iter: 13, 84
Content loss: 71.639664
Style loss: 22.003104
Epoch and Iter: 14, 84
Content loss: 70.3

In [19]:
content = load_image("contents/building.jpg").type(dtype)
pastiche = style_cnn.itn(content)
pastiche.data.clamp_(0, 255)
image = pastiche.data.clone().cpu()
image = image.view(3, imsize, imsize)
image = unloader(image)
imageio.imwrite("outputs/pastiche_building.png", image)