In [12]:
from utils.image_treatment import preprocess_image
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from wasserstein.basic_wasserstein import compute_sliced_wass_barycenter
import pyrtools as pt
from tqdm import tqdm 
import pickle

In [7]:
image_path1 = 'data/elephant.jpg'
image_path2 = 'data/gateau.png'

In [8]:
image1 = preprocess_image(image_path1, new_size = (75,75))
image2 = preprocess_image(image_path2, new_size = (75,75))
textures = [image1, image2]

In [189]:
def compute_steerable_pyramid_coeffs(image, num_scales=3, num_orientations=4):
    """
    Compute steerable pyramid coefficients with specified orientations using pyrtools.
    
    Parameters:
    - image: 2D numpy array, input grayscale image.
    - num_scales: int, number of scales.
    - num_orientations: int, number of orientations.

    Returns:
    - coeffs: Dictionary of coefficients organized by scale and orientation.
    """
    # Initialize the steerable pyramid
    pyramid = pt.pyramids.SteerablePyramidFreq(image, height=num_scales, order=num_orientations-1)

    return pyramid.pyr_coeffs, pyramid

In [190]:
coef, pyr = compute_steerable_pyramid_coeffs(image1[:,:,1], num_scales=3, num_orientations=4)

In [10]:
def compute_3D_wavelets_coeffs(image, num_scales=4, num_orientations=4):
    """
    Compute wavelets coefficients (highpass, bandpass, low-residuals) for the 3 channels (R,G,B) of an image
    
    Parameters:
    - image: 2D numpy array, input grayscale image.
    - num_scales: int, number of scales.
    - num_orientations: int, number of orientations.

    Returns:
    - wavelets_coeffs: Dictionary of coefficients organized by channel (R,G,B) and then by bandpass (highpass, bandpass -scale and orientation- and low residual).
    """
    wavelets_coeffs = {}
    rgb = ['R','G','B']
    for channel in range(3):
        wavelets_coeffs[rgb[channel]] = compute_steerable_pyramid_coeffs(image[:, :, channel], num_scales=num_scales, num_orientations=num_orientations)
    return(wavelets_coeffs)

In [11]:
def compute_wavelet_coeffs_barycenter(textures, num_scales=4, num_orientations=4):
    """
    Compute the barycenter of wavelets coefficients --> Y^l (see page 9 in paper)
    
    Parameters:
    - textures: 3D numpy array, input RGB image.
    - num_scales: int, number of scales.
    - num_orientations: int, number of orientations.

    Returns:
    - bar_wavelet_coeffs_RGB: Dictionary of barycenters of wavelets coefficients by channel (R,G,B) and then by highpass/bandpass/lowresidual
    """

    bar_wavelet_coeffs = {}
    bar_wavelet_coeffs_RGB = {}
    RGB = ['R','G','B']
    wavelets_coeffs = [compute_3D_wavelets_coeffs(image) for image in textures]

    for rgb in RGB:
        for k in wavelets_coeffs[0][rgb].keys():
                distributions = [w[rgb][k].reshape(-1,1) for w in wavelets_coeffs] #reshape --> flattens the image to compute the barycenter
                n = int(np.sqrt(distributions[0].shape[0]))
                bar_wavelet_coeffs[k] = (compute_sliced_wass_barycenter(distributions, rho = None)).reshape(n, n)
        bar_wavelet_coeffs_RGB[rgb] = bar_wavelet_coeffs
    return(bar_wavelet_coeffs_RGB)

In [None]:
Y_l = compute_wavelet_coeffs_barycenter(textures, num_scales=4, num_orientations=4)

In [13]:
with open('Yl.pkl', 'rb') as file:
        # Loading the data from the pickle file
        Y_l = pickle.load(file)

In [17]:
def initialize_random_image(size=(256, 256), channels=3):
    """Initialize a random white noise image f^(0)."""
    return np.random.rand(*size, channels)

In [137]:
def projection(X0, Y):
    """
    X0 : image a projeter 
    Y : ce sur quoi on veut projeter 
    """
    Y_distrib = [Y]
    proj = compute_sliced_wass_barycenter(Y_distrib, rho = None, lr = 1e3, k = 200,  nb_iter_max = 50, xbinit = X0)
    return(proj)

