# Neural Style Transfer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms
import torchvision

import matplotlib.pyplot as plt
import random
import numpy as np
import copy
import time
import os
import cv2
from PIL import Image
from torchvision.models import vgg19
from torchvision.utils import save_image


# 1. Load data

In [None]:
def set_seed(seed, use_gpu = True):
    """
    Set SEED for PyTorch reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

SEED = 44

USE_SEED = False

if USE_SEED:
    set_seed(SEED, torch.cuda.is_available())

In [None]:
def load_image(image_path, device, output_size=None):
    """Loads an image by transforming it into a tensor."""
    img = Image.open(image_path)

    output_dim = None
    if output_size is None:
        output_dim = (img.size[1], img.size[0])
    elif isinstance(output_size, int):
        output_dim = (output_size, output_size)
    elif isinstance(output_size, tuple):
        if (len(output_size) == 2) and isinstance(output_size[0], int) and isinstance(output_size[1], int):
            output_dim = output_size
    else:
        raise ValueError("ERROR: output_size must be an integer or a 2-tuple of (height, width) if provided.")

    torch_loader = transforms.Compose(
        [
            transforms.Resize(output_dim),
            transforms.ToTensor()
        ]
    )
    
    img_tensor = torch_loader(img).unsqueeze(0)
    return img_tensor.to(device)

In [None]:

"""Gloria´s paths"""
#content_path = "/home/gloria/Scrivania/Vision_and_cognitive_system/content_style/content.jpg"
#style_path = "/home/gloria/Scrivania/Vision_and_cognitive_system/content_style/style1.jpg"

"""Sara´s paths"""
content_path = "/home/sara/Scrivania/Physics_of_Data/2nd Year/Vision_cognitive_sys/Projects/neural_style_transfer/taj_mahal.jpg"
style_path = "/home/sara/Scrivania/Physics_of_Data/2nd Year/Vision_cognitive_sys/Projects/neural_style_transfer/vg_starry_night.jpg"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_size = 512

content_tensor = load_image(content_path, device, output_size=output_size)
output_size = (content_tensor.shape[2], content_tensor.shape[3])
style_tensor = load_image(style_path, device, output_size=output_size)

# 2. Load VGG model

In [None]:
class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()

        #select 5 convolutional layers
        self.chosen_features = {0: 'conv1_1', 5: 'conv2_1', 10: 'conv3_1', 19: 'conv4_1', 21: 'conv4_2', 28: 'conv5_1'}
        self.vgg = torchvision.models.vgg19(weights='DEFAULT').features[:37]
        
    def forward(self, x):
        feature_maps = dict()
        for idx, layer in enumerate(self.vgg):
            x = layer(x)
            if idx in self.chosen_features.keys():
                feature_maps[self.chosen_features[idx]] = x
        
        return feature_maps



In [None]:
#load the model
vgg = VGG19().to(device).eval()

# 3. Loss function

The overall loss is constituted by the loss of the target image with respect to the content image, and the loss of the target image with respect to the style image. $$L_{tot}=L_{content}+L_{style}$$
For this process it wouldn´t make sense to compare the images pixel by pixel: for example if the content image contains a house and the predominant style of the style image is to have diagonal lines, we would want the target image to be a house which is inclinated diagonally; comparing pixel by pixel an image with a diagonal house and an image with a house would return a much higher loss than we expect, because the pixel by pixel comparison doesn´t take into account more 'generic' features
In order to perform a more accurate comparison, both these losses are evaluated between **feature maps** which take into account the more generic features of both images.

### 3.1 Content loss
The content loss is computed at the end of the CNN; we compute the mean squared error between the target feature map and the content feature map.

In [None]:
def get_content_loss(target_map, content_map):
    #return torch.mean((content_original-content_current)**2)
    return torch.nn.MSELoss(reduction='mean')(target_map, content_map)

### 3.2 Style loss
For the style loss, the procedure is more complicated.
We are interested in co-occurrences of pairs of features to highlight important stylistic combinations.    

If we have a feature map, of height and width $h,w$ and lenght $k$, which is the number of maps applied, we want to compute cooccurrences between each pair of  maps $i,j$ with values in $[0,k]$ range (these are also called *channels*): we obtain a $kxk$ matrix in which each entry is the dot product between two maps, a scalar.    
Given the feature map of an image, this matrix, called the **Gram matrix** can be computed easily as the sum of the matrix multiplication between the whole feature map and its transpose.
   
This is done both with the feature map of the STYLE IMAGE and the feature map of the TARGET IMAGE.
We compute a Gram matrix for both images for each convolutional layer considered $l$, and end up with:
   - 5 Gram matrices of the style image feature maps $G_{style}^l$
   - 5 Gram matrices of the target image feature maps  $G_{target}^l$    
   
The loss of each layer $l$ is computed via MSE between the two gram matrices, and the overall style loss will be the average of these values over the number of layers (in our case 5).

In [None]:
def get_style_loss(target_map,style_map):
    """Compute MSE between gram matrix of style feature map and of generated feature map as style loss."""
    _, channel, height, width = target_map.shape
    
    #computing Gram matrix of the style feature map
    style_gram = style_map.view(channel, height*width).mm(
        style_map.view(channel, height*width).t()
    )
    #computing Gram matrix of the target feature map
    target_gram = target_map.view(channel, height*width).mm(
        target_map.view(channel, height*width).t()
    )
    # Normalize the Gram matrices
    norm = channel * height * width
    style_gram /= norm
    target_gram /= norm
    
    return torch.mean((target_gram - style_gram) ** 2)



### 3.3 Total variation loss

To regularize the loss function and encourage smoothness in the output image, we also introduce a total variation loss term to the total loss. This additional term will have to be weighted appropriately.

In [None]:
def total_variation_loss(image):
    loss = torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
           torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
    return loss

# 4. Testing

### 4.1 Initialize random (target) image

In [None]:
img=content_tensor
img.shape 

#gaussian_noise_img = np.random.normal(loc=0, scale=90., size=img.shape).astype(np.float32)
white_noise_img = np.random.uniform(-90., 90., img.shape).astype(np.float32)
init_img = torch.from_numpy(white_noise_img).float().to(device)
init_img = (init_img - init_img.min()) / (init_img.max() - init_img.min())
init_img.shape

In this case, renormalized images are used in the CNN, but to visualize the actual images we have to denormalize them

In [None]:
content_show = transforms.ToPILImage()(torchvision.utils.make_grid(content_tensor.cpu()))
style_show = transforms.ToPILImage()(torchvision.utils.make_grid(style_tensor.cpu()))
rnd_show = transforms.ToPILImage()(torchvision.utils.make_grid(init_img.cpu()))


# Display the original content, style image, and random noise images
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(content_show)
axs[0].set_title("Content Image")

axs[1].imshow(style_show)
axs[1].set_title("Style Image")

axs[2].imshow(rnd_show)
axs[2].set_title("Random Noise Image")

plt.show()

###  4.2 Function to save intermediate feature maps
Used during tests to check the behaviour of the feature maps.

In [None]:

def save_content_features(content, model, layer_indices, intermediate_dir):
    with torch.no_grad():
        content_features = model(content)

        for layer_idx in layer_indices:
            layer_name = model.chosen_features[layer_idx]
            feature_map = content_features[layer_name].squeeze(0)
            
            # Normalizza il tensore
            normalized_feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min())
            
            # Converte la feature map in un'immagine utilizzando matplotlib
            plt.imshow(normalized_feature_map[0].cpu().numpy(), cmap='viridis')
            plt.axis('off')
            
            # Salva l'immagine
            save_path = os.path.join(intermediate_dir, f'content_{layer_idx}.jpg')
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
            plt.close()



### 4.2 Set parameters

In [None]:
style_layers = ['conv1_1','conv2_1','conv3_1','conv4_1','conv5_1']
content_layers = ['conv4_2']

In [None]:
content = content_tensor
style = style_tensor
target = init_img.requires_grad_(True)  #requires_grad is needed to make sure that the image is updated


learn_rate=0.1
alpha=5.0
beta=1e7
tv_weight=1e-3

intermediate_dir="/home/sara/Scrivania/Physics_of_Data/2nd Year/Vision_cognitive_sys/Projects/neural_style_transfer/intermediate"

In [None]:
def train_image(content, style, target, device, output_img_fmt, content_img_name, style_img_name, num_epochs,
               learn_rate):
    """Update the output image using pre-trained VGG19 model."""
    ...
    
    model = VGG19().to(device).eval()    # freeze parameters in the model

    optimizer = torch.optim.Adam([target], lr=learn_rate)
    
    for epoch in range(num_epochs):
        # get features maps of content, style and generated images from chosen layers
        content_features = model(content)
        style_features = model(style)
        target_features = model(target)
        
        content_loss  = 0.0
        style_loss= 0.0
        
        
        """Computing loss"""
        for layer in target_features.keys(): 
            content_feature = content_features[layer]
            style_feature = style_features[layer]
            target_feature = target_features[layer]
   
            
            if layer in content_layers:
            # computes content loss on layer 4_2
                content_loss_per_feature = get_content_loss(content_feature, target_feature)
                content_loss += content_loss_per_feature
                
            if layer in style_layers:
            # computes style loss for all 5 style layers
                style_loss_per_feature = get_style_loss(style_feature, target_feature)
                style_loss+=style_loss_per_feature
        
        #average style loss over all 5 layers
        style_loss /= len(style_layers)
        
        tv_loss = total_variation_loss(target)
        
        # Total loss 
        total_loss = alpha * content_loss + beta * style_loss + tv_weight * tv_loss

      #compute the gradient and update parameters
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        

        
        #Save every 100 steps
        if ((epoch+1)%50)==0:
            save_image(target, os.path.join(intermediate_dir, f'nst-{content_img_name}-{style_img_name}-{epoch + 1}.{output_img_fmt}'))
            
            '''If we used the renormalization, we have to denormalize before saving for visualization purposes'''
            #denormalized_target = denormalize(target.cpu().squeeze()).clamp(0, 1)
            # Save the denormalized image
            #save_image(denormalized_target, os.path.join(intermediate_dir, f'nst-denorm-{content_img_name}-{style_img_name}-{epoch + 1}.{output_img_fmt}'))


        print(f"\tEpoch {epoch + 1}/{num_epochs}, loss = {total_loss}") 
    ...

    return 1

In [None]:
train_image(content, style, target, device,'jpeg', 'taj', 'vangogh',500,learn_rate)