In [15]:
# %load defaults.ipy
from numpy import *
import matplotlib
matplotlib.rcParams['savefig.dpi'] = 300
%matplotlib inline
import matplotlib.pyplot as plt
import netCDF4
from IPython.core.display import display, HTML
import matplotlib2tikz
import os
import h5py
import ot
import sys

for p in sys.path:
    if 'matplotlib' in p.lower():
        sys.path.remove(p)
    if 'netcdf' in p.lower():
        sys.path.remove(p)

from mpl_toolkits.mplot3d import Axes3D
def showAndSave(name, savetikz=True):
    name = name.replace('.', '').replace('=','').replace(' ', '')
    
    
    print ("Saving %s" % name)
    if savetikz:
        matplotlib2tikz.save('img_tikz/' + name + '.tikz',
           figureheight = '\\figureheight',
           figurewidth = '\\figurewidth')
    
    plt.savefig('img/' + name + '.png', bbox_inches='tight')
    plt.show()
    plt.close('all')

def load(f, sample):
    if '.nc' in f:
        with netCDF4.Dataset(f) as d:
            return d.variables['sample_%d_rho' % sample][:,:,0]
    else:
        f = os.path.join(f, 'kh_%d/kh_1.h5' % sample)
        with h5py.File(f) as d:
            return d['rho'][:,:,0]

# Histogram plotting

In [10]:
def plot_histograms2(N, M, perturbation,  name, minValue, maxValue, x, y, xp, yp, valuesx, valuesy):
    
    
    
   
    
    plt.hist2d(valuesx, valuesy, bins=20, normed=True, range=[[minValue, maxValue], [minValue, maxValue]])
    plt.colorbar()
    plt.xlabel('Value of $\\rho(%.2f,%.2f)$' % (x,y))
    plt.ylabel('Value of $\\rho(%.2f,%.2f)$' % (xp,yp))
    plt.title('Histogram at resolution %d, $\\epsilon = %.4f$, for %s,\nbetween $(%.2f, %.2f)$ and $(%.2f, %.2f)$' %  (N, perturbation, name, x,y,xp,yp))
    
    showAndSave('hist2pt_perturbation_%s_%.5f_%d_%.1f_%.1f_%.1f_%.1f' %(name, perturbation, N, x, y, xp, yp))
    
    H, xedges, yedges = histogram2d(valuesx, valuesy, bins=20, normed=True,range=[[minValue, maxValue], [minValue, maxValue]])
    
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    Xvalues, Yvalues = meshgrid(xedges[:-1], yedges[:-1])
    surf = ax.plot_surface(Xvalues, Yvalues, H)
    plt.xlabel('Value of $\\rho(%.2f,%.2f)$' % (x,y))
    plt.ylabel('Value of $\\rho(%.2f,%.2f)$' % (xp,yp))
    plt.title('Histogram at resolution %d, , $\\epsilon = %.4f$, for %s,\nbetween $(%.2f, %.2f)$ and $(%.2f, %.2f)$' % (N, perturbation, name, x,y,xp,yp))
    
    showAndSave('hist2pt_perturbation_surface_%s_%.5f_%d_%.1f_%.1f_%.1f_%.1f' %(name, perturbation, N, x, y, xp, yp), savetikz=False)

In [30]:
def plotHistograms(name, resolution, perturbations, basename, samples):
   
    points = [0.2, 0.4, 0.7, 0.8]
    min_values = {}
    max_values = {}
    
    
    for x in points:
        min_values[x] = {}
        max_values[x] = {}
        for y in points:
            min_values[x][y] = {}
            max_values[x][y] = {}
            for xp in points:
                min_values[x][y][xp] = {}
                max_values[x][y][xp] = {}
                for yp in points:
                    min_values[x][y][xp][yp] = 100000000
                    max_values[x][y][xp][yp] = -100000000
                    
    
    for p in perturbations:
        # first we load the data
        data = zeros((resolution, resolution, samples))
        filename = basename.format(perturbation=p)
        for k in range(samples):
            data[:,:,k] = load(filename, k)
            
    
        for x in points:
            for y in points:
                for xp in points:
                    for yp in points:
                        i = int(x*resolution)
                        j = int(y*resolution)
                        ip = int(xp*resolution)
                        jp = int(yp*resolution)
                        min_values[x][y][xp][yp] = min(amin([data[i,j,:], data[ip,jp,:]]), min_values[x][y][xp][yp])
                        
                        
                        max_values[x][y][xp][yp] = max(amax([data[i,j,:], data[ip,jp,:]]), max_values[x][y][xp][yp])
                        
                        # First we find minimum and maximum over all resolutions and all samples
                        # we need this to equalize the histograms properly.
                    
                

                    
    for p in perturbations:
        filename = basename.format(perturbation=p)
        data = zeros((resolution, resolution, samples))
        for k in range(samples):
            data[:,:,k] = load(filename, k)
            
        for x in points:
            for y in points:
                for xp in points:
                    for yp in points:
                        valuesx = []
                        valuesy = []
                        
                        i = int(x*resolution)
                        j = int(y*resolution)
                        ip = int(xp*resolution)
                        jp = int(yp*resolution)
                        
                        for k in range(samples):
                            
                            valuesx.append(data[i, j,k])
                            valuesy.append(data[ip, jp,k])
                
                        plot_histograms2(resolution, samples, p, name, min_values[x][y][xp][yp], max_values[x][y][xp][yp], x,y, xp, yp, valuesx, valuesy)

        

