In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from torchvision import transforms
from torch import optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import warnings
from torchvision.utils import save_image

warnings.simplefilter('ignore')

In [2]:
def gram(tensor):
    return torch.mm(tensor, tensor.t())

In [3]:
def gram_loss(noise_img_gram, style_img_gram, N, M):
    return torch.sum(torch.pow(noise_img_gram - style_img_gram, 2)).div((np.power(N*M*2, 2, dtype=np.float64)))

In [4]:
def total_variation_loss(image):
    # shift one pixel and get difference (for both x and y direction)
    loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
            torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
    return loss

In [5]:
# read the images
cont_img = Image.open('./content_img_1.jpg')
style_img = Image.open('./style_img_3.jpg')

In [6]:
# define the transform
transform = transforms.Compose([transforms.Resize((1024, 1024)),
                                transforms.ToTensor(), 
                                transforms.Normalize([0.485, 0.456, 0.406],
                                                     [0.229, 0.224, 0.225])])

In [7]:
# get the tensor of the image
content_image = transform(cont_img).unsqueeze(0).cuda()
style_image = transform(style_img).unsqueeze(0).cuda()

In [8]:
# define the VGG
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # load the vgg model's features
        self.vgg = models.vgg19(pretrained=True).features
    
    def get_content_activations(self, x):
        return self.vgg[:22](x)
    
    def get_style_activations(self, x):
        # block1_conv1, block2_conv1, block3_conv1, block4_conv1, block5_conv1
        return [self.vgg[:4](x)] + [self.vgg[:7](x)] + [self.vgg[:12](x)] + [self.vgg[:21](x)] + [self.vgg[:30](x)] 
#         return [self.vgg[:4](x)]
    
    def forward(self, x):
        return self.vgg(x)

In [9]:
# init the network
vgg = VGG().cuda().eval()

In [10]:
vgg

VGG(
  (vgg): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    

In [None]:
# replace the MaxPool with the AvgPool layers
for name, child in vgg.vgg.named_children():
    if isinstance(child, nn.MaxPool2d):
        vgg.vgg[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)

In [None]:
# lock the gradient
for param in vgg.parameters():
    param.requires_grad = False

In [None]:
# get the content activations of the content image and detach them from the graph
content_activations = vgg.get_content_activations(content_image).detach()

In [None]:
# unroll the content activations
content_F = content_activations.view(512, -1)

In [None]:
# get the style activations of the style image
style_activations = vgg.get_style_activations(style_image)

In [None]:
style_activations[0].shape

In [None]:
# for every layer in the style activations
for i in range(len(style_activations)):
    
    # unroll the activations and detach them from the graph
    style_activations[i] = style_activations[i].squeeze().view(style_activations[i].shape[1], -1).detach()

# calculate the gram matrices of the style image
gram_matrices = [gram(style_activations[i]) for i in range(len(style_activations))]

In [None]:
style_activations[0].shape

In [None]:
gram_matrices[0].shape

### Training

In [None]:
# generate the Gaussian noise
noise = torch.randn(1, 3, 1024, 1024, device='cuda', requires_grad=True)

In [None]:
# define the adam optimizer
# pass the fearture map pixels to the optimnizer as parameters
adam = optim.Adam(params=[noise], lr=0.01, betas=(0.9, 0.999))

# run the iteration
for iteration in range(20000):
    
    # zero grad
    adam.zero_grad()
    
    # get the content activations of the Gaussian noise
    noise_content_activations = vgg.get_content_activations(noise)
    
    # unroll the feature maps of the noise
    noise_content_F = noise_content_activations.view(512, -1)
    
    # calculate the content loss
    content_loss = 1/2. * torch.sum(torch.pow(noise_content_F - content_F, 2))
    
    # get the style activations of the noise image
    noise_style_activations = vgg.get_style_activations(noise)
    
    # for every layer
    for i in range(len(noise_style_activations)):
        
        # unroll the the noise style activations
        noise_style_activations[i] = noise_style_activations[i].squeeze().view(noise_style_activations[i].shape[1], -1)
    
    # calculate the noise gram matrices
    noise_gram_matrices = [gram(noise_style_activations[i]) for i in range(len(noise_style_activations))]
    
    # calculate the total style loss
    style_loss = 0
    for i in range(len(style_activations)):
        N, M = noise_style_activations[i].shape[0], noise_style_activations[i].shape[1]
        style_loss += (gram_loss(noise_gram_matrices[i], gram_matrices[i], N, M) / 5.)

    variation_loss = total_variation_loss(noise).cuda()
        
    style_loss = style_loss.cuda()
    # try to reproduce the style
#     total_loss = 10e6 * style_loss + 10e3 * variation_loss
    total_loss = 10e-4 * content_loss + 10e6 * style_loss + 10e3 * variation_loss
    
    if iteration % 1000 == 0:
        print("Iteration: {}, Content Loss: {:.3f}, Style Loss: {:.3f}, Var Loss: {:.3f}".format(iteration, 
                                                                                                 10e-4 * content_loss.item(), 
                                                                                                 10e6 * style_loss.item(), 
                                                                                                 10e3 * variation_loss.item()))
    
    if iteration % 100 == 0:
        save_image(noise.cpu().detach(), filename='./generated/iter_{}.png'.format(iteration))
        
    total_loss.backward()
    adam.step()

In [None]:
plt.imshow(noise.cpu().detach().squeeze().permute(1, 2, 0))