In [None]:
%load_ext autoreload
%autoreload 2

# **Learning Gradients of Convex Functions with monotone Gradient networks**

Thomas Gravier, Emilio Picard

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import os, sys

In [None]:
SEED = 7

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.cuda.manual_seed(SEED)  # only if cuda is available

### **Toy Example of the paper**
The purpose here is to approximate the gradient of a convex function $f$.

Let $f$ be $f(x) = x_1^4 + \frac{x_2}{2} + \frac{x_1x_2}{2} + \frac{3x_2^2}{2} - \frac{x_2^3}{3}$, where $x = [x_1, x_2]^T$.

We want to approximate $\nabla f(x) = [4x_1^3 + \frac{x_2}{2}, \frac{1}{2}+ \frac{x_1}{2} + 3x_2 - x_2^2]$ with several methods of the paper.

In [None]:
def f(x1, x2):
    return x1**4 + (x2**2)/2 + (x1*x2)/2 + 3*(x2**2)/2 - (x2**3)/3

def grad_f(z): 
    x1 = z[:, 0]
    x2 = z[:, 1]
    return torch.stack([4*x1**3 + x2/2, 1/2 + x1/2 + 3*x2 - x2**2]).T

In [None]:
# WHAT WE WANT TO APPROXIMATE

x = torch.linspace(0, 1, 15)
y = torch.linspace(0, 1, 15)

X, Y = torch.meshgrid(x, y)
space = torch.cat([torch.reshape(X,(-1,1)),torch.reshape(Y,(-1,1))],1)
grad = grad_f(space)

cmap = 'plasma'
levels = 30

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=300)

contour = axs[0].contourf(x, y, f(space[:,0], space[:,1]).view(x.numel(),x.numel()).T, levels=levels, cmap=cmap)
axs[0].set_xticks([0, 0.5, 1.0])
axs[0].set_yticks([0, 0.5, 1.0])
axs[0].tick_params(axis='both', length=20, which='major')
cbar = fig.colorbar(contour, ax=axs[0])
cbar.set_label('Function Value')
cbar.ax.tick_params()

quiver = axs[1].quiver(space[:,0], space[:,1], grad[:,0], grad[:,1], grad.norm(dim=-1), cmap=cmap)
axs[1].set_xticks([0, 0.5, 1.0])
axs[1].set_yticks([0, 0.5, 1.0])
axs[1].tick_params(axis='both', length=20, which='major')
cbar = fig.colorbar(quiver, ax=axs[1])
cbar.set_label('Gradient Norm')
cbar.ax.tick_params()

plt.tight_layout()
plt.show()
plt.close()

In [None]:
from torch.utils.data import TensorDataset, random_split, DataLoader
from models import CascadeGradNet, ModularGradNet

# get data: sample points from the unit square
sampled_points = torch.rand((1000000, 2))

DATASET = TensorDataset(sampled_points, grad_f(sampled_points))
split = int(0.8 * len(DATASET))
TRAIN_SET, VAL_SET = random_split(DATASET, [split, len(DATASET) - split])
TRAIN_LOADER = DataLoader(TRAIN_SET, batch_size=5000, shuffle=True)
VAL_LOADER = DataLoader(VAL_SET, batch_size=5000, shuffle=False)

NUM_LAYERS, NUM_MODULES = 3, 3
IN_DIM = 2
EMBED_DIM = 8
ACTIVATTION = lambda : nn.Tanh()

# CascadeGradNet
model1 = CascadeGradNet(
    num_layers=NUM_LAYERS,
    in_dim=IN_DIM,
    embed_dim=EMBED_DIM,
    activation=ACTIVATTION
).to(DEVICE)
print(f"number of learnable parameters: {sum(p.size().numel() for p in model1.parameters() if p.requires_grad)}")

Let's Train the Gradients Networks for the toy example of the paper:

In [None]:
from tqdm import tqdm

def train_model(model, train_loader, val_loader, grad_f, model_name):
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    criterion = nn.MSELoss()

    global DEVICE, NUM_EPOCHS

    model.to(DEVICE)
    train_losses, val_losses = [], []
    best_val_loss = 1
    for epoch in tqdm(range(NUM_EPOCHS), desc=f"Training"):
        model.train()
        train_loss = 0
        for x, grads in train_loader:
            x, grads = x.to(DEVICE), grads.to(DEVICE)
            outputs = model(x)
            loss = criterion(outputs, grads)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_losses.append(train_loss / len(train_loader))

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, grads in val_loader:
                x, grads = x.to(DEVICE), grads.to(DEVICE)
                outputs = model(x)
                loss = criterion(outputs, grads)
                val_loss += loss.item()
        val_losses.append(val_loss / len(val_loader))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"{model_name}.pt")

        if epoch+1 % 100 == 0:
            print("Epoch: ", epoch, "Train Loss: ", train_loss / len(train_loader), "Val Loss: ", val_loss / len(val_loader))

    return train_losses, val_losses

In [None]:
NUM_EPOCHS = 100
print("Training CascadeGradNet")
print("device used: ", DEVICE)
train_losses_C, val_losses_C = train_model(
        model1,
        train_loader=TRAIN_LOADER,
        val_loader=VAL_LOADER,
        grad_f=grad_f,
        model_name='CascadeGradNet'
    )

In [None]:
def compute_l2_error_map(model, space, grad_f):
    model.eval()
    with torch.no_grad():
        true_grads = grad_f(space).to(DEVICE)
        predicted_grads = model(space.to(DEVICE))
        l2_error = F.mse_loss(predicted_grads, true_grads, reduction='none').sum(dim=1).cpu()
    return l2_error

