**First order statistical mixing**

In [1]:
from utils.image_treatment import preprocess_image
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from wasserstein.texture_mixing import compute_steerable_pyramid_coeffs
from wasserstein.basic_wasserstein import compute_sliced_wass_barycenter


In [131]:
image_path1 = 'data/elephant.jpg'
image_path2 = 'data/gateau.png'
image1 = np.array(preprocess_image(image_path1, new_size = (200, 200)))
image2 = np.array(preprocess_image(image_path2, new_size = (70, 70)))

In [132]:
import pyrtools as pt
def compute_steerable_pyramid_coeffs(image, num_scales=3, num_orientations=4):
    """
    Compute steerable pyramid coefficients with specified orientations using pyrtools.
    
    Parameters:
    - image: 2D numpy array, input grayscale image.
    - num_scales: int, number of scales.
    - num_orientations: int, number of orientations.

    Returns:
    - coeffs: Dictionary of coefficients organized by scale and orientation.
    """
    # Initialize the steerable pyramid
    pyramid = pt.pyramids.SteerablePyramidFreq(image, height=num_scales, order=num_orientations-1)

    return pyramid.pyr_coeffs

In [133]:
def compute_3D_wavelets_coeffs(image, num_scales=4, num_orientations=4):
    """
    Compute wavelets coefficients (highpass, bandpass, low-residuals) for the 3 channels (R,G,B) of an image
    
    Parameters:
    - image: 2D numpy array, input grayscale image.
    - num_scales: int, number of scales.
    - num_orientations: int, number of orientations.

    Returns:
    - wavelets_coeffs: Dictionary of coefficients organized by channel (R,G,B) and then by bandpass (highpass, bandpass -scale and orientation- and low residual).
    """
    wavelets_coeffs = {}
    rgb = ['R','G','B']
    for channel in range(3):
        wavelets_coeffs[rgb[channel]] = compute_steerable_pyramid_coeffs(image[:, :, channel], num_scales=num_scales, num_orientations=num_orientations)
    return(wavelets_coeffs)

In [135]:
def compute_wavelet_coeffs_barycenter(textures, num_scales=4, num_orientations=4):
    """
    Compute the barycenter of wavelets coefficients --> Y^l (see page 9 in paper)
    
    Parameters:
    - textures: 3D numpy array, input RGB image.
    - num_scales: int, number of scales.
    - num_orientations: int, number of orientations.

    Returns:
    - bar_wavelet_coeffs_RGB: Dictionary of barycenters of wavelets coefficients by channel (R,G,B) and then by highpass/bandpass/lowresidual
    """

    bar_wavelet_coeffs = {}
    bar_wavelet_coeffs_RGB = {}
    RGB = ['R','G','B']
    wavelets_coeffs = [compute_3D_wavelets_coeffs(image) for image in textures]

    for rgb in RGB:
        for k in wavelets_coeffs[0][rgb].keys():
                distributions = [w[rgb][k].reshape(-1,1) for w in wavelets_coeffs] #reshape --> flattens the image to compute the barycenter
                n = int(np.sqrt(distributions[0].shape[0]))
                bar_wavelet_coeffs[k] = (compute_sliced_wass_barycenter(distributions, rho = None)).reshape(n, n)
        bar_wavelet_coeffs_RGB[rgb] = bar_wavelet_coeffs
    return(bar_wavelet_coeffs_RGB)

In [136]:
textures = [image1, image2]
bar_wavelet_coeffs_RGB = compute_wavelet_coeffs_barycenter(textures, num_scales=4, num_orientations=4)

{'R': {'residual_highpass': array([[-1.14479059, -0.37007428,  0.39465748, ...,  6.00513359,
          -0.63553682,  0.13611842],
         [-0.2801428 , -0.85469131,  1.17545808, ..., -0.13249826,
          -0.39719129, -0.23343904],
         [-0.45518237,  0.43088376, -0.11677042, ...,  0.58350216,
          -0.61393761, -0.65791012],
         ...,
         [ 0.56779313,  0.20122172,  4.74828202, ..., -0.28108558,
           2.60629657, -0.03016736],
         [-1.04752916, -5.43705959, -0.97245676, ...,  0.01720182,
          -1.32783668, -0.73037246],
         [-0.33240537, -2.50300734, -0.15890916, ...,  0.0325533 ,
          -0.725984  , -0.40818318]]),
  (0,
   0): array([[-0.5174053 , -6.5924972 ,  0.62550412, ...,  0.0908223 ,
          -0.86508505, -1.29834396],
         [ 2.24425856, -4.30669182,  0.35940535, ...,  0.74088539,
           1.4323344 , -0.05480647],
         [-0.09834923,  1.82413755,  0.34454769, ..., -2.71581447,
           1.82970852, -0.60624482],
         ..

In [None]:
bar_wavelet_coeffs_RGB
#il ne manque plus que le barycentre "classique" i.e. sur les pixels
#next step --> la projection (13) du papier