<a href="https://colab.research.google.com/github/bgalerne/mva_generative_models_for_images/blob/main/Maximum_entropy_model_for_CNN_texture_synthesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Maximum Entropy Model for CNN Texture Synthesis

## Introduction

This practical session implements the texture synthesis algorithm developed in the paper **Maximum entropy methods for texture synthesis: theory and practice** by *V. De Bortoli, A. Desolneux, A. Durmus, B. Galerne, A. Leclaire*.

**References:**

* Paper: Maximum entropy methods for texture synthesis: theor y and practice,  V. De Bor toli, A. Desolneux, A. Dur mus, B. Galerne, A. Leclaire, SIAM Jour nal on Mathematics of Data Science (SIMODS), 2021

* Public repository: https://gitlab.com/vdeborto/macrocanonical-synthesis/-/tree/master/

**Authors:**

* Bruno Galerne: www.idpoisson.fr/galerne / https://github.com/bgalerne
* Lucía Bouza


**Texture Synthesis:** Given an input texture image, produce an output texture image being both visually similar to and pixel-wise different from the input texture. The output image should ideally be perceived as another part of the same large piece of homogeneous material the input texture is taken from.



## Underlying Principle

**Framework: Macrocanonical model**. 

One looks for the parameters $\theta\in\mathbb{R}^p$ 
such that the exponential model
$$
\pi_{\theta}(x) \propto e^{-V(x,\theta)} dx
$$
where
$$ 
V(x,\theta) = \theta \cdot (f(x)-f(x_0)) + J(x)
$$
with

* $x_0$ is the target texture, 

* $f:\mathbb{R}^d \to \mathbb{R}^p$ is the spatial average of the feature responses of each selected layers (multiplied by $\beta = 128$).

* $J(x) = \frac{\epsilon}{2}  \left\| x \right\|^2$

such that $\theta$ is a solution of the macrocanonical problem, that is, 
$$
\mathbb{E}_{\pi_{\theta}}[f(X)] = f(x_0)
$$
and $\pi_{\theta}$ has maximal entropy.

Let us recall that the pseudo-code of the algorithm.




**SOUL algorithm**

* Initialization: $\theta \leftarrow 0$; $X_0^0 \in \mathbb{R}^d$
* For $n = 1, \ldots, N$,

  * $m_n$ steps of Langevin diffusion: for $k=0,\ldots,m_n-1$, 
         
  $$
  X_{k+1}^n = X_k^n - \gamma_{n+1} \nabla_x V(X_k^n,\theta_n) + \sqrt{2\gamma_{n+1}} Z_{k+1}^n
  $$
  with $Z_{k+1}^n \sim \mathcal{N}(0,I)$

 * Update $\theta$ with Langevin intermediary states:

  $$
  \theta_{n+1} = \mathsf{Proj}_{\Theta}\left( \theta_n + \frac{\delta_{n+1}}{m_n} \sum_{k=1}^{m_n} f(X_k^n) - f(x_0)\right)
  $$

 * Set warm start for next step: $X_0^{n+1} = X_{m_n}^n$


**In Practice:**
* The initialization $X_0^0$ is an ADSN realization (see below).
* We do not use projections.
* $\epsilon = 0.1$
* $\gamma$ and $\delta$ are fixed.
* $m = 1$, so we do one update of $x$, one update of $\theta$, and so on. 
* Use the layers [1, 3, 6, 8, 11, 13, 15, 24, 26, 31] for optimization.

## Importing Packages ##

Below is a list of the packages needed to implement the texture synthesis.



* `torch` (indispensables packages for neural networks with PyTorch)
* `torchvision.transforms.functional` (necessary to transform images into tensors)
* `torchvision.models` (for get vgg network)
* `PIL.Image, matplotlib.pyplot, os, display` (load and display images)

In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import mse_loss
import torchvision.models as models
from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import display
import os
import numpy as np

## Loading Images

On next section we will load images. Here we will just get and display the images, without doing any changes to it. 

In [None]:
texture_imgnames = ["bois.png", "briques.png", "mur.png", "tissu.png", 
                    "nuages.png","pebbles.jpg","wall1003.png", "osier12.png",
                    "paille17c2.png","bark.png","coffee.png","flower.png",
                    "rock.png","sweet.png"]

for fname in texture_imgnames:
    os.system("wget -c https://www.idpoisson.fr/galerne/mva/"+fname)
    img = Image.open(fname)
    print(fname)
    display(img)

## Set a device

Next, we need to choose which device to run the network. Running the algorithm on large images takes longer and will go much faster when running on a GPU. We can use `torch.cuda.is_available()` to detect if there is a GPU available. Next, we set the `torch.device`. Also the `.to(device)` method is used to move tensors or modules to a desired device.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is", device)
!nvidia-smi

## Prepare data ##

The original PIL images have values between 0 and 255, but when transformed into torch tensors, their values are converted to be between 0 and 1. 

An important detail to note is that neural networks from the torch library are trained with tensor values ranging from 0 to 1. Additionally, VGG networks are trained on images with each channel normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] (mean and standard deviation of Imagenet). We will have to normalize the image tensor before sending it into the network.

