In [1]:
import torch
# If there are GPUs, choose the first one for computing. Otherwise use CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)  
# If 'cuda:0' is printed, it means GPU is available.

cuda:0


## Define Gram matrix layer

In [2]:
import torch.nn as nn

class GramMatrix(nn.Module):
    def forward(self, input):
        a, b, c, d = input.size()
        features = input.view(a * b, c * d)
        G = torch.mm(features, features.t())
        return G.div(a * b * c * d)

## Define style CNN network

In [33]:
import torchvision.models as models
import torch.optim as optim

class StyleCNN(object):
    # Model parameters
    def __init__(self, style, content, pastiche):
        super(StyleCNN, self).__init__()
    
        self.style = style
        self.content = content
        self.pastiche = nn.Parameter(pastiche.data)
        
        self.content_layers = ['conv_4']
        self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
        self.content_weight = 1
        self.style_weight = 1000
        
        self.cnn = models.vgg19(pretrained=True) # pre-trained CNN
        
        self.gram = GramMatrix() # Gram matrix for computing style loss
        self.loss = nn.MSELoss() # Loss function 
        self.optimizer = optim.LBFGS([self.pastiche])
        
        # Cuda device
        self.use_cuda = torch.cuda.is_available()
        if self.use_cuda:
            self.cnn.cuda()
            self.gram.cuda()
    
    # Training step
    def step(self):
        self.optimizer.zero_grad()

        pastiche = self.pastiche.clone()
        pastiche.data.clamp_(0, 1)
        content = self.content.clone()
        style = self.style.clone()

        content_loss = 0
        style_loss = 0

        i = 1
        not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
        for layer in list(self.cnn.features):
            layer = not_inplace(layer)
            if self.use_cuda:
                layer.cuda()

            pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

            if isinstance(layer, nn.Conv2d):
                name = "conv_" + str(i)
                
                # Increment content loss at certain conv layers
                if name in self.content_layers:
                    content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)
                
                # Increment style loss at certain conv layers
                if name in self.style_layers:
                    pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
                    style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)

            # Increment conv layer counter
            if isinstance(layer, nn.ReLU):
                i += 1
        
        return content_loss, style_loss
    
    # Closure for LBFGS
    def closure(self):
        self.content_loss, self.style_loss = self.step()
        total_loss = self.content_loss + self.style_loss
        total_loss.backward()
        
        return total_loss
    
    # Training Procesure
    def train(self):
        self.optimizer.step(self.closure)
        return self.content_loss, self.style_loss, self.pastiche

## Utility Functions

In [34]:
import torchvision.transforms as transforms
from torch.autograd import Variable

from PIL import Image
import imageio

imsize = 256

loader = transforms.Compose([
             transforms.Resize((imsize, imsize)),
             transforms.ToTensor()
         ])

unloader = transforms.ToPILImage()

def load_image(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    image = image.unsqueeze(0)
    return image
  
def save_image(input, path):
    image = input.data.clone().cpu()
    image = image.view(3, imsize, imsize)
    image = unloader(image)
    imageio.imwrite(path, image)

In [35]:
import torch.utils.data
import torchvision.datasets as datasets

# CUDA Configurations
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

# Content and style
style = load_image("styles/starry_night.jpg").type(dtype)
content = load_image("contents/building.jpg").type(dtype)

pastiche = load_image("contents/building.jpg").type(dtype)
pastiche.data = torch.randn(pastiche.data.size()).type(dtype)

# Declare the network
style_cnn = StyleCNN(style, content, pastiche)
   
num_epochs = 31
for i in range(num_epochs):
    content_loss, style_loss, pastiche = style_cnn.train()

    if i % 10 == 0:
        print("Iteration: %d" % (i))
        print("Content loss: %f" % (content_loss.item()))
        print("Style loss: %f" % (style_loss.item()))
            
        path = "outputs/%d.png" % (i)
        pastiche.data.clamp_(0, 1)
        save_image(pastiche, path)

Iteration: 0
Content loss: 6.625802
Style loss: 48.774509
Iteration: 10
Content loss: 1.582543
Style loss: 0.273035
Iteration: 20
Content loss: 1.339626
Style loss: 0.185075
Iteration: 30
Content loss: 1.299487
Style loss: 0.175296
