<a href="https://colab.research.google.com/github/bgalerne/mva_generative_models_for_images/blob/main/2_mvagm_CNN_texture_synthesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Texture Synthesis with CNNs in PyTorch

## Introduction ##

This practical session explains how to implement the Texture Synthesis based on the algorithm developed by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge on the Paper **[Texture Synthesis Using Convolutional Neural Networks](https://papers.nips.cc/paper/5633-texture-synthesis-using-convolutional-neural-networks)**. 


**Sources:**
This practical session is based on several resources:

*   Original code: https://github.com/leongatys/DeepTextures
*   Reimplementation: https://github.com/trsvchn/deep-textures
*   Tutorial used for some explanations: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

**Authors:**
* Bruno Galerne: www.idpoisson.fr/galerne / https://github.com/bgalerne
* Lucía Bouza


**Texture Synthesis**: Given an input texture image, produce an output texture image being both visually similar to and pixel-wise different from the input texture. The output image should ideally be perceived as another part of the same large piece of homogeneous material the input texture is taken from.



##Underlying Principle##

Let us recall the algorithm proposed by Gatys et al.
Given an example image $u$ and a random initialization $x=x_0$, 
one optimizes the loss function 
$$
E(x) = \sum_{\text{for selected layers } L} w_L\left\| G^L(x) - G^L(u) \right\|^2_F
$$
where $\|\cdot\|_F$ is the Frobenius norm and for an image $y$ and a layer index $L$ $G^L(y)$ denotes the Gram matrix of the VGG-19 features at layer $L$:
if $V^L(y)$ is the feature response of $y$ at layer $L$ that has spatial size $w\times h$ and $n$ channels, 
$$
G^L(y) = \frac{1}{w h}\sum_{k\in \{0,\dots,w-1\}\times\{0,\dots,h-1\}} V^L(y)_k V^L(y)_k^T \in \mathbb{R}^{n\times n}.
$$
The optimization is done using the L-BFGS algorithm.

## Exercise 1:

1. Go through the notebook and execute each cell.

2. We are using the outputs of 5 VGG-19 layers to define $E$. Verify that the quality of the output texture decreases if one uses less layers (e.g. only the first layer or the three first layers).


## Importing Packages ##

Below is a list of the packages needed to implement the texture synthesis.



* `torch` (indispensables packages for neural networks with PyTorch)
* `torchvision.transforms.functional` (necessary to transform images into tensors)
* `torchvision.models` (to get vgg network)
* `mse_loss` (to compute loss)
* `torch.optim`
* `PIL.Image, matplotlib.pyplot, BytesIO, urlopen` (load and display images)

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.functional import mse_loss
import torchvision.models as models
from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt
from io import BytesIO
from urllib.request import urlopen
import os

## Loading Images

On next section we will load images.

In [None]:
texture_imgnames = ["bois.png", "briques.png", "mur.png", "tissu.png", "nuages.png","pebbles.jpg","wall1003.png"]
#import wget
for fname in texture_imgnames:
  os.system("wget -c https://www.idpoisson.fr/galerne/mva/"+fname)
  img = Image.open(fname)
  print(img.size)
  display(img)

## Set a device

Next, we need to choose which device to run the network. Running the algorithm on large images takes longer and will go much faster when running on a GPU. We can use `torch.cuda.is_available()` to detect if there is a GPU available. Next, we set the `torch.device`. Also the `.to(device)` method is used to move tensors or modules to a desired device.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device is", device)
!nvidia-smi

## Prepare data ##

The original PIL RGB images have values between 0 and 255 and size WxHx3, but when transformed into torch tensors, their values are converted to be between 0 and 1 with size 3xWxH. This "chanel first" convention is always used to pass an image into a CNN.

VGG networks are trained on images with each channel normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] (mean and standard deviation of ImageNet). We will have to normalize the image tensor before sending it into the network.

Here some auxiliary functions to load, display and transform to tensors. 

In [None]:
# Utilities
# Functions to manage images

MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)

def prep_img(imagename: str, size=None, mean=MEAN, std=STD):
    """Preprocess image.
    1) load as PIl
    2) resize
    3) convert to tensor
    4) normalize
    """
    im = Image.open(imagename)
    texture = resize(im, size) # resize so that minimal side length is size pixels
    texture_tensor = to_tensor(texture).unsqueeze(0) # add batch dimension
    # remove alpha channel if any
    if texture_tensor.shape[1]==4:
      print('removing alpha chanel')
      texture_tensor = texture_tensor[:,:3,:,:]
    texture_tensor = normalize(texture_tensor, mean=mean, std=std)
    return texture_tensor