Here some auxiliary functions to load, display and transform to tensors. 

In [None]:
# Utilities
# Functions to manage images

MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)

def prep_img(image: str, size=None, mean=MEAN, std=STD):
    """Preprocess image.
    1) load as PIl
    2) resize
    3) convert to tensor
    5) remove alpha channel if any
    4) normalize
    """
    im = Image.open(image)
    texture = resize(im, size)
    texture_tensor = to_tensor(texture).unsqueeze(0)
    if texture_tensor.shape[1]==4:
        print('removing alpha chanel')
        texture_tensor = texture_tensor[:,:3,:,:]
    texture_tensor = normalize(texture_tensor, mean=mean, std=std)
    return texture_tensor


def denormalize(tensor: torch.Tensor, mean=MEAN, std=STD, inplace: bool = False):
    """Based on torchvision.transforms.functional.normalize.
    """
    tensor = tensor.clone().squeeze() 
    mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
    tensor.mul_(std).add_(mean)
    return tensor


def to_pil(tensor: torch.Tensor):
    """Converts tensor to PIL Image.
    Args: tensor (torch.Temsor): input tensor to be converted to PIL Image of torch.Size([C, H, W]).
    Returns: PIL Image: converted img.
    """
    img = tensor.clone().detach().cpu()
    img = denormalize(img).clip(0, 1)
    img = to_pil_image(img)
    return img

## Model

Now we need to import a pretrained neural network. We will use as base the 19 layer VGG network of PyTorch.

PyTorch implementation of VGG is a module divided into two child Sequential modules: features (containing convolution and pooling layers), and classifier (containing fully connected layers). For the texture synthesis task we only need the layers of the features module.


In [None]:
cnn = models.vgg19(pretrained=True).features.to(device).eval()

# Don't let parameters to change
cnn.requires_grad_(False)

We will slithgly change the network so that $f$ is differentiable. We will change Relu for CeLu and Max Pooling for Average Pooling.

On the output of next commands you can see the structure of `features` module. Indexes will help to select the needed layers for the algorithm. 

In [None]:
def differentiable(cnn):
    """
    This function replaces non differentiable non-linear functions
    in the network by differentiable ones.
    """

    for i, layer in cnn.named_modules():
        if isinstance(layer, nn.ReLU):
            cnn[int(i)] = nn.CELU(inplace=True)
        if isinstance(layer, nn.MaxPool2d):
            cnn[int(i)] = nn.AvgPool2d(2, stride=2, padding=0, ceil_mode=False)

# Replace no differentiable functions
differentiable(cnn)

According to the algorithm explained at the beginning of this notebook, we need to access the outputs of some selected intermediate layers.

In order to access the outputs of the layers on PyTorch VGG19 network, we need to register a hook on each layer we need. Hooks are functions, able to be attached to every layer and called each time the layer is used. You can register a hook before or after the forward pass, or after the backward pass. We will define a function `save_output` that will be triggered after the forward pass, for each layer of `features` module. 

The outputs of the layers will be store on a dictionary where the key is the index of the layer and the value is the output tensor of the layer.

So, we must define which layers will be part of the optimization. Using the indexes of layers, we select the layers to use in the algorithm. We will choose indexes [1, 3, 6, 8, 11, 13, 15, 24, 26, 31]. 

In [None]:
# Initialize outputs dic
outputs = {}

# Hook definition
def save_output(name):
    
    # The hook signature
    def hook(module, module_in, module_out):
        outputs[name] = module_out
    return hook

# Define layers
layers = [1,3, 6, 8, 11, 13, 15, 24, 26, 31]

# Register hook on each layer with index on array "layers"
for layer in layers:
    handle = cnn[layer].register_forward_hook(save_output(layer))

The function $f$ corresponds to the spatial mean of the VGG-19 layers computed by the function ```mean_Spatial``` below.

In [None]:
# for computing the spatial average of the feature responses of each selected layer
def mean_Spatial (input: torch.Tensor):  
    mean_input = torch.mean(input.squeeze(), axis=(1,2))  
    return mean_input

## Initialization



We will use the ADSN model to initialize the synthesis.

**ADSN initialization:** 

Let $h\in\mathbb{R}^{M\times N\times 3}$ be a an image, $m = (m_r, m_g, m_b)$ be the mean color of $h$ and $X$ be a Gaussian white noise image.
The random image
$$
Y = m + \frac{1}{\sqrt{MN}}
\begin{pmatrix}
\left( h_r - m_r \right) \ast X\\
\left( h_g - m_g \right) \ast X\\
\left( h_b - m_b \right) \ast X
\end{pmatrix},~~~X\in\mathbb{R}^{M\times N}~\text{a Gaussian white noise},
$$
is the ADSN associated with $h$.

