# Generalized Sliced-Wasserstein Flows with Neural Networks

The goal of this experiment is to illustrate the effects of the Generalized Sliced-Wasserstein (GSW) and maximum Generalized Sliced-Wasserstein (max-GSW) distance, whose defining function is learned through a neural network.

## Experiment details

We consider the following problem:
$$\operatorname{min}_{p_Y} GSW_p(p_X, p_Y),$$

where $p_X$ is a target distribution and $p_Y$ is the source distribution, which is initialized to the normal distribution. 

The optimization is solved iteratively via
$$ \partial_t (p_Y)_t= -\nabla GSW_p(p_X, (p_Y)_t), ~~(p_Y)_0=\mathcal{N}(0, (0.25)^2).$$

We also consider $\operatorname{min}_{\mu} \{ \max\text{-}GSW_p(p_X, p_Y) \},$ and we use the same optimization scheme to solve it (with $\max\text{-}GSW_p$ in place of $GSW_p$). 

We use 5 well-known distributions as the target: the 25-Gaussians, 8-Gaussians, Swiss Roll, Half Moons and Circle distributions. 

The defining function is learned through a neural network. We compare different configurations: we use a multilayer perceptron of depth 1, 2 or 3. 

We analyze the results (i) qualitatively, by plotting samples drawn from $p_X$ and $(p_Y)_t$ at each iteration $t$ of the optimization process, and (ii) quantitatively, by computing and reporting the 2-Wasserstein distance between $p_X$ and $(p_Y)_t$ at each $t$.

## Requirements

* Numpy
* Scikit-learn
* PyTorch
* POT

In [None]:
import sys
sys.path.append('../gsw')

import numpy as np
from gswnn import GSW_NN
from gsw_utils import w2,load_data

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.nn.parameter import Parameter
from torch import optim

from tqdm import tqdm
from IPython import display
import time
import pickle 
import matplotlib.pyplot as plt
import random
import os

In [None]:
np.random.seed(10)

### We choose a dataset and load it
### The dataset name must be 'swiss_roll', 'half_moons', 'circle', '8gaussians' or '25gaussians'

In [None]:
dataset_name = 'swiss_roll'

In [None]:
N = 1000  # Number of samples from p_X
X = load_data(name=dataset_name, n_samples=N)
X -= X.mean(dim=0)[np.newaxis,:]  # Normalization
meanX = 0

In [None]:
# Show the dataset
_, d = X.shape
fig = plt.figure(figsize=(5,5))
plt.scatter(X[:,0], X[:,1])
plt.show()

### We create the different folders to store the results

In [None]:
results_folder = './saved_results_GSW_flows_NN'
if not os.path.isdir(results_folder):
    os.mkdir(results_folder)

In [None]:
foldername = os.path.join(results_folder, 'Gifs')
if not os.path.isdir(foldername):
    os.mkdir(foldername)

In [None]:
foldername = os.path.join(results_folder, 'Gifs', dataset_name + '_Comparison_NN')
if not os.path.isdir(foldername):
    os.mkdir(foldername)

### We solve the two optimization problems for different neural network configurations and plot the results at each step

In [None]:
# Use GPU if available, CPU otherwise
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
# Number of iterations for the optimization process
nofiterations = 250

In [None]:
# Define the variables to store the loss (2-Wasserstein distance) for each defining function and each problem 
w2_dist = np.nan * np.zeros((nofiterations, 3))
maxw2_dist = np.nan * np.zeros((nofiterations, 3))

In [None]:
# Define the different neural networks architectures
depth = [1, 2, 3, 1, 2, 3]
titles=['GSW NN - Depth=1', 'GSW NN - Depth=2', 'GSW NN - Depth=3', 
        'MaxGSW NN - Depth=1', 'MaxGSW NN - Depth=2', 'MaxGSW NN - Depth=3']

In [None]:
# Define the initial distribution
temp = np.random.normal(loc=meanX, scale=.25, size=(N,d))

# Define the optimizers
Y=list()
optimizer=list()
gsw=list()

for i in range(6):
    Y.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
    optimizer.append(optim.Adam([Y[-1]], lr = 1e-2))
    gsw.append(GSW_NN(din=2, nofprojections=1, model_depth=depth[i]))


In [None]:
fig=plt.figure(figsize=(15, 15))
grid = plt.GridSpec(3, 3, wspace=0.4, hspace=0.3)

for i in range(nofiterations):            
    loss=list()
    # We loop over the different neural networks configurations for the GSW problem
    for k in range(3):
        # Loss computation (here, GSW)
        loss = gsw[k].gsw(X.to(device), Y[k].to(device))
        
        # Optimization step
        optimizer[k].zero_grad()
        loss.backward()
        optimizer[k].step()
        
        # Compute the 2-Wasserstein distance to compare the distributions
        w2_dist[i, k] = w2(X.detach().cpu().numpy(),Y[k].detach().cpu().numpy())  
        
        # Plot samples from the target and the current solution
        temp = Y[k].detach().cpu().numpy()
        plt.subplot(grid[0, k])
        plt.scatter(X[:, 0], X[:, 1])
        plt.scatter(temp[:, 0], temp[:, 1], c='r') 
        plt.title(titles[k], fontsize=22)
    
    # We loop over the different neural networks configurations for the max-GSW problem 
    for k in range(3,6):
        # Loss computation (here, max-GSW)
        loss = gsw[k].max_gsw(X.to(device), Y[k].to(device), iterations=250, lr=1e-4)
        
        # Optimization step
        optimizer[k].zero_grad()
        loss.backward()
        optimizer[k].step() 
        
        # Compute the 2-Wasserstein distance to compare the distributions
        maxw2_dist[i, k-3] = w2(X.detach().cpu().numpy(), Y[k].detach().cpu().numpy())  
        
        # Plot samples from the target and the current solution
        temp = Y[k].detach().cpu().numpy()
        plt.subplot(grid[1, k-3])
        plt.scatter(X[:, 0], X[:, 1])
        plt.scatter(temp[:, 0], temp[:, 1],c='r') 
        plt.title(titles[k], fontsize=22)    
    
    # Plot the 2-Wasserstein distance
    plt.subplot(grid[2, 0:3])
    plt.plot(np.log10(w2_dist[:,:]), linewidth=3)
    plt.plot(np.log10(maxw2_dist[:,:]), linewidth=3)
    plt.title('2-Wasserstein Distance', fontsize=22)
    plt.ylabel(r'$Log_{10}(W_2)$', fontsize=22)
    
    plt.legend(titles, fontsize=22, loc='lower left')
    display.clear_output(wait=True)
    display.display(plt.gcf()) 
    time.sleep(1e-5)    
    
    # Save the figure
    fig.savefig(foldername+'/img%03d.png'%(i))
    for k in range(3):
        plt.subplot(grid[:, k])
        plt.cla()


### We save the results

In [None]:
filename = os.path.join(results_folder, dataset_name + '_comparison_NN_1run.pkl')
with open(filename, 'wb') as f:
    pickle.dump([w2_dist, maxw2_dist], f)

In [None]:
import imageio
from glob import glob
from skimage.transform import resize

In [None]:
filenames = glob(foldername + '/*.png')
images = []
for filename in filenames:
    images.append((resize(imageio.imread(filename).astype(float) / 255., (750, 750, 4)) * 255).astype('uint8'))
imageio.mimsave(dataset_name + '_comparison_NN.gif', images)