def denormalize(tensor: torch.Tensor, mean=MEAN, std=STD):
    """Based on torchvision.transforms.functional.normalize.
    """
    tensor = tensor.clone().squeeze() # remove batch dimension
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    tensor.mul_(std).add_(mean)
    return tensor


def to_pil(tensor: torch.Tensor):
    """Converts tensor to PIL Image.
    Args: tensor (torch.Temsor): input tensor to be converted to PIL Image of torch.Size([C, H, W]).
    Returns: PIL Image: converted img.
    """
    img = tensor.clone().detach().cpu()
    img = denormalize(img).clip(0, 1)
    img = to_pil_image(img)
    return img

Now, we transform the image to tensor, making the normalization and resize.

In [None]:
input_image_name = "wall1003.png"
img_size = 256

# Prepare texture data
target = prep_img(input_image_name, img_size).to(device)
target_img = to_pil(target)
plt.imshow(target_img)
display(target_img)

## Model

Now we need to import a pretrained neural network. We will use a 19 layer VGG network of pyTorch.

PyTorch implementation of VGG is a module divided into two child Sequential modules: features (containing convolution and pooling layers), and classifier (containing fully connected layers). For the texture synthesis task we only care about the layers of the features module. Also, don't let the parameters change (the network is already trained). 

On the output of next commands you can see the structure of `features` module. Indexes will help to select the needed layers for the algorithm. 


In [None]:
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn.requires_grad_(False)

According to the algorithm explained at the beginning of this notebook, we need to access the outputs of some selected intermediate layers.

In order to access the outputs of the layers on PyTorch VGG19 network, we need to register a hook on each layer we need. Hooks are functions, able to be attached to every layer and called each time the layer is compited. You can register a hook before or after the forward pass, or after the backward pass. We will define a function `save_output` that will be triggered after the forward pass, for each layer of `features` module. 

The outputs of the layers will be stored on a dictionary where the key is the index of the layer and the value is the output tensor of the layer.

So, we must define which layers will be part of the optimization and define weights for each one (we will use the weights when running the texture synthesis). Using the indexes of layers, we select the layers to use in the algorithm. The output of first conv layer, and the outputs of pools layers are a good selection. That's why we choose indexes 0, 4, 9, 18, 27. 

In [None]:
# Initialize outputs dic
outputs = {}

# Hook definition
def save_output(name):
    
    # The hook signature
    def hook(module, module_in, module_out):
        outputs[name] = module_out
    return hook

# Define layers
layers = [1, 6, 11, 20, 29]
# Define weights for layers
layers_weights = [1e9/n**2 for n in [64,128,256,512,512]]

# Register hook on each layer with index on array "layers"
for layer in layers:
    handle = cnn[layer].register_forward_hook(save_output(layer))

## Loss Function and optimizer

Now, we need to define the Loss function $E$ defined at the beginning of the notebook. To do so we define a function to calculate the Gram Matrix of a feature layer, and then a loss function that computes the Mean-Square-Error (MSE) for 2 Gram matrices. 
We also compute the Gram matrices of the target once to save computation.

In [None]:
# Computes Gram matrix for the input batch tensor.
#    Args: tnsr (torch.Tensor): input tensor of the Size([B, C, H, W]).
#    Returns:  G (torch.Tensor): output tensor of the Size([B, C, C]).
def gramm(tnsr: torch.Tensor) -> torch.Tensor:   
    b,c,h,w = tnsr.size() 
    F = tnsr.view(b, c, h*w)
    G = torch.bmm(F, F.transpose(1,2)) 
    G.div_(h*w)
    return G

# Computes MSE Loss for 2 Gram matrices 
def gram_loss(input: torch.Tensor, gramm_target: torch.Tensor, weight: float = 1.0):
    loss = weight * mse_loss(gramm(input), gramm_target)
    return loss

## Optimizer and initialization:

Then we compute the random initialization $x_0$. This tensor has to have the same size of the original image.

We use L-BFGS algorithm to run our gradient descent. We will create a PyTorch L-BFGS optimizer `optim.LBFGS` and pass our `synth` image to it as the tensor to optimize.

In [None]:
# selec input image: ["bois.png", "briques.png", "mur.png", "tissu.png", "nuages.png","pebbles.jpg","wall1003.png"]
input_image_name = "wall1003.png"
img_size = 256