In [None]:
# Compute ADSN initialization. 
def adsn(input):
    # input is supposed to be a tensor of size 1 x c x h x w
    tnsr = input.squeeze(0)
    c, h, w = tnsr.size() 
    m = torch.mean(tnsr, axis=(1,2))
    X = torch.randn(h, w).to(device)
    Y = torch.empty_like(tnsr)
    sqrtHW = np.sqrt(h*w)

    for i in range(c):
        tnsrnorm = (tnsr[i,:,:]-m[i])/sqrtHW
        Y[i,:,:] = torch.real(torch.fft.ifft2(torch.fft.fft2(tnsrnorm) * torch.fft.fft2(X))) + m[i]

    return Y.unsqueeze(0)

In [None]:
######################################################
### This section allows you to change target image ###
######################################################
# Select input image: 
#  ["bois.png", "briques.png", "mur.png", "tissu.png", 
#   "nuages.png","pebbles.jpg","wall1003.png", "osier12.png",
#   "paille17c2.png","bark.png","coffee.png","flower.png",
#   "rock.png","sweet.png"]
input_image_name = "coffee.png"#"sweet.png"#"bark.png"##"rock.png"#"paille17c2.png"
img_size = 256

# Prepare texture data
target = prep_img(input_image_name, img_size).to(device)
######################################################

# set seed to reproduce examples
torch.manual_seed(123)

#init image with adsn from target image (normalized)
x = adsn(target)

# print images
display(to_pil(torch.cat((target, x), axis=3)))
plt.pause(0.05)

## Running Texture Synthesis

Finally, we must run code that performs the texture synthesis. For each iteration, we update the image $x$ and then update the weights $\theta$.

**Exercise:** 
1. Fill in the ```#TODO``` segments of the code to perform texture synthesis using the SOUL algorithm. The formula are rewritten below.
2. What is the dimension $p$ here? 
2. What hapens when $\gamma$ is inscreased?
3. Is the ADSN initialization important?


**SOUL algorithm**

* Initialization: $\theta \leftarrow 0$; $X_0^0 \in \mathbb{R}^d$ an ADSN realization
* For $n = 1, \ldots, N$,

  * $m_n = 1$ step of Langevin diffusion: for $k=0,\ldots,m_n-1$, 
         
  $$
  X_{k+1}^n = X_k^n - \gamma_{n+1} \nabla_x V(X_k^n,\theta_n) + \sqrt{2\gamma_{n+1}} Z_{k+1}^n
  $$
  with $Z_{k+1}^n \sim \mathcal{N}(0,I)$

 * Update $\theta$ with Langevin intermediary states:

  $$
  \theta_{n+1} = \mathsf{Proj}_{\Theta}\left( \theta_n + \frac{\delta_{n+1}}{m_n} \sum_{k=1}^{m_n} f(X_k^n) - f(x_0)\right)
  $$

 * Set warm start for next step: $X_0^{n+1} = X_{m_n}^n$


**In Practice:**
* The initialization $X_0^0$ is an ADSN realization (see below).
* $f:\mathbb{R}^d \to \mathbb{R}^p$ is the spatial average of the feature responses of each selected layers **multiplied by $\beta = 128$**.
* We do not use projections.
* $\epsilon = 0.1$
* $\gamma$ and $\delta$ are fixed (see value in code below).
* $m = 1$, so we do one update of $x$, one update of $\theta$, and so on. 
* Use the layers [1, 3, 6, 8, 11, 13, 15, 24, 26, 31] for optimization.

In [None]:
n_iters = 3000
log_every = n_iters//10

# steps and epsilon
delta = 10e-1
gamma = 2*10e-6
epsilon = 0.1

# Compute just once means spatial of activations of the target image. 
cnn(target)
meansTargetOutputs = [mean_Spatial(outputs[key]) for key in layers] 

# initialize weights (Theta)
theta = [torch.zeros_like(meansTargetOutputs[i]) for i in range(len(meansTargetOutputs))]

# Initialize list of intermediary images
xpil_list = []

# Forward pass using x. Get activations of selected layers for image x (outputs).
x.requires_grad=True
cnn(x)
x_outputs = [outputs[key] for key in layers]

for iter in range(n_iters):
    
    if x.grad is not None:
        x.grad.zeros_()
  
    # Compute V and its gradient with respect to x:

    # TODO

    # update image
    with torch.no_grad():
        
        # TODO

    # Forward pass using x. Get activations of selected layers for image x (outputs).
    cnn(x)
    x_outputs = [outputs[key] for key in layers] 

    # update weights thetas:
    with torch.no_grad():
        for i in range(len(layers)):
            
            # TODO

    # Display results: print Loss value and show image
    if (iter==0 or iter % log_every == log_every-1):
        print('Iteration: ', iter)
        display(to_pil(torch.cat((target, x), axis=3)))
        # Store for comparison:
        xpil_list.append(to_pil(x.clone().detach()))

In [None]:
from google.colab import widgets

def compare_images(imgs):
    labels = ['image ' + str(i) for i in range(len(imgs))]
    tb = widgets.TabBar(labels, location='top')
    for i, img in enumerate(imgs):
        with tb.output_to(i, select=(i == 0)):
            display(img)


compare_images(xpil_list)
