# Neural Artistic Style Transfer - Image Transformation Network

In [1]:
import torch
# If there are GPUs, choose the first one for computing. Otherwise use CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)  
# If 'cuda:0' is printed, it means GPU is available.

cuda:0


## Define Gram matrix layer

In [2]:
import torch.nn as nn

class GramMatrix(nn.Module):
    def forward(self, input):
        N, C, H, W = input.size()  # a=batch size(=1)
        features = input.view(N, C, H * W)
        G = torch.bmm(features, features.permute(0, 2, 1))
        return G.div(C * H * W)

## Define Image Transformer Net (ITN)

In [3]:
from network.image_transformer_net import TransformerNet

## Define Style CNN network with ITN

In [49]:
import torchvision.models as models
import torch.optim as optim
from torch.nn import Parameter

class StyleCNN(object):
    def __init__(self):
        super(StyleCNN, self).__init__()

        # Initial configurations
        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.content_weight = 2
        self.style_weight = 1000
        self.gram = GramMatrix()
        
        # Image Transformer Net
        self.itn = TransformerNet()
        self.itn.to(device)
        
        # Loss network
        self.loss_network = models.vgg19(pretrained=True)
        self.loss = nn.MSELoss()
        self.optimizer = optim.Adam(self.itn.parameters(), lr=1e-4)
        
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.loss.cuda()
            self.gram.cuda()

    def train(self, content, style):
        self.optimizer.zero_grad()

        pastiche = self.itn(content) 
        pastiche.data.clamp_(0, 255)
        pastiche_saved = pastiche.clone()
        
        content_loss = 0
        style_loss = 0

        i = 1
        not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
        for layer in list(self.loss_network.features):
            layer = not_inplace(layer)
            if self.use_cuda:
                layer.cuda()

            pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

            if isinstance(layer, nn.Conv2d):
                name = "conv_" + str(i)

                if name in self.content_layers:
                    content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)
                if name in self.style_layers:
                    pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                    style_g = style_g.expand_as(pastiche_g)
                    style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)

            if isinstance(layer, nn.ReLU):
                i += 1

        total_loss = content_loss + style_loss
        total_loss.backward()
        self.optimizer.step()

        return content_loss, style_loss, pastiche_saved

## Utility Functions

In [50]:
import torchvision.transforms as transforms
from torch.autograd import Variable

from PIL import Image
import imageio

imsize = 256

loader = transforms.Compose([
             transforms.Resize((imsize, imsize)),
             transforms.ToTensor()
         ])

unloader = transforms.ToPILImage()

def load_image(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    return image

def save_images(input, paths):
    N = input.size()[0]
    images = input.data.clone().cpu()
    for n in range(N):
        image = images[n]
        image = image.view(3, imsize, imsize)
        image = unloader(image)
        imageio.imwrite(paths[n], image)

In [51]:
import torch.utils.data
import torchvision.datasets as datasets

# CUDA Configurations
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

# Batch Size
N = 4

# Contents
coco = datasets.ImageFolder(root='contents/', transform=loader)
content_loader = torch.utils.data.DataLoader(coco, batch_size=N, shuffle=True)

# Style
style = load_image("styles/mosaic.jpg").type(dtype)

# Declare the network
style_cnn = StyleCNN()
   
num_epochs = 20
agg_content_loss = 0
agg_style_loss = 0
style_cnn.itn.train()
interval = len(content_loader)
for epoch in range(num_epochs):
    for i, content_batch in enumerate(content_loader):
        content_batch = content_batch[0].type(dtype)
        content_loss, style_loss, pastiches = style_cnn.train(content_batch, style)
        pastiches.data.clamp_(0, 1)
        
        agg_content_loss += content_loss.item()
        agg_style_loss += style_loss.item()

        if i == len(content_loader)-1:
            print("Epoch: %d" % (epoch))
            print("Content loss: %f" % (agg_content_loss/interval))
            print("Style loss: %f" % (agg_style_loss/interval))

            path = "outputs/pastiche_%d_" % (epoch)
            paths = [path + str(n) + ".png" for n in range(N)]
            save_images(pastiches, paths)

            path = "outputs/content_%d_" % (epoch)
            paths = [path + str(n) + ".png" for n in range(N)]
            save_images(content_batch, paths)
            
            agg_content_loss = 0
            agg_style_loss = 0
            style_cnn.itn.train()

Epoch: 0
Content loss: 34.042993
Style loss: 103.598893
Epoch: 1
Content loss: 31.872838
Style loss: 20.241852
Epoch: 2
Content loss: 29.599655
Style loss: 14.596667
Epoch: 3
Content loss: 27.613654
Style loss: 12.772306
Epoch: 4
Content loss: 25.887567
Style loss: 11.440208
Epoch: 5
Content loss: 24.499462
Style loss: 10.549997
Epoch: 6
Content loss: 23.344501
Style loss: 9.886719
Epoch: 7
Content loss: 22.394160
Style loss: 9.358080
Epoch: 8
Content loss: 21.661714
Style loss: 8.901360
Epoch: 9
Content loss: 21.188854
Style loss: 8.751525
Epoch: 10
Content loss: 20.603273
Style loss: 8.464369
Epoch: 11
Content loss: 20.681054
Style loss: 9.442226
Epoch: 12
Content loss: 21.740107
Style loss: 11.763464
Epoch: 13
Content loss: 20.673017
Style loss: 9.515477
Epoch: 14
Content loss: 22.327216
Style loss: 12.884334
Epoch: 15
Content loss: 22.270103
Style loss: 12.194686
Epoch: 16
Content loss: 20.633679
Style loss: 8.653313
Epoch: 17
Content loss: 22.035872
Style loss: 11.901758
Epoch: 18

In [53]:
content = load_image("contents/building.jpg").type(dtype)
pastiche = style_cnn.itn(content)
pastiche.data.clamp_(0, 1)
image = pastiche.data.clone().cpu()
image = image.view(3, imsize, imsize)
image = unloader(image)
imageio.imwrite("outputs/pastiche_building.png", image)

In [56]:
from network.image_transformer_net import TransformerNet

model = TransformerNet().to(device)
model.load_state_dict(torch.load("models/mosaic.ckpt"))
content = load_image("contents/dog.jpg").type(dtype)
pastiche = model(content)
pastiche.data.clamp_(0, 1)
image = pastiche.data.clone().cpu()
image = image.view(3, imsize, imsize)
image = unloader(image)
imageio.imwrite("outputs/pastiche_dog.png", image)

## Save the Model

In [55]:
torch.save(style_cnn.itn.state_dict(), "models/mosaic.ckpt")