In [67]:
def flatten_dict(d, parent_key='', sep='.'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

In [82]:
def from_dict_to_RGB_array(dico):

    output_dict = {}

    # Extract unique filter types by parsing keys
    filters = set(key.split('.')[1] for key in dico.keys())

    # Iterate through each unique filter
    for filter_type in filters:
        # Collect the R, G, B arrays for the current filter
        r_key = f'R.{filter_type}'
        g_key = f'G.{filter_type}'
        b_key = f'B.{filter_type}'
        
        if r_key in dico and g_key in dico and b_key in dico:
            # Stack R, G, B arrays along a new third axis to form a (P, Q, 3) array
            rgb_array = np.stack([dico[r_key], dico[g_key], dico[b_key]], axis=-1)
            
            # Add to output dictionary
            output_dict[filter_type] = rgb_array
    return(output_dict)


In [92]:
f_0 = initialize_random_image(size=(75, 75), channels=3)
f_0_pyr = compute_3D_wavelets_coeffs(f_0, num_scales=4, num_orientations=4)
f_0_pyr = flatten_dict(f_0_pyr)
f_0_pyr = from_dict_to_RGB_array(f_0_pyr)
#Y_l = from_dict_to_RGB_array(Y_l)

In [141]:
c_l_n = {}
for key in Y_l.keys():
    Y_l_filter = Y_l[key].reshape(-1,3)
    f_0_filter = f_0_pyr[key].reshape(-1,3)
    c_l_n[key] = projection(f_0_filter, Y_l_filter)

In [172]:
import ast

def from_str_to_tuple_dict(data_dict):

    # New dictionary to hold the modified keys
    modified_dict = {}

    # Iterate over the original dictionary
    for key, value in data_dict.items():
        # Try to convert keys that look like tuples into actual tuple types
        try:
            # Use ast.literal_eval to safely evaluate keys that are tuples as strings
            evaluated_key = ast.literal_eval(key)
            if isinstance(evaluated_key, tuple):
                modified_dict[evaluated_key] = value
            else:
                modified_dict[key] = value
        except (ValueError, SyntaxError):
            # If the key isn't a tuple-like string, keep it as-is
            modified_dict[key] = value
    return(modified_dict)

In [174]:
c_l_n = from_str_to_tuple_dict(c_l_n)

In [185]:
dict_r = {}
dict_g = {}
dict_b = {}

for key in c_l_n.keys():
    size = int(np.sqrt(c_l_n[key].shape[0]))
    dict_r[key] = c_l_n[key][:,0].reshape(size, size)
    dict_g[key] = c_l_n[key][:,1].reshape(size, size)
    dict_b[key] = c_l_n[key][:,2].reshape(size, size)


In [195]:
dict_r[(0,0)].shape

(75, 75)

In [197]:
pyr.pyr_coeffs = dict_r
pyr.pyr_coeffs

{(3,
  1): array([[  14.19230186,   46.70405665,   71.27683316,  -57.35268989,
           13.46408791,   49.98501312,  -33.37068928, -105.67072642,
          -43.23587344,   79.46617974],
        [  21.58992101,   12.80287062,  -21.16833046,  -10.77452295,
           86.86817636,   85.20291108,   48.99416248,   69.34879306,
           68.54319429,   38.92192985],
        [  16.55827236,  -15.51027766,   34.31417205,   19.0407221 ,
          112.01523795,   82.74732316,   42.21511011,  -75.67416764,
         -125.65902971,   27.45722446],
        [ -33.24243271,  -48.17982245,   23.71950311,   64.00167771,
          147.91754396,  -15.7987565 , -134.57589892, -221.12465117,
          -79.36403197,  118.51880892],
        [ -39.74454486,   -3.14199841,   13.91118492,   67.51149057,
          -38.77515347, -187.3571599 , -147.22091793,  -19.78122738,
          156.9788858 ,   77.81293078],
        [  14.42448848,  -15.70466101,   -3.47550313,  -68.74937402,
         -175.63200416, -106.78

In [139]:
def compute_sliced_wass_barycenter(distributions, rho = None, lr = 1e3, k = 200,  nb_iter_max = 50, xbinit = None):

    device = "cuda" if torch.cuda.is_available() else "cpu"

    x_torch = [torch.tensor(x).to(device=device) for x in distributions]

    if rho is None: 
        n = len(distributions)
        rho = n*[1/n]
    
    if xbinit is None:
        #xbinit = np.random.randn(500, 2) * 10 + 16 #initialization
        xbinit = np.random.normal(0., 1., distributions[0].shape)
    xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)


    x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], xbary_torch.shape[1]))

    loss_iter = []

    # generator for random permutations
    gen = torch.Generator(device=device)
    gen.manual_seed(42)


    for i in range(nb_iter_max):

        loss = 0
        for i, x in enumerate(x_torch):
            loss += rho[i] * ot.sliced_wasserstein_distance(xbary_torch, x, n_projections=50, seed=gen)
        loss_iter.append(loss.clone().detach().cpu().numpy())
        loss.backward()

        # performs a step of projected gradient descent
        with torch.no_grad():
            grad = xbary_torch.grad
            xbary_torch -= grad * lr  # / (1 + i / 5e1)  # step
            xbary_torch.grad.zero_()
            x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy()

    xb = xbary_torch.clone().detach().cpu().numpy()
    return(xb)