In [None]:
import matplotlib.pyplot as plt
import time
import torch
from geomloss import SamplesLoss
import numpy as np

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

In [None]:
from random import choices
import h5py  
import pandas as pd
from scipy.io import savemat
import time

In [None]:
def load_hdf5_snapshot(fname, var="snapshot_matrix_000000"):
    hf = h5py.File(fname, 'r')
    snap = np.array(hf[var][:])
    return snap

def find_states(full_sol, N, nvar=4):
    # decomposes a state solution into its variable solutions
    var1 = np.zeros(N,)
    var2 = np.zeros(N,)
    var3 = np.zeros(N,)
    var4 = np.zeros(N,)
    
    for i in range(N):
        var1[i] = full_sol[(i)*nvar + 0];
        var2[i] = full_sol[(i)*nvar + 1];
        var3[i] = full_sol[(i)*nvar + 2];
        var4[i] = full_sol[(i)*nvar + 3];
    
    return var1, var2, var3, var4
    
def create_weights(state, psize, mn):
    state2 = (state - mn)*psize
    state2 = state2 / 10
    state2 = state2.round()
    return state2

def find_pebble_size(state):
    avg = sum(state) / len(state)
    mn = min(state)
    return psize
    
    
# remove rows with zero difference
def remove_zeros(x1, x2, mesh):
    x3 = x1 - x2
    x1 = x1[np.where(x3 != 0)]
    x2 = x2[np.where(x3 != 0)]
    mesh = mesh[np.where(x3 != 0)]
    
    return x1, x2, mesh

def create_pebbles(x1, mesh, npebbles):
    # each pebble is denoted by its location in the mesh  
    mesh1 = np.empty([int(sum(x1)), 2])
    k1 = 0
    for i in range(len(mesh)):
        n1 = int(x1[i])
        mesh1[k1:(k1+n1)] = mesh[i]
        k1 += n1
    
    mesh1 = np.array(choices(mesh1, k=npebbles))
    
    return mesh1 

def display_samples(ax, x, color):
    x_ = x.detach().cpu().numpy()
    ax.scatter(x_[:, 0], x_[:, 1], 25 * 500 / len(x_), color, edgecolors="none")

## Gradient descent

In [None]:
# adapted from tutorials at:
# https://www.kernel-operations.io/geomloss/_auto_examples/
#         optimal_transport/plot_optimal_transport_2D.html
#         #sphx-glr-auto-examples-optimal-transport-plot-optimal-transport-2d-py

def gradient_descent(loss, Nsteps = 11, lr=1, disp=1):
    """Flows along the gradient of the loss function.

    Parameters:
        loss ((x_i,y_j) -> torch float number):
            Real-valued loss function.
        lr (float, default = 1):
            Learning rate, i.e. time step.
    """
    
    display_its = [0, 1, 4, Nsteps-1]
    
    # Use colors to identify the particles
    colors = (70 * X_i[:, 0]).cos() * (70 * X_i[:, 1]).cos()
    colors = colors.detach().cpu().numpy()
    
    # Make sure that we won't modify the reference samples
    x_i, y_j = X_i.clone(), Y_j.clone()
    
    # We're going to perform gradient descent on Loss(α, β)
    # wrt. the positions x_i of the diracs masses that make up α:
    x_i.requires_grad = True
    
    t_0 = time.time()
    if (disp==1):
        plt.figure(figsize=(6, 6))
    k = 1
    for i in range(Nsteps):  # Euler scheme ===============
        # Compute cost and gradient
        L_αβ = loss(x_i, y_j)
        [g] = torch.autograd.grad(L_αβ, [x_i])
        
        if (i in display_its) and (disp == 1):  # display
            #ax = plt.subplot(2, 2, k)
            ax = plt.subplot(1,1,1)
            k = k + 1
            plt.set_cmap("hsv")
            plt.scatter(
                [10], [10]
            )  # shameless hack to prevent a slight change of axis...
            display_samples(ax, y_j, [(0.55, 0.55, 0.95)])
            display_samples(ax, x_i, colors)
            ax.set_title("it = {}".format(i))
            plt.axis([0, 1.69, -0.35, 0.35])
            plt.gca().set_aspect("equal", adjustable="box")
            plt.xticks([], [])
            plt.yticks([], [])
            plt.tight_layout()
            plt.savefig(f'wass_mach_AE_{i}.png',transparent = False, facecolor = 'white')
            plt.clf()

        # in-place modification of the tensor's values
        x_i.data -= lr * len(x_i) * g
        
    return x_i.detach().cpu().numpy()
    