l2_error_map_C = compute_l2_error_map(model1, space, grad_f)
l2_error_map_C = l2_error_map_C.view(x.numel(), y.numel()).T

fig, ax = plt.subplots(figsize=(6, 3), dpi=300)
contour = ax.contourf(x, y, l2_error_map_C, levels=levels, cmap=cmap)
ax.set_xticks([0, 0.5, 1.0])
ax.set_yticks([0, 0.5, 1.0])
ax.tick_params(axis='both', length=10)
cbar = fig.colorbar(contour, ax=ax)
cbar.set_label('L2 Error')
cbar.ax.tick_params()

plt.tight_layout()
plt.show()
plt.close()

In [None]:
predicted_grad = model1(space.to(DEVICE)).detach().cpu()

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=300)

quiver_true = axs[0].quiver(space[:,0], space[:,1], grad[:,0], grad[:,1], grad.norm(dim=-1), cmap=cmap)
axs[0].set_title('True Gradients')
axs[0].set_xticks([0, 0.5, 1.0])
axs[0].set_yticks([0, 0.5, 1.0])
axs[0].tick_params(axis='both', length=20, which='major')
cbar_true = fig.colorbar(quiver_true, ax=axs[0])
cbar_true.set_label('Gradient Norm')
cbar_true.ax.tick_params()

quiver_pred = axs[1].quiver(space[:,0].cpu().numpy(), space[:,1].cpu().numpy(), predicted_grad[:,0].numpy(), predicted_grad[:,1].numpy(), predicted_grad.norm(dim=-1).numpy(), cmap=cmap)
axs[1].set_title('Predicted Gradients')
axs[1].set_xticks([0, 0.5, 1.0])
axs[1].set_yticks([0, 0.5, 1.0])
axs[1].tick_params(axis='both', length=20, which='major')
cbar_pred = fig.colorbar(quiver_pred, ax=axs[1])
cbar_pred.set_label('Gradient Norm')
cbar_pred.ax.tick_params()

plt.tight_layout()
plt.show()
plt.close()

## Toy Example 2 for another convex function:

$h(x) = \log(\exp(5x_1) + \exp(2x_2))$

$J_h(x) = [\frac{5\exp(5x_1)}{\exp(5x_1) + \exp(2x_2)}, \frac{2\exp(2x_2)}{\exp(5x_1) + \exp(2x_2)}]$

In [None]:
def h(x1, x2):
    return torch.log(torch.exp(5 * x1) + torch.exp(2 * x2))

def grad_h(z): 
    x1 = z[:, 0]
    x2 = z[:, 1]
    return torch.stack([(5 * torch.exp(5 * x1)) / (torch.exp(5 * x1) + torch.exp(2 * x2)), 
                        (2 * torch.exp(2 * x2)) / (torch.exp(5 * x1) + torch.exp(2 * x2))]).T

In [None]:
# WHAT WE WANT TO APPROXIMATE

x = torch.linspace(0, 1, 15)
y = torch.linspace(0, 1, 15)

X, Y = torch.meshgrid(x, y)
space = torch.cat([torch.reshape(X,(-1,1)),torch.reshape(Y,(-1,1))],1)
grad = grad_h(space)

cmap = 'plasma'
levels = 30

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=300)

contour = axs[0].contourf(x, y, h(space[:,0], space[:,1]).view(x.numel(),x.numel()).T, levels=levels, cmap=cmap)
axs[0].set_xticks([0, 0.5, 1.0])
axs[0].set_yticks([0, 0.5, 1.0])
axs[0].tick_params(axis='both', length=20, which='major')
cbar = fig.colorbar(contour, ax=axs[0])
cbar.set_label('Function Value')
cbar.ax.tick_params()

quiver = axs[1].quiver(space[:,0], space[:,1], grad[:,0], grad[:,1], grad.norm(dim=-1), cmap=cmap)
axs[1].set_xticks([0, 0.5, 1.0])
axs[1].set_yticks([0, 0.5, 1.0])
axs[1].tick_params(axis='both', length=20, which='major')
cbar = fig.colorbar(quiver, ax=axs[1])
cbar.set_label('Gradient Norm')
cbar.ax.tick_params()

plt.tight_layout()
plt.show()
plt.close()

In [None]:
# ModularGradNet
model2 = ModularGradNet(
    num_modules=NUM_MODULES,
    in_dim=IN_DIM,
    embed_dim=EMBED_DIM,
    activation=ACTIVATTION
).to(DEVICE)
print(f"number of learnable parameters: {sum(p.size().numel() for p in model2.parameters() if p.requires_grad)}")

In [None]:
NUM_EPOCHS = 100
print("Training ModularGradNet")
print("device used: ", DEVICE)
train_losses_C, val_losses_C = train_model(
        model2,
        train_loader=TRAIN_LOADER,
        val_loader=VAL_LOADER,
        grad_f=grad_f,
        model_name='ModularGradNet'
        
    )

In [None]:
l2_error_map_M = compute_l2_error_map(model2, space, grad_h)
l2_error_map_M = l2_error_map_C.view(x.numel(), y.numel()).T

# Plot the L2-error map
fig, ax = plt.subplots(figsize=(6, 3), dpi=300)
contour = ax.contourf(x, y, l2_error_map_M, levels=levels, cmap=cmap)
ax.set_xticks([0, 0.5, 1.0])
ax.set_yticks([0, 0.5, 1.0])
ax.tick_params(axis='both', length=10)
cbar = fig.colorbar(contour, ax=ax)
cbar.set_label('L2 Error')
cbar.ax.tick_params()

