#### IMPORT THE LIBRARIES

In [1]:
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
from torch import optim
import torchvision
import torch.nn as nn
import numpy as np

#### DEFINING THE VGG16 MODEL

In [1]:

# define the VGG
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # load the vgg model's features
        self.vgg = models.vgg19(pretrained=True).features
    
    def get_content_activations(self, x: torch.Tensor) -> torch.Tensor:
        """
            Extracts the features for the content loss from the block4_conv2 of VGG19
            Args:
                x: torch.Tensor - input image we want to extract the features of
            Returns:
                features: torch.Tensor - the activation maps of the block4_conv2 layer
        """
        features = self.vgg[:23](x)
        return features
    
    def get_style_activations(self, x):
        """
            Extracts the features for the style loss from the block1_conv1, 
                block2_conv1, block3_conv1, block4_conv1, block5_conv1 of VGG19
            Args:
                x: torch.Tensor - input image we want to extract the features of
            Returns:
                features: list - the list of activation maps of the block1_conv1, 
                    block2_conv1, block3_conv1, block4_conv1, block5_conv1 layers
        """
        features = [self.vgg[:4](x)] + [self.vgg[:7](x)] + [self.vgg[:12](x)] + [self.vgg[:21](x)] + [self.vgg[:30](x)] 
        return features
    
    def forward(self, x):
        return self.vgg(x)

#### DEFINING THE LOSS FUNCTIONS

In [2]:
def gram(tensor):
    """
        Constructs the Gramian matrix out of the tensor
    """
    return torch.mm(tensor, tensor.t())


def gram_loss(noise_img_gram, style_img_gram, N, M):
    """
        Gramian loss: the SSE between Gramian matrices of a layer
            arXiv:1508.06576v2 - equation (4)
    """
    return torch.sum(torch.pow(noise_img_gram - style_img_gram, 2)).div((np.power(N*M*2, 2, dtype=np.float64)))


def total_variation_loss(image):
    """
        Variation loss makes the images smoother, defined over spacial dimensions
    """
    loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
        torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
    return loss


def content_loss(noise: torch.Tensor, image: torch.Tensor):
    """
        Simple SSE loss over the generated image and the content image
            arXiv:1508.06576v2 - equation (1)
    """
    return 1/2. * torch.sum(torch.pow(noise - image, 2))

In [3]:
def main(style_img_path: str,
         content_img_path: str, 
         img_dim: int,
         num_iter: int,
         style_weight: int,
         content_weight: int,
         variation_weight: int,
         print_every: int,
         save_every: int):

    assert style_img_path is not None
    assert content_img_path is not None

    # define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # read the images
    style_img = Image.open(style_img_path)
    cont_img = Image.open(content_img_path)
    
    # define the transform
    transform = transforms.Compose([transforms.Resize((img_dim, img_dim)),
                                    transforms.ToTensor(), 
                                    transforms.Normalize([0.485, 0.456, 0.406],
                                                         [0.229, 0.224, 0.225])])
    
    # get the tensor of the image
    content_image = transform(cont_img).unsqueeze(0).to(device)
    style_image = transform(style_img).unsqueeze(0).to(device)
    
    # init the network
    vgg = VGG().to(device).eval()
    
    # replace the MaxPool with the AvgPool layers
    for name, child in vgg.vgg.named_children():
        if isinstance(child, nn.MaxPool2d):
            vgg.vgg[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)
            
    # lock the gradients
    for param in vgg.parameters():
        param.requires_grad = False
    
    # get the content activations of the content image and detach them from the graph
    content_activations = vgg.get_content_activations(content_image).detach()
    
    # unroll the content activations
    content_activations = content_activations.view(512, -1)
    
    # get the style activations of the style image
    style_activations = vgg.get_style_activations(style_image)
    
    # for every layer in the style activations
    for i in range(len(style_activations)):

        # unroll the activations and detach them from the graph
        style_activations[i] = style_activations[i].squeeze().view(style_activations[i].shape[1], -1).detach()

    # calculate the gram matrices of the style image
    style_grams = [gram(style_activations[i]) for i in range(len(style_activations))]
    
    # generate the Gaussian noise
    noise = torch.randn(1, 3, img_dim, img_dim, device=device, requires_grad=True)
    
    # define the adam optimizer
    # pass the feature map pixels to the optimizer as parameters
    adam = optim.Adam(params=[noise], lr=0.01, betas=(0.9, 0.999))

    # run the iteration
    for iteration in range(num_iter):

        # zero the gradient
        adam.zero_grad()

        # get the content activations of the Gaussian noise
        noise_content_activations = vgg.get_content_activations(noise)

        # unroll the feature maps of the noise
        noise_content_activations = noise_content_activations.view(512, -1)

        # calculate the content loss
        content_loss_ = content_loss(noise_content_activations, content_activations)

        # get the style activations of the noise image
        noise_style_activations = vgg.get_style_activations(noise)

        # for every layer
        for i in range(len(noise_style_activations)):

            # unroll the the noise style activations
            noise_style_activations[i] = noise_style_activations[i].squeeze().view(noise_style_activations[i].shape[1], -1)

        # calculate the noise gram matrices
        noise_grams = [gram(noise_style_activations[i]) for i in range(len(noise_style_activations))]

        # calculate the total weighted style loss
        style_loss = 0
        for i in range(len(style_activations)):
            N, M = noise_style_activations[i].shape[0], noise_style_activations[i].shape[1]
            style_loss += (gram_loss(noise_grams[i], style_grams[i], N, M) / 5.)

        # put the style loss on device
        style_loss = style_loss.to(device)
            
        # calculate the total variation loss
        variation_loss = total_variation_loss(noise).to(device)

        # weight the final losses and add them together
        total_loss = content_weight * content_loss_ + style_weight * style_loss + variation_weight * variation_loss

        if iteration % print_every == 0:
            print("Iteration: {}, Content Loss: {:.3f}, Style Loss: {:.3f}, Var Loss: {:.3f}".format(iteration, 
                                                                                                     content_weight * content_loss_.item(),
                                                                                                     style_weight * style_loss.item(), 
                                                                                                     variation_weight * variation_loss.item()))

        # create the folder for the generated images
        if not os.path.exists('./generated/'):
            os.mkdir('./generated/')
        
        # generate the image
        if iteration % save_every == 0:
            save_image(noise.cpu().detach(), filename='./generated/iter_{}.png'.format(iteration))

        # backprop
        total_loss.backward()
        
        # update parameters
        adam.step()