# Computing Wasserstein distances

In [20]:
def wasserstein_point2_fast(data1, data2, i, j, ip, jp, a, b, xs, xt):
    """
    Computes the Wasserstein distance for a single point in the spatain domain
    """
    

    xs[:,0] = data1[i,j,:]
    xs[:,1] = data1[ip, jp, :]

    xt[:,0] = data2[i,j, :]
    xt[:,1] = data2[ip, jp, :]


    M = ot.dist(xs, xt, metric='euclidean')
    G0 = ot.emd(a,b,M)

    return sum(G0*M)

def wasserstein2pt_fast(data1, data2):
    """
    Approximate the L^1(W_1) distance (||W_1(nu1, nu2)||_{L^1})
    """
    M = data1.shape[2]
    a = ones(M)/M
    b = ones(M)/M
    xs = zeros((M,2))
    xt = zeros((M,2))
    N = data1.shape[0]
    distance = 0
    

    points = 0.1*array(range(0,10))
    for (n,x) in enumerate(points):
        for y in points:

            for xp in points:
                for yp in points:
                    i = int(x*N)
                    j = int(y*N)
                    ip = int(xp*N)
                    jp = int(yp*N)
                    distance += wasserstein_point2_fast(data1, data2, i,j, ip, jp, a, b, xs, xt)


    
    return distance / len(points)**4


def plotWassersteinConvergence(name, basename, r, perturbations):
    wasserstein2pterrors = []
    for (np, p) in enumerate(perturbations[:-1]):
        filename = basename.format(perturbation=p)
        filename_coarse = basename.format(perturbation=perturbations[-1])
        data1 = zeros((r,r,r))
        data2 = zeros((r,r,r))
        for k in range(r):
            d1 = load(filename, k)
            d2 = load(filename_coarse, k)
            data1[:,:,k] = d1
            data2[:,:,k] = d2

        wasserstein2pterrors.append(wasserstein2pt_fast(data1, data2))
        print("wasserstein2pterrors=%s" % wasserstein2pterrors)
    

    plt.loglog(perturbations[1:], wasserstein2pterrors, '-o')
    plt.xlabel("Perturbation $\\epsilon$")
    plt.ylabel('$||W_1(\\nu^{2, \\Delta x, \\epsilon}, \\nu^{2,\\Delta x, \\epsilon_0})||_{L^1(D\\times D)}$')
    plt.title("Wasserstein convergence for %s\nfor second correlation measure,\nwith respect to perturbation size\nagainst a reference solution with $\epsilon_0=%.4f$"%(name,perturbations[-1]))
    showAndSave('%s_wasserstein_perturbation_convergence_2pt_all' % name)

# Kelvin-Helmholtz

In [None]:
resolution = 1024
perturbations = [0.09, 0.075, 0.06, 0.05, 0.025, 0.01, 0.0075, 0.005,0.0025]


basename = 'kh_perts/q{perturbation}/kh_1.nc'
name = 'Kelvin-Helmholtz'
samples = 1024

plotWassersteinConvergence(name, basename, resolution, perturbations)
plotHistograms(name, resolution, perturbations, basename, samples)


In [None]:
perturbations = [0.01, 0.025, 0.05, 0.06, 0.075, 0.09]
perturbations = sorted(perturbations)
perturbations.reverse()
print(perturbations)

basename = 'kh_conv/q{perturbation}/kh_1.nc'

for p in perturbations:
    print (basename.format(perturbation=p))