In [None]:
# dataset 
snap1 = load_hdf5_snapshot("data/mach3p5/su2_snapshot.000000")
snap2 = load_hdf5_snapshot("data/mach2p5/su2_snapshot.000000")
# mesh points
xy = np.loadtxt(open("data/xy.csv", "rb"), delimiter=",")
xy = xy[:,0:2]

# decompose full state vector into variable vectors
#xi, var12, var13, var14 = find_states(snap1, N, 4)
#yj, var22, var23, var24 = find_states(snap2, N, 4)
xi = snap1
yj = snap2

# parameters
tol = 1e-1  # tolerance for comparison of the two datasets (problem dependent)
M = 5000     # number of (randomly sampled) pebbles to use in algorithm (5000 max if using CPU)

# pre-process dataset
Xi = create_weights(xi, 1/tol, min(yj))
Yj = create_weights(yj, 1/tol, min(yj))

Xi, Yj, xy = remove_zeros(Xi, Yj, xy)

M2 = min(round(len(Xi)),M)
print("M: {}".format(M2))

X_i = create_pebbles(Xi, xy, M2)
Y_j = create_pebbles(Yj, xy, M2)

X_i = torch.from_numpy(X_i).type(torch.FloatTensor)
Y_j = torch.from_numpy(Y_j).type(torch.FloatTensor)

In [None]:
xfinal = gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.01),15)

# find wasserstein cost
xinit = X_i.detach().cpu().numpy()
dw = np.sqrt( sum((xinit[:,0] - xfinal[:,0])**2 + (xinit[:,1] - xfinal[:,1])**2)) 
dw

# Wasserstein cost for multiple files

In [None]:
# loop over simulation data in parameter space
files = ["mach2", "mach3", "mach3p5", "mach4"]
xy_orig = np.loadtxt(open("data/xy.csv", "rb"), delimiter=",")
xy_orig = xy_orig[:,0:2]

snap2 = load_hdf5_snapshot("data/mach2p5/su2_snapshot.000000")
yj, var22, var23, var24 = find_states(snap2, 20200, 4)

# parameters
tol = 1e-3  # tolerance for comparison of the two datasets
M = 2000    # number of (randomly sampled) pebbles to use in algorithm
Niter = 5   # number of gradient descent iterations

d = np.empty([len(files), 1])

for i in range(len(files)):
    # dataset 
    snap1 = load_hdf5_snapshot("data/" + files[i] + "/su2_snapshot.000000")

    # decompose full state vector into variable vectors
    xi, var12, var13, var14 = find_states(snap1, 20200, 4)
    
    # pre-process dataset
    Xi = create_weights(xi, 1/tol, min(yj))
    Yj = create_weights(yj, 1/tol, min(yj))
    
    Xi, Yj, xy = remove_zeros(Xi, Yj, xy_orig)
    
    M2 = min(round(len(Xi)),M)
    
    X_i = create_pebbles(Xi, xy, M2)
    X_i = torch.from_numpy(X_i).type(torch.FloatTensor)
    
    if (i==0):
        Y_j = create_pebbles(Yj, xy, M2)
        Y_j = torch.from_numpy(Y_j).type(torch.FloatTensor)
    
    print("Parameter: {}, M: {}".format(files[i], M2))
    st = time.time()
    
    xfinal = gradient_descent(SamplesLoss("sinkhorn", p=2, blur=0.01), Niter, disp=0)
    
    et = time.time()
    elapsed_time = et - st
    print('Execution time:', elapsed_time, 'seconds')
    
    # find wasserstein cost
    xinit = X_i.detach().cpu().numpy()
    d[i] = np.linalg.norm(xinit - xfinal)
    #d[i] = sum(np.sqrt((xinit[:,0] - xfinal[:,0])**2 + (xinit[:,1] - xfinal[:,1])**2))/ M2
    
    
    


In [None]:
fig = plt.figure(figsize=(12, 5))
ax = fig.add_subplot(111)
labels = [2, 3, 3.5, 4]
plt.plot(labels, d)
ax.tick_params(axis='both', which='major', labelsize=20)
#ax.set_xticklabels(labels, rotation=45, ha='right')
plt.tight_layout()