plt.tight_layout()
plt.show()
plt.close()

In [None]:
predicted_grad = model2(space.to(DEVICE)).detach().cpu()

fig, axs = plt.subplots(1, 2, figsize=(12, 4), dpi=300)

quiver_true = axs[0].quiver(space[:,0], space[:,1], grad[:,0], grad[:,1], grad.norm(dim=-1), cmap=cmap)
axs[0].set_title('True Gradients')
axs[0].set_xticks([0, 0.5, 1.0])
axs[0].set_yticks([0, 0.5, 1.0])
axs[0].tick_params(axis='both', length=20, which='major')
cbar_true = fig.colorbar(quiver_true, ax=axs[0])
cbar_true.set_label('Gradient Norm')
cbar_true.ax.tick_params()

quiver_pred = axs[1].quiver(space[:,0].cpu().numpy(), space[:,1].cpu().numpy(), predicted_grad[:,0].numpy(), predicted_grad[:,1].numpy(), predicted_grad.norm(dim=-1).numpy(), cmap=cmap)
axs[1].set_title('Predicted Gradients')
axs[1].set_xticks([0, 0.5, 1.0])
axs[1].set_yticks([0, 0.5, 1.0])
axs[1].tick_params(axis='both', length=20, which='major')
cbar_pred = fig.colorbar(quiver_pred, ax=axs[1])
cbar_pred.set_label('Gradient Norm')
cbar_pred.ax.tick_params()

plt.tight_layout()
plt.show()
plt.close()

## Train for transport of distribution

In [None]:
def train_model(model, train_loader, val_loader, CRITERION, NUM_EPOCHS, DEVICE, model_name='model'):
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    train_losses, val_losses = [], []
    best_val_loss = float('inf')

    for epoch in tqdm(range(NUM_EPOCHS), desc=f"Training {model_name}"):
        model.train()
        train_loss = 0

        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            outputs = model(x)
            loss = CRITERION(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_losses.append(train_loss / len(train_loader))

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(DEVICE), y.to(DEVICE)
                outputs = model(x)
                loss = CRITERION(outputs, y)
                val_loss += loss.item()

        val_losses.append(val_loss / len(val_loader))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"{model_name}.pt")

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {train_losses[-1]:.6f} - Val Loss: {val_losses[-1]:.6f}")

## **Transport Distributions: from Gaussian to Banana Shape**

In [None]:
n_samples = 2000
b = 1
c = 0.5

X_torch = torch.randn(n_samples)
Y_torch = torch.randn(n_samples)

Y_banana_torch = Y_torch + b * (X_torch**2 - c)
theta_torch = torch.tensor(np.pi / 5, dtype=torch.float32, device=X_torch.device)
rotation_matrix_torch = torch.tensor([
    [torch.cos(theta_torch), -torch.sin(theta_torch)],
    [torch.sin(theta_torch), torch.cos(theta_torch)]
], dtype=torch.float32, device=X_torch.device)

points_torch = torch.stack((X_torch, Y_banana_torch))
rotated_points_torch = torch.matmul(rotation_matrix_torch, points_torch)
X_rotated_torch, Y_rotated_torch = rotated_points_torch.cpu().numpy()

# Distribution gaussienne
mu = np.array([-6, 4])
cov = np.array([[1, 0], [0, 1]])
data2 = np.random.multivariate_normal(mu, cov, n_samples)
X_gaussian_torch = torch.tensor(data2[:, 0], dtype=torch.float32, device=X_torch.device)
Y_gaussian_torch = torch.tensor(data2[:, 1], dtype=torch.float32, device=X_torch.device)

plt.figure(figsize=(8, 6))
plt.scatter(X_gaussian_torch, Y_gaussian_torch, alpha=0.5, s=10, color='red')
plt.scatter(X_rotated_torch, Y_rotated_torch, alpha=0.5, s=10, color='blue')
plt.title("Prior and expected posterior distributions")
plt.xlabel("X")
plt.ylabel("Y")
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)
plt.grid(True, linestyle='--', alpha=0.6)
plt.show(block=True)
plt.close()


In [None]:
tensor_data_gaussian = torch.tensor(data2, dtype=torch.float32)
tensor_data_banana = rotated_points_torch.T

train_dataset = TensorDataset(tensor_data_gaussian[:1600], tensor_data_banana[:1600])
val_dataset = TensorDataset(tensor_data_gaussian[1600:], tensor_data_banana[1600:])
num_workers = min(4, os.cpu_count())

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


### CascadeGradNet

In [None]:
from geomloss import SamplesLoss
CRITERION = SamplesLoss("sinkhorn", p=2, blur=0.05, backend="tensorized")

NUM_EPOCHS = 1000

NUM_LAYERS = 6
IN_DIM = 2
EMBED_DIM = 16
ACTIVATION = lambda : nn.Tanh()

model1 = CascadeGradNet(
    num_layers=NUM_LAYERS,
    in_dim=IN_DIM,
    embed_dim=EMBED_DIM,
    activation=ACTIVATION
).to(DEVICE)

train_model(model1, train_loader, val_loader, CRITERION, NUM_EPOCHS, DEVICE, model_name='CascadeGradNet')

In [None]:
model1.load_state_dict(torch.load("CascadeGradNetBanana.pt"))
model1.eval()

source_samples = tensor_data_gaussian[:1000].to(DEVICE)

with torch.no_grad():
    transported_samples = model1(source_samples).cpu().numpy()