# Prepare texture data
target = prep_img(input_image_name, img_size).to(device)
target_img = to_pil(target)
plt.imshow(target_img)

# Forward pass using target texture for get activations of selected layers (outputs). Calculate gram Matrix for those activations
cnn(target)
gramm_targets = [gramm(outputs[key]) for key in layers] 

# Random init for image synth
synth = torch.randn_like(target) * 0.5
synth.requires_grad=True

# Set optimizer
optimizer = optim.LBFGS([synth])

## Running Texture Synthesis

Finally, we must run code that performs the texture synthesis. 

We have to compute the activations of the layers selected for the texture image (dictionary `outputs` after forward pass using target texture). We also will compute for Gram Matrix for those activations (this values doesn't change so is efficient calculate it just once). 

Then, for each iteration of the network, it is fed an updated input and computes new losses between `target` activations and `synth` activations.

The optimizer requires a “closure” function, which reevaluates the module and returns the loss.

In [None]:
n_iters = 2000
log_every = n_iters//10
iter_ = 0

while iter_ <= n_iters:

    def closure():
        global iter_

        optimizer.zero_grad()

        # Forward pass using synth. Get activations of selected layers for image synth (outputs). Calculate gram Matrix for those activations
        cnn(synth)
        synth_outputs = [outputs[key] for key in layers] 
        
        # Compute loss for each activation
        losses = []
        for activations in zip(synth_outputs, gramm_targets, layers_weights):
            losses.append(gram_loss(*activations).unsqueeze(0))

        total_loss = torch.cat(losses).sum()
        total_loss.backward()

        # Display results: print Loss value and show image
        if iter_ == 0 or iter_ % log_every == 0:
            print('Iteration: %d, loss: %1.2e'%(iter_, total_loss.item()))
            fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))
            axes[0].imshow(target_img)
            axes[0].set_title('original image')
            axes[1].imshow(to_pil(synth))
            axes[1].set_title('synthesis (it. %d)'%( iter_ ))
            fig.tight_layout()
            plt.pause(0.05)

        iter_ += 1

        return total_loss

    optimizer.step(closure)


# Exercise 2:

Pick one of the following problems:

* **A: Color correction** 
  
  Observe that for some image the color is inconsistant (eg with `wall1003.png`). A solution to correct the output color distribution is to incorporate the mean color and the color covariance as a target statistics in $E$. 

  **Hint**: Considering mean color vector $m$ and covariance matices $C_h = C_h(0)$:

  $$
  m = \begin{pmatrix}
  m_r \\
  m_g \\
  m_b
  \end{pmatrix}
  = \frac{1}{MN}\sum_{t\in\Omega} h(t) \in \mathbb{R}^{3}
  $$
  $$
  C_h = \frac{1}{MN}\sum_{t\in\Omega} 
  \begin{pmatrix}
  h_r(t) - m_r \\
  h_g(t) - m_g \\
  h_b(t) - m_b
  \end{pmatrix}
  \begin{pmatrix}
  h_r(t) - m_r \\
  h_g(t) - m_g \\
  h_b(t) - m_b
  \end{pmatrix}^T
  \in\mathbb{R}^{3\times 3}.
  $$

  Then change $E$ to:
  $ E + \lambda_{mean} \| m(x) - m(u)\|^2 + \lambda_{cov} \| C(x) - C(u)\|^2. $

  Try with $\lambda_{mean}$ and $\lambda_{cov}$ between 1e6 and 1e3.



* **B: Spectral correction** 

  Add a term to the energy that would enforce a consistency with the original Fourier spectrum of each color channel, that is change $E$ to: 
$$
E + \lambda_{Fourier} \| |\hat{x}| - |\hat{u}|\|^2.
$$
Try with $\lambda_{Fourier}$ between 5 and 0.5 to see the differences. Try with different kind of textures.
What is the interest of this approach?
What are the textures for which it improves or degrades the quality of the result? 




* **C: Order one statistics** 

  Replace $E$ so that the spatial average of ALL the VGG-19 layers is preserved, that is change $E$ to:
$$
  E_{mean} (x) = \sum_{\text{for all layers } L} w_L \left\| \operatorname{mean}(V^L(x)) - \operatorname{mean}(V^L(u)) \right\|^2_F
$$
  Here, consider weighting the layers with the same approach made in the notebook. The mean is computed along the spatial dimension, so for each layer the mean vector has size "number of channels within the layer". 
 Compare with the original model. What is the interest of this approach?

  
  
