In [None]:
import numpy as np
import numpy.ma as ma
import gudhi as gd
from gudhi.wasserstein import wasserstein_distance as wd
from netCDF4 import Dataset
import matplotlib.pyplot as plt
import sys
import os
from glob import glob

In [None]:
def extract_PD_arrays_with_cubical_complex(data, infinity = 300.):
    cubical_complex = gd.CubicalComplex(top_dimensional_cells=data)
    PD_CC = cubical_complex.persistence()
    dim = np.array([i[0] for i in PD_CC])
    birth = np.array([i[1][0] for i in PD_CC])
    death = np.array([i[1][1] for i in PD_CC])
    death[death==np.Infinity] = data.max()
    n0 = dim[dim==0].size
    n1 = dim[dim==1].size
    pd0 = np.zeros([n0, 2])
    pd1 = np.zeros([n1, 2])
    pd0[:,0] = birth[dim==0]
    pd0[:,1] = death[dim==0]
    pd1[:,0] = birth[dim==1]
    pd1[:,1] = death[dim==1]
    return pd0, pd1

In [None]:
def calculate_wasserstein_distance(pd0_array1, pd0_array2, pd1_array1, pd1_array2, order=2):
    return wd(pd0_array1, pd0_array2, order=order) + wd(pd1_array1, pd1_array2, order=order)

In [None]:
def plot_PD (pd0, pd1, data_min, data_max):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(pd0[:,0], pd0[:,1], color='r', s=7, label='0')
    ax.scatter(pd1[:,0], pd1[:,1], marker='^', s=5, facecolors='none', edgecolor='blue', label='1')
    ax.set_xlim([data_min*0.95, data_max*1.05])
    ax.set_ylim([data_min*0.95, data_max*1.05])
    ax.set_xlabel('Birth')
    ax.set_ylabel('Death')
    ax.legend(loc='lower right')
    plt.show()

In [None]:
# TDA target
domain = 'EUR-11'     # AFR-44, EUR-11, NAM-44
variable = 'pr'       # pr, rlut, rsds
season = 'annual'     # annual, summer, winter


In [None]:
cwd = os.getcwd()
if variable == 'pr':
    ref_name = 'TRMM-L3'
elif variable == 'rlut':
    ref_name = 'CERES-EBAF'
elif variable == 'rsds':
    ref_name = 'CERES-EBAF'
else:
    sys.exit('variable must be pr, rlut, or rsds')
datadir = cwd +'/evaluation_result/'+domain+'/'+ref_name+'/'+variable+'/'+season+'/'    

In [None]:
ref_file = datadir+domain[0:3]+'_'+season+'_'+variable+'_'+ref_name+'.nc'
model_files = glob(datadir+domain[0:3]+'_'+season+'_'+variable+'*.nc')
model_files = [file for file in model_files if file != ref_file]
model_files.sort()
nmodel = len(model_files)

In [None]:
prefix = os.path.commonprefix(model_files)
model_names = [i.replace(prefix,'')[:-3] for i in model_files]
print(model_names)
               

In [None]:
print('Reading subset datasets from '+datadir)    
f0 = Dataset(ref_file)
ref_data = ma.mean(f0.variables[variable][:], axis=0)
ref0, ref1 = extract_PD_arrays_with_cubical_complex(ref_data, infinity=ref_data.max())
plot_PD (ref0, ref1, data_min=ref_data.min(), data_max=ref_data.max())

In [None]:
for imodel, model_file in enumerate(model_files):
    f1 = Dataset(model_file)
    model_data = ma.mean(f1.variables[variable][:], axis=0)
    model0, model1 = extract_PD_arrays_with_cubical_complex(model_data, infinity=ref_data.max())
    print(model_names[imodel], ' W2 is %f10.3' %calculate_wasserstein_distance(ref0, model0, ref1, model1, order=2))
    plot_PD (model0, model1, data_min=ref_data.min(), data_max=ref_data.max())