source_samples = source_samples.cpu().numpy()
target_samples = tensor_data_banana[:1000].cpu().numpy()

sampling_rate = 0.2
num_points = len(source_samples)
num_sampled = max(1, int(sampling_rate * num_points))

sampled_indices = np.random.choice(num_points, num_sampled, replace=False)

plt.figure(figsize=(8, 6))

for i in sampled_indices:
    plt.plot([source_samples[i, 0], transported_samples[i, 0]], 
             [source_samples[i, 1], transported_samples[i, 1]], 
             color='gray', alpha=0.5, linewidth=0.8)

plt.scatter(source_samples[:, 0], source_samples[:, 1], alpha=0.6, s=10, color='blue', label='Source (Gaussienne)')
plt.scatter(transported_samples[:, 0], transported_samples[:, 1], alpha=0.6, s=10, color='red', label='Transported')
plt.scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.6, s=10, color='green', label='Target (Banane)')
plt.title("Transport de la Gaussienne vers les points transportés (20% des points)")
plt.xlabel("X")
plt.ylabel("Y")
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()


In [None]:
from scipy.stats import wasserstein_distance

wd_x = wasserstein_distance(transported_samples[:, 0], target_samples[:, 0])
wd_y = wasserstein_distance(transported_samples[:, 1], target_samples[:, 1])

wasserstein_dist = (wd_x + wd_y) / 2
wasserstein_dist

### M-MGN

In [None]:
from geomloss import SamplesLoss
CRITERION = SamplesLoss("sinkhorn", p=2, blur=0.05, backend="tensorized")

NUM_EPOCHS = 1000

NUM_MODULES = 6
IN_DIM = 2
EMBED_DIM = 16
ACTIVATION = lambda : nn.Tanh()

model1 = ModularGradNet(
    num_modules=NUM_BLOCKS,
    in_dim=IN_DIM,
    embed_dim=EMBED_DIM,
    activation=ACTIVATION
).to(DEVICE)

train_model(model1, train_loader, val_loader, CRITERION, NUM_EPOCHS, DEVICE, model_name='M-GradNet')

In [None]:
model1.load_state_dict(torch.load("M-GradNetBanana.pt"))
model1.eval()

source_samples = tensor_data_gaussian[:1000].to(DEVICE)

with torch.no_grad():
    transported_samples = model1(source_samples).cpu().numpy()

source_samples = source_samples.cpu().numpy()
target_samples = tensor_data_banana[:1000].cpu().numpy()

sampling_rate = 0.2
num_points = len(source_samples)
num_sampled = max(1, int(sampling_rate * num_points))

sampled_indices = np.random.choice(num_points, num_sampled, replace=False)

plt.figure(figsize=(8, 5))

for i in sampled_indices:
    plt.plot([source_samples[i, 0], transported_samples[i, 0]], 
             [source_samples[i, 1], transported_samples[i, 1]], 
             color='gray', alpha=0.5, linewidth=0.8)

plt.scatter(source_samples[:, 0], source_samples[:, 1], alpha=0.6, s=10, color='blue', label='Source (Gaussienne)')
plt.scatter(transported_samples[:, 0], transported_samples[:, 1], alpha=0.6, s=10, color='red', label='Transported')
plt.scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.6, s=10, color='green', label='Target (Banane)')

plt.title("Distribution transport from a Gaussian to a Banana-shaped distribution")
plt.xlabel("X")
plt.ylabel("Y")
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.savefig("M-GradNet_Transport.pdf",dpi=300)
plt.show()

In [None]:
from scipy.stats import wasserstein_distance

wd_x = wasserstein_distance(transported_samples[:, 0], target_samples[:, 0])
wd_y = wasserstein_distance(transported_samples[:, 1], target_samples[:, 1])
asserstein_dist = (wd_x + wd_y) / 2
wasserstein_dist

## **Transport Distributions: from Gaussian to another Gaussian**

In [None]:
n_samples = 2000

mu1 = np.array([-6, 4])
cov1 = np.array([[1, 0], [0, 1]])
data_gaussian1 = np.random.multivariate_normal(mu1, cov1, n_samples)
X_gaussian1 = data_gaussian1[:, 0]
Y_gaussian1 = data_gaussian1[:, 1]

mu2 = np.array([2, -3])
cov2 = np.array([[1, 0.5], [0.5, 1]])
data_gaussian2 = np.random.multivariate_normal(mu2, cov2, n_samples)
X_gaussian2 = data_gaussian2[:, 0]
Y_gaussian2 = data_gaussian2[:, 1]

plt.figure(figsize=(8, 6))
plt.scatter(X_gaussian1, Y_gaussian1, alpha=0.5, s=10, color='red', label='Gaussian 1')
plt.scatter(X_gaussian2, Y_gaussian2, alpha=0.5, s=10, color='green', label='Gaussian 2')
plt.title("Prior and expected posterior distributions")
plt.xlabel("X")
plt.ylabel("Y")
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.show()
plt.close()


In [None]:
tensor_data_gaussian1 = torch.tensor(data_gaussian1, dtype=torch.float32)
tensor_data_gaussian2 = torch.tensor(data_gaussian2, dtype=torch.float32)

train_dataset = TensorDataset(tensor_data_gaussian1[:1600], tensor_data_gaussian2[:1600])
val_dataset = TensorDataset(tensor_data_gaussian1[1600:], tensor_data_gaussian2[1600:])
num_workers = min(4, os.cpu_count())

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


### CascadeGradNet

