<a href="https://colab.research.google.com/github/cesar-claros/data-efficient-gans/blob/master/bures_weighted_flows_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install POT
! pip install celluloid

In [None]:
!git clone https://github.com/anon-author-dev/gsw
%cd gsw/code_flow/gsw

In [None]:
import numpy as np
from bures_weighted import WGSB
from weighted_utils import w2_weighted
from gsw_utils import w2,load_data

import torch
from torch import optim

from celluloid import Camera
from tqdm import tqdm
from IPython import display
import matplotlib.pyplot as plt

In [None]:
# in  ['swiss_roll','circle','8gaussians','25gaussians']:

dataset_name = 'circle' 
np.random.seed(10)
N = 1000  # Number of samples from p_X
X = load_data(name=dataset_name, n_samples=N)
X -= X.mean(dim=0)[np.newaxis,:]  # Normalization
meanX = 0
# Show the dataset
_, d = X.shape
fig = plt.figure(figsize=(5,5))
plt.scatter(X[:,0], X[:,1])
plt.show()

In [None]:
# Use GPU if available, CPU otherwise
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
# Number of iterations for the optimization process
nofiterations = 250

In [None]:
# # Define the different defining functions
titles = ['Max Sliced W2: linear', 'RFB-2000-0.2', 'RFB-2000-0.1', 'Max Sliced Bures: linear', 'RFB-2000-0.2', 'RFB-2000-0.1']

In [None]:
# Define the initial distribution
#Y = torch.from_numpy(np.random.normal(loc=meanX, scale=0.5, size=(N,d))).float()
Y = torch.from_numpy((np.random.rand(N,d)-0.5)*4 ).float()
temp = np.ones((N,))/N
# Define the optimizers
beta = list()
optimizer = list()
wmsd = list()

new_lr = 1e-2
nb = 2000 # number of bases in the random Fourier bases
sigma0 = 0.2
sigma1 = 0.1

# WARNING the order of the optimizers is HARDCODED below ...

beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
optimizer.append(optim.Adam([beta[-1]], lr = new_lr))
wmsd.append([WGSB(ftype='linear')])

beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
optimizer.append(optim.Adam([beta[-1]], lr = new_lr))
wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma0)])

beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
optimizer.append(optim.Adam([beta[-1]], lr = new_lr))
wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma1)])

beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
optimizer.append(optim.Adam([beta[-1]], lr = new_lr))
wmsd.append([WGSB(ftype='linear')])

beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
optimizer.append(optim.Adam([beta[-1]], lr = new_lr))
wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma0)])

beta.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))
optimizer.append(optim.Adam([beta[-1]], lr = new_lr))
wmsd.append([WGSB(ftype='kernel', nofbases = nb, sigma = sigma1)])

In [None]:
fig, f_axs = plt.subplots(ncols=3, nrows=3, figsize=(15, 15));
camera = Camera(fig)
gs = f_axs[2, 0].get_gridspec()
# remove the underlying axes
for ax in f_axs[-1, :]:
    ax.remove()
axbig = fig.add_subplot(gs[-1, :])
axbig.set_title('Wasserstein-2 Distance', fontsize=22)
axbig.set_ylabel(r'$Log_{10}(W_2)$', fontsize=22)
colors = ['#1f77b4',
          '#ff7f0e',
          '#2ca02c',
          '#d62728',
          '#9467bd',
          '#8c564b']
# Define the variables to store the loss (2-Wasserstein distance) for each defining function and each problem 
w2_dist = np.nan * np.zeros((nofiterations, 6))
fig.suptitle('FLOWS COMPARISON', fontsize=44)
for i in range(nofiterations):            
    loss = list()
    # We loop over the different defining functions for the max- GSW problem
    for k in range(6):
        # Loss computation (here, max-GSW and max-sliced Bures)
        if k is 0:
            loss_ = wmsd[k][0].max_gsw_weighted(X.to(device),Y.to(device),beta[k].to(device))        
        elif k is 1 or k is 2:
            loss_ = wmsd[k][0].max_kernel_gsw_weighted(X.to(device), Y.to(device),beta[k].to(device))
        elif k is 3:
            loss_ = wmsd[k][0].max_sliced_bures_weighted(X.to(device), Y.to(device),beta[k].to(device))
        else:
            loss_ = wmsd[k][0].max_sliced_kernel_bures_weighted(X.to(device), Y.to(device),beta[k].to(device))

        # Optimization step
        loss.append(loss_)
        optimizer[k].zero_grad()
        loss[k].backward()
        optimizer[k].step()
        # Should projection to simplex for nu instead beta 

        
        nu = beta[k].detach().cpu().numpy()
        nu = np.maximum(0,nu).astype(np.float64)
        nu = nu/np.sum(nu)
        # Compute the 2-Wasserstein distance to compare the distributions
        w2_dist[i, k] = w2_weighted(X.detach().cpu().numpy(), Y.detach().cpu().numpy(), nu)
        
        nu = nu/np.max(nu)*15
        # Plot samples from the target and the current solution
        row = 0
        col = k
        if k>=3:
            col = k-3
            row = 1
        f_axs[row,col].scatter(X[:, 0], X[:, 1], c='b')
        f_axs[row,col].scatter(Y[:, 0], Y[:, 1], s=nu, c='r')
        f_axs[row,col].set_title(titles[k], fontsize=22)

    # Plot the 2-Wasserstein distance
    for p, color in enumerate(colors):
        axbig.plot(np.log10(w2_dist[:,p]), color = color)

    axbig.legend(titles, fontsize=22, bbox_to_anchor=(.1,-.55), loc="lower left",
                 ncol=2, fancybox=True, shadow=True)
    camera.snap()

plt.close()

In [None]:
animation_full = camera.animate()
animation_full.save('animation_full.mp4')
animation_reduced = camera.animate(blit=False, interval=10)
animation_reduced.save('animation_reduced.mp4')

In [None]:
display.HTML(animation_full.to_html5_video())