from google.colab import files

uploaded = files.upload()

In [4]:
!ls

StyleTransfer.ipynb
content_img_1.jpg
style_img_1.jpeg
style_img_2.jpg


#### TRAINING THE MODEL

In [5]:
import os
from torchvision.utils import save_image


style_img = 'style_img_1.jpeg'
content_img = 'content_img_1.jpg'

main(style_img, content_img, 512, 12000, 10e6, 10e-4, 10e3, 500, 1000)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to C:\Users\Harshal/.cache\torch\checkpoints\vgg19-dcbb9e9d.pth
100%|███████████████████████████████████████████████████████████████████████████████| 548M/548M [02:03<00:00, 4.65MB/s]


Iteration: 0, Content Loss: 416.564, Style Loss: 3927154.839, Var Loss: 22578.158
Iteration: 500, Content Loss: 926.612, Style Loss: 142313.084, Var Loss: 18363.991
Iteration: 1000, Content Loss: 979.907, Style Loss: 57351.720, Var Loss: 16788.577
Iteration: 1500, Content Loss: 966.062, Style Loss: 16623.667, Var Loss: 14469.128
Iteration: 2000, Content Loss: 937.888, Style Loss: 8590.085, Var Loss: 11433.748
Iteration: 2500, Content Loss: 917.420, Style Loss: 5647.683, Var Loss: 8440.204
Iteration: 3000, Content Loss: 897.237, Style Loss: 4076.416, Var Loss: 5977.107
Iteration: 3500, Content Loss: 876.495, Style Loss: 3080.994, Var Loss: 4314.184
Iteration: 4000, Content Loss: 856.655, Style Loss: 2387.100, Var Loss: 3394.766
Iteration: 4500, Content Loss: 836.788, Style Loss: 1880.576, Var Loss: 2936.953
Iteration: 5000, Content Loss: 817.842, Style Loss: 1498.656, Var Loss: 2695.118
Iteration: 5500, Content Loss: 799.770, Style Loss: 1203.540, Var Loss: 2547.415
Iteration: 6000, Con

In [0]:
!ls /content/generated/

iter_0.png     iter_200.png   iter_400.png   iter_600.png
iter_1000.png  iter_3000.png  iter_5000.png  iter_7000.png
iter_100.png   iter_300.png   iter_500.png   iter_8000.png
iter_2000.png  iter_4000.png  iter_6000.png  iter_9000.png


In [0]:
from google.colab import files
files.download('/content/generated/iter_8000.png')