In [None]:
from geomloss import SamplesLoss
CRITERION = SamplesLoss("sinkhorn", p=2, blur=0.05, backend="tensorized")

NUM_EPOCHS = 500

NUM_LAYERS = 3
IN_DIM = 2
EMBED_DIM = 8
ACTIVATION = lambda : nn.Tanh()

model1 = CascadeGradNet(
    num_layers=NUM_LAYERS,
    in_dim=IN_DIM,
    embed_dim=EMBED_DIM,
    activation=ACTIVATION
).to(DEVICE)

train_model(model1, train_loader, val_loader, CRITERION, NUM_EPOCHS, DEVICE, model_name='CascadeGradNetGaussian')

In [None]:
import matplotlib.pyplot as plt

model1.load_state_dict(torch.load("CascadeGradNetGaussian.pt"))
model1.eval()

source_samples = tensor_data_gaussian1[:1000].to(DEVICE)
with torch.no_grad():
    transported_samples = model1(source_samples).cpu().numpy()

source_samples = source_samples.cpu().numpy()
target_samples = tensor_data_gaussian2[:1000].cpu().numpy()

In [None]:
sampling_rate = 0.2
num_points = len(source_samples)
num_sampled = max(1, int(sampling_rate * num_points))
sampled_indices = np.random.choice(num_points, num_sampled, replace=False)

plt.figure(figsize=(8, 5))

for i in sampled_indices:
    plt.plot([source_samples[i, 0], transported_samples[i, 0]], 
             [source_samples[i, 1], transported_samples[i, 1]], 
             color='gray', alpha=0.5, linewidth=0.8)

plt.scatter(source_samples[:, 0], source_samples[:, 1], alpha=0.6, s=10, color='blue', label='Source')
plt.scatter(transported_samples[:, 0], transported_samples[:, 1], alpha=0.6, s=10, color='red', label='Transported')
plt.scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.6, s=10, color='green', label='Target')
plt.title("Distribution transport between Gaussians")
plt.xlabel("X")
plt.ylabel("Y")
plt.axhline(0, color='gray', linestyle='--', linewidth=0.8)
plt.axvline(0, color='gray', linestyle='--', linewidth=0.8)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)

plt.savefig("transport_gaussian012.pdf",dpi=300)
plt.show()


In [None]:
from scipy.stats import wasserstein_distance

wd_x = wasserstein_distance(transported_samples[:, 0], target_samples[:, 0])
wd_y = wasserstein_distance(transported_samples[:, 1], target_samples[:, 1])

wasserstein_dist = (wd_x + wd_y) / 2
wasserstein_dist

## MNIST Model - Experiments

In [None]:
X_transformed = model(X_data).detach().cpu().numpy()
Y_original = Y_data.detach().cpu().numpy()

In [None]:
wasserstein_distances = [
    wasserstein_distance(X_transformed[:, i], Y_original[:, i])
    for i in range(X_transformed.shape[1])
]
wasserstein_score = np.mean(wasserstein_distances)
print(f"Distance de Wasserstein moyenne entre les distributions des '1' transformés et des '7' : {wasserstein_score:.4f}")

# **Transport digit from 1 to 2, 2 to 3, etc.**

## Meme chose avec M-MGN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from geomloss import SamplesLoss
from torch.utils.data import random_split, DataLoader, TensorDataset

def init_W(embed_dim, in_dim):
    return torch.randn(embed_dim, in_dim) * 0.01

def init_b(size):
    return torch.zeros(size)

# ----- Modèle ModularGradNet -----
class Module_ModularGN(nn.Module):
    def __init__(self, in_dim, embed_dim, activation):
        super().__init__()
        self.beta = nn.Parameter(torch.rand(1), requires_grad=True)
        self.W = nn.Parameter(init_W(embed_dim, in_dim), requires_grad=True)
        self.b = nn.Parameter(init_b(embed_dim), requires_grad=True)
        self.act = activation()

    def forward(self, x):
        z = F.linear(x, weight=self.W, bias=self.b)
        z = self.act(z * F.softplus(self.beta))
        z = F.linear(z, weight=self.W.T)
        return z

class ModularGradNet(nn.Module):
    def __init__(self, num_modules, in_dim, embed_dim, activation):
        super().__init__()
        self.num_modules = num_modules
        self.mmgn_modules = nn.ModuleList([Module_ModularGN(in_dim, embed_dim, activation) for _ in range(num_modules)])
        self.alpha = nn.Parameter(torch.randn(num_modules), requires_grad=True)
        self.bias = nn.Parameter(init_b(in_dim), requires_grad=True)

    def forward(self, x):
        out = 0
        for i in range(self.num_modules):
            out += self.alpha[i] * self.mmgn_modules[i](x)
        out += self.bias
        return out

# ----- Utils -----
def show_comparison_grid(inputs, outputs, title=''):
    fig, axes = plt.subplots(2, len(inputs), figsize=(len(inputs) * 2, 4))
    for i in range(len(inputs)):
        axes[0, i].imshow(inputs[i].reshape(28, 28).cpu().numpy(), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(outputs[i].reshape(28, 28).cpu().numpy(), cmap="gray")
        axes[1, i].axis("off")
    axes[0, 0].set_ylabel("Input", fontsize=12)
    axes[1, 0].set_ylabel("Output", fontsize=12)
    if title:
        plt.suptitle(title)
    plt.tight_layout()
    plt.show()

def extract_digits(dataset, digit, n):
    return torch.stack([img.view(-1) for img, label in dataset if label == digit][:n])

def get_single_digit_samples(dataset, digits):
    samples = []
    for d in digits:
        for img, label in dataset:
            if label == d:
                samples.append(img.view(-1))
                break
    return torch.stack(samples)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_full = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_train, mnist_val = random_split(mnist_full, [50000, 10000])

digit_pairs = [(i, (i + 1) % 10) for i in range(10)]
samples_per_digit = 2000

all_source_images = []
all_target_images = []

for src, tgt in digit_pairs:
    src_imgs = extract_digits(mnist_train, src, samples_per_digit)
    tgt_imgs = extract_digits(mnist_train, tgt, samples_per_digit)
    all_source_images.append(src_imgs)
    all_target_images.append(tgt_imgs)

X_all = torch.cat(all_source_images, dim=0)
Y_all = torch.cat(all_target_images, dim=0)

train_dataset = TensorDataset(X_all, Y_all)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

# ----- Initialisation -----
generator = ModularGradNet(num_modules=20, in_dim=28*28, embed_dim=256, activation=nn.Tanh)
optimizer = optim.Adam(generator.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
loss_fn = SamplesLoss("sinkhorn", p=2, blur=0.03)

num_epochs = 200
for epoch in range(num_epochs):
    generator.train()
    total_loss = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        g_x = generator(x_batch)
        loss = loss_fn(g_x, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.6f}")

    if (epoch + 1) % 20 == 0 or epoch == num_epochs - 1:
        with torch.no_grad():
            test_inputs = torch.cat([imgs[:1] for imgs in all_source_images], dim=0)
            test_outputs = generator(test_inputs)
            show_comparison_grid(test_inputs, test_outputs, title=f"Epoch {epoch+1} - Transport 0→1 ... 9→0")

print("\n Test final (données jamais vues) :")
input_digits = list(range(10))
val_inputs = get_single_digit_samples(mnist_val, input_digits)

with torch.no_grad():
    val_outputs = generator(val_inputs)

transformations = [f"{i}→{(i+1)%10}" for i in range(10)]
input_labels = [f"[{i}]" for i in range(10)]

print("  ".join(transformations))
print("  ".join(input_labels))

show_comparison_grid(val_inputs, val_outputs, title="Résultat final - Validation (0→1 ... 9→0)")


# **Experiment of Digit Generation from a gaussian distribution.**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from geomloss import SamplesLoss
from torch.utils.data import random_split, DataLoader, TensorDataset

def init_W(embed_dim, in_dim):
    return torch.randn(embed_dim, in_dim) * 0.01

def init_b(size):
    return torch.zeros(size)

class Module_ModularGN(nn.Module):
    def __init__(self, in_dim, embed_dim, activation):
        super().__init__()
        self.beta = nn.Parameter(torch.rand(1), requires_grad=True)
        self.W = nn.Parameter(init_W(embed_dim, in_dim), requires_grad=True)
        self.b = nn.Parameter(init_b(embed_dim), requires_grad=True)
        self.act = activation()

    def forward(self, x):
        z = F.linear(x, weight=self.W, bias=self.b)
        z = self.act(z * F.softplus(self.beta))
        z = F.linear(z, weight=self.W.T)
        return z

class ModularGradNet(nn.Module):
    def __init__(self, num_modules, in_dim, embed_dim, activation):
        super().__init__()
        self.num_modules = num_modules
        self.mmgn_modules = nn.ModuleList([Module_ModularGN(in_dim, embed_dim, activation) for _ in range(num_modules)])
        self.alpha = nn.Parameter(torch.randn(num_modules), requires_grad=True)
        self.bias = nn.Parameter(init_b(in_dim), requires_grad=True)

    def forward(self, x):
        out = 0
        for i in range(self.num_modules):
            out += self.alpha[i] * self.mmgn_modules[i](x)
        out += self.bias
        return out

def show_grid(images, title=''):
    fig, axes = plt.subplots(1, len(images), figsize=(len(images) * 2, 2))
    for i in range(len(images)):
        axes[i].imshow(images[i].reshape(28, 28).cpu().numpy(), cmap="gray")
        axes[i].axis("off")
    if title:
        plt.suptitle(title)
    plt.tight_layout()
    plt.show()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_full = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_train, mnist_val = random_split(mnist_full, [50000, 10000])

num_train_samples = 10000
num_val_samples = 2000
input_dim = 28 * 28

train_noise = torch.randn(num_train_samples, input_dim)
train_targets = torch.stack([img.view(-1) for img, _ in list(mnist_train)[:num_train_samples]])

val_noise = torch.randn(num_val_samples, input_dim)
val_targets = torch.stack([img.view(-1) for img, _ in list(mnist_val)[:num_val_samples]])

train_dataset = TensorDataset(train_noise, train_targets)
val_dataset = TensorDataset(val_noise, val_targets)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

# ----- Initialisation -----
generator = ModularGradNet(num_modules=50, in_dim=input_dim, embed_dim=512, activation=nn.Tanh)
optimizer = optim.Adam(generator.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
loss_fn = SamplesLoss("sinkhorn", p=2, blur=0.03)

num_epochs = 200
for epoch in range(num_epochs):
    generator.train()
    total_loss = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        g_x = generator(x_batch)
        loss = loss_fn(g_x, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()

    generator.eval()
    total_val_loss = 0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            g_val = generator(x_val)
            val_loss = loss_fn(g_val, y_val)
            total_val_loss += val_loss.item()

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {total_loss/len(train_loader):.6f} | Val Loss: {total_val_loss/len(val_loader):.6f}")

    if (epoch + 1) % 20 == 0 or epoch == num_epochs - 1:
        with torch.no_grad():
            z = torch.randn(10, input_dim)
            samples = generator(z)
            show_grid(samples, title=f"Generated digits at epoch {epoch+1}")

print("\n Génération finale depuis du bruit (validation)")
with torch.no_grad():
    z = torch.randn(10, input_dim)
    samples = generator(z)
    show_grid(samples, title="Final generated digits from noise")


# **CIFAR colorization**

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Subset, random_split
import matplotlib.pyplot as plt
from geomloss import SamplesLoss

def init_W(embed_dim, in_dim):
    k = np.sqrt(1 / in_dim)
    return 2 * k * torch.rand(embed_dim, in_dim) - k

def init_b(embed_dim, in_dim):
    k = np.sqrt(1 / in_dim)
    return 2 * k * torch.rand(embed_dim,) - k

class CascadeGradNet(nn.Module):
    def __init__(self, num_layers, in_dim, embed_dim, activation):
        super().__init__()
        self.num_layers = num_layers
        self.nonlinearity = nn.ModuleList([activation() for _ in range(num_layers)])
        self.W = nn.Parameter(init_W(embed_dim, in_dim), requires_grad=True)
        self.bias = nn.ParameterList([nn.Parameter(init_b(embed_dim, embed_dim), requires_grad=True) for _ in range(num_layers+1)])
        self.bias[0] = nn.Parameter(init_b(embed_dim, in_dim), requires_grad=True)
        self.bias[-1] = nn.Parameter(init_b(in_dim, embed_dim), requires_grad=True)
        self.beta = nn.ParameterList([nn.Parameter(torch.rand(embed_dim)-0.5, requires_grad=True) for _ in range(num_layers)])
        self.alpha = nn.ParameterList([nn.Parameter(torch.rand(embed_dim)-0.5, requires_grad=True) for _ in range(num_layers)])

    def forward(self, x):
        z = self.beta[0].view(1,-1) * F.linear(x, self.W, self.bias[0])
        for i in range(self.num_layers - 1):
            skip = self.beta[i+1].view(1,-1) * F.linear(x, self.W, self.bias[i+1])
            z = skip + self.alpha[i].view(1,-1) * self.nonlinearity[i](z)
        z = self.alpha[-1].view(1,-1) * self.nonlinearity[-1](z)
        z = F.linear(z, self.W.T, self.bias[-1])
        return z

class GrayscaleToColorDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.to_grayscale = transforms.Grayscale(num_output_channels=1)
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        color_img, _ = self.dataset[idx]
        gray_img = self.to_grayscale(color_img)
        gray_img = gray_img.expand(3, -1, -1).contiguous()
        return gray_img.reshape(-1), color_img.contiguous().reshape(-1)

data_root = "./data_cifar10"
image_size = 32
subset_size = 5000

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

full_dataset = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
indices = np.random.choice(len(full_dataset), subset_size, replace=False)
dataset = Subset(full_dataset, indices)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_subset, val_subset = random_split(dataset, [train_size, val_size])
train_dataset = GrayscaleToColorDataset(train_subset)
val_dataset = GrayscaleToColorDataset(val_subset)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

input_dim = image_size * image_size * 3
generator = CascadeGradNet(num_layers=24, in_dim=input_dim, embed_dim=512, activation=nn.Tanh)
optimizer = torch.optim.Adam(generator.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
loss_fn = SamplesLoss("sinkhorn", p=2, blur=0.03)

num_epochs = 350
best_val_loss = float("inf")
train_losses = []
val_losses = []
model_path = "best_generator.pt"

for epoch in range(num_epochs):
    generator.train()
    total_loss = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        output = generator(x_batch)
        loss = loss_fn(output, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()

    avg_train_loss = total_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Validation
    generator.eval()
    total_val_loss = 0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            val_output = generator(x_val)
            val_loss = loss_fn(val_output, y_val)
            total_val_loss += val_loss.item()
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(generator.state_dict(), model_path)
        print(f"Nouveau meilleur modèle sauvegardé à epoch {epoch+1} (val_loss = {best_val_loss:.6f})")

    if (epoch + 1) % 20 == 0 or epoch == num_epochs:
        generator.eval()
        with torch.no_grad():
            val_indices = np.random.choice(len(val_dataset), size=8, replace=False)
            samples = [val_dataset[i] for i in val_indices]
            sample_input = torch.stack([s[0] for s in samples])
            sample_target = torch.stack([s[1] for s in samples])
            sample_output = generator(sample_input)

            fig, axes = plt.subplots(3, 8, figsize=(16, 6))
            for i in range(8):
                img_input = sample_input[i].view(3, image_size, image_size).permute(1, 2, 0).cpu().numpy()
                img_output = sample_output[i].view(3, image_size, image_size).permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5
                img_target = sample_target[i].view(3, image_size, image_size).permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5

                axes[0, i].imshow(img_input, cmap="gray")
                axes[1, i].imshow(np.clip(img_output, 0, 1))
                axes[2, i].imshow(np.clip(img_target, 0, 1))

                for row in range(3):
                    axes[row, i].axis("off")

            axes[0, 0].set_ylabel("Noir & Blanc", fontsize=12)
            axes[1, 0].set_ylabel("Généré", fontsize=12)
            axes[2, 0].set_ylabel("Réel", fontsize=12)
            plt.suptitle(f"Epoch {epoch+1} – Résultats aléatoires sur Validation", fontsize=14)
            plt.tight_layout()
            plt.show()

plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label="Train Loss")
plt.plot(range(1, num_epochs + 1), val_losses, label="Validation Loss")
plt.xlabel("Épochs")
plt.ylabel("Loss")
plt.title("Courbes de perte (Train vs Validation)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

###### Colorization

# Color transport

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from geomloss import SamplesLoss

def init_W(embed_dim, in_dim):
    k = np.sqrt(1 / in_dim)
    return 2 * k * torch.rand(embed_dim, in_dim) - k

def init_b(embed_dim, in_dim):
    k = np.sqrt(1 / in_dim)
    return 2 * k * torch.rand(embed_dim,) - k

class CascadeGradNet(nn.Module):
    def __init__(self, num_layers, in_dim, embed_dim, activation=nn.ReLU):
        super().__init__()
        self.num_layers = num_layers
        self.nonlinearity = nn.ModuleList([activation() for _ in range(num_layers)])
        self.W = nn.Parameter(init_W(embed_dim, in_dim), requires_grad=True)
        self.bias = nn.ParameterList([nn.Parameter(init_b(embed_dim, embed_dim)) for _ in range(num_layers + 1)])
        self.bias[0] = nn.Parameter(init_b(embed_dim, in_dim))
        self.bias[-1] = nn.Parameter(init_b(in_dim, embed_dim))
        self.beta = nn.ParameterList([nn.Parameter(torch.rand(embed_dim) - 0.5) for _ in range(num_layers)])
        self.alpha = nn.ParameterList([nn.Parameter(torch.rand(embed_dim) - 0.5) for _ in range(num_layers)])

    def forward(self, x):
        z = self.beta[0].view(1, -1) * F.linear(x, self.W, self.bias[0])
        for i in range(self.num_layers - 1):
            skip = self.beta[i + 1].view(1, -1) * F.linear(x, self.W, self.bias[i + 1])
            z = skip + self.alpha[i].view(1, -1) * self.nonlinearity[i](z)
        z = self.alpha[-1].view(1, -1) * self.nonlinearity[-1](z)
        z = F.linear(z, self.W.T, self.bias[-1])
        return z

def sample_batch(x, batch_size):
    idx = torch.randperm(x.shape[0])[:batch_size]
    return x[idx]

def train_transport_sinkhorn(model, x_input, x_target, size, epochs=1000, lr=1e-3, batch_size=2048, device='cpu', display_every=100):
    model.to(device)
    x_input = x_input.to(device)
    x_target = x_target.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        x_input_batch = sample_batch(x_input, batch_size)
        x_target_batch = sample_batch(x_target, batch_size)

        x_output = model(x_input_batch).clamp(0, 1)
        loss = sinkhorn_loss(x_output, x_target_batch)

        loss.backward()
        optimizer.step()

        print(f"[{epoch}/{epochs}] Sinkhorn Loss: {loss.item():.6f}")

        if epoch % display_every == 0 or epoch == epochs - 1:
            model.eval()
            with torch.no_grad():
                stylized = model(x_input).clamp(0, 1)
            stylized_image = stylized.reshape(size[1], size[0], 3).cpu().numpy()
            plt.figure(figsize=(6, 6))
            plt.title(f"Stylized Image - Epoch {epoch}")
            plt.imshow(stylized_image)
            plt.axis('off')
            plt.show()

def image_to_colors(img_path):
    image = Image.open(img_path).convert('RGB')
    transform = transforms.ToTensor()
    tensor = transform(image).permute(1, 2, 0)
    pixels = tensor.reshape(-1, 3)
    return pixels, image.size, tensor

def fit_gaussian(x):
    mu = x.mean(dim=0)
    cov = torch.from_numpy(np.cov(x.T.numpy())).float()
    return mu, cov

def sample_from_gaussian(mu, cov, n_samples):
    dist = torch.distributions.MultivariateNormal(mu, covariance_matrix=cov)
    return dist.sample((n_samples,))

def apply_transport(model, color_tensor, size):
    model.eval()
    with torch.no_grad():
        stylized = model(color_tensor).clamp(0, 1)
    stylized_image = stylized.reshape(size[1], size[0], 3).cpu().numpy()
    return stylized_image

def style_transfer_pipeline_sinkhorn(image_source_path, image_style_path, save_path="stylized_output.jpg", epochs=1000, display_every=100, batch_size=2048):
    print("Chargement des images...")
    x_input, size, _ = image_to_colors(image_source_path)
    x_style, _, _ = image_to_colors(image_style_path)

    print("Estimation de la gaussienne de style...")
    mu_s, cov_s = fit_gaussian(x_style)
    x_target = sample_from_gaussian(mu_s, cov_s, x_input.shape[0])

    print("Initialisation du modèle...")
    model = CascadeGradNet(num_layers=3, in_dim=3, embed_dim=128)

    print("Entraînement avec Sinkhorn (geomloss)...")
    train_transport_sinkhorn(model, x_input, x_target, size,
                             epochs=epochs, lr=1e-3,
                             batch_size=batch_size,
                             display_every=display_every)

    print("Application finale du style...")
    final_image = apply_transport(model, x_input, size)

    final_pil = Image.fromarray((final_image * 255).astype(np.uint8))
    final_pil.save(save_path)
    print(f"Image stylisée sauvegardée dans : {save_path}")

    plt.figure(figsize=(6, 6))
    plt.title("Résultat final")
    plt.imshow(final_image)
    plt.axis('off')
    plt.show()

style_transfer_pipeline_sinkhorn("data/marseille.jpeg", "data/soleil2.jpg", save_path="stylized_output_sinkhorn.jpg", epochs=500, batch_size=2048, display_every=10)