<h1> Demo. Data Evaluation </h1>

An example of SSH reconstruction has been produced in the "example_data_oi.ipynb" notebook. Here, an example of data evaluation is proposed. The notebook is structured as follow: 1) reading of reference and reconstructed SSH fields, 2) make field on similar spatio-temporal grid and 3) comparison of reconstrusted and refernce SSH fields (statistical/spectral comparison)

In [None]:
import xarray as xr
import cftime
import geoviews as gv
import matplotlib.pylab as plt
from matplotlib.ticker import ScalarFormatter

import numpy as np
from datetime import datetime, timedelta
import numpy
import pyinterp
import holoviews as hv
import xrft

from dask.diagnostics import ProgressBar

gv.extension('bokeh')
from matplotlib import rcParams
rcParams['font.sans-serif'] = 'TeX Gyre Heros'

#### 1) Read reference & reconstructed SSH fields

##### Read reference SSH field

In [None]:
dc_ref = xr.open_mfdataset('./dc_ref/*.nc', combine='nested', concat_dim='time')
dc_ref

##### Read reconstructed SSH field

In [None]:
dc_reconstruction = xr.open_mfdataset('ssh_rec.nc', combine='by_coords')
dc_reconstruction

#### 2) Regriding: make reconstructed & reference SSH fields onto the same grid

##### *The reconstructed SSH is not on the same spatio-temporal grid as the refrence field,*
*A regridding on the similar spato-temporal grid s required for the comparison. Here we have mde the choice to "temporally degrade" the reference SSH field into daily mean sample and interpolate the reconstructed field onto this new reference spatio-temporal grid. For this, the **pyinterp package** is used*  

In [None]:
dc_ref_sample = dc_ref.resample(time='1D').mean()
dc_ref_sample

##### Define reconstruction grid (source grid to interpolate)

In [None]:
x_axis = pyinterp.Axis(dc_reconstruction["lon"][:], is_circle=False)
y_axis = pyinterp.Axis(dc_reconstruction["lat"][:])
z_axis = pyinterp.TemporalAxis(dc_reconstruction["time"][:])

In [None]:
ssh_rec = dc_reconstruction["ssh_rec"][:].T

In [None]:
grid = pyinterp.Grid3D(x_axis, y_axis, z_axis, ssh_rec.data)

##### Define reference grid (target grid where to interpolate)

In [None]:
mx, my, mz = np.meshgrid(
    dc_ref_sample.lon.values,
    dc_ref_sample.lat.values,
    z_axis.safe_cast(dc_ref_sample.time.values),
    indexing="ij")

##### Interpolate...

In [None]:
ssh_rec = pyinterp.trivariate(grid,
                          mx.flatten(),
                          my.flatten(),
                          mz.flatten(),
                          bounds_error=True).reshape(mx.shape).T

##### Save the SSH reconstruction interpolated onto the reference spatio-temporal grid into *dc_reconstruction_interp* dataset 

In [None]:
dc_reconstruction_interp = xr.Dataset({'sossheig' : (('time', 'lat', 'lon'), ssh_rec)},
                               coords={'time': dc_ref_sample.time.values,
                                       'lon': dc_ref_sample.lon.values, 
                                       'lat': dc_ref_sample.lat.values, 
                                       })  
dc_reconstruction_interp  

#### 3) Comparison between reference and reconstructed SSH fields

##### Plot example...

In [None]:
time_selection = '2013-01-01T23:00:00'
plt.figure(figsize=(15, 5))
plt.subplot(121)
dc_reconstruction_interp.sossheig.sel(time=time_selection, method='nearest').plot(vmin=-0.2, vmax=1, cmap='gist_stern')
plt.title('Reconstruction')
plt.subplot(122)
dc_ref_sample.sossheig.sel(time=time_selection, method='nearest').plot(vmin=-0.2, vmax=1, cmap='gist_stern')
plt.title('Reference')
# plt.savefig('example_ssh.png')

In [None]:

# SSH reconstruction resampled, otherwise too heavy for display
dataset2 = gv.Dataset(dc_reconstruction_interp.coarsen({'lon':6, 'lat': 6, 'time':6}, boundary="trim").mean(), ['lon', 'lat', 'time'], 'sossheig')
images2 = dataset2.to(gv.Image).redim(sossheig=dict(range=(-0.2, 1.)))

# SSH reference resampled, otherwise too heavy for display
dataset3 = gv.Dataset(dc_ref_sample.coarsen({'lon':6, 'lat': 6, 'time':6}, boundary="trim").mean(), ['lon', 'lat', 'time'], 'sossheig')
images3 = dataset3.to(gv.Image).redim(sossheig=dict(range=(-0.2, 1.)))

# delta SSH
delta_ssh = (dc_reconstruction_interp - dc_ref_sample).coarsen({'lon':6, 'lat': 6, 'time':6}, boundary="trim").mean()
delta_ssh = delta_ssh.rename({'sossheig': 'delta_ssh'})

dataset1 = gv.Dataset(delta_ssh, ['lon', 'lat', 'time'], 'delta_ssh')
images1 = dataset1.to(gv.Image).redim(delta_ssh=dict(range=(-0.3, 0.3)))

In [None]:
layout = hv.Layout(images3.opts(width=400, height=300, cmap='gist_stern', colorbar=True, title='SSH reference', toolbar='above') + images2.opts(width=400, height=300, cmap='gist_stern', colorbar=True, title='SSH reconstruction', toolbar='above') + images1.opts(width=400, height=300, cmap='coolwarm', colorbar=True, title='SSH reconstruction - reference')).cols(2)

In [None]:
layout

#### 4) Evaluation of the reconstructed SSH fields

##### Compute RMSE-based score

In [None]:
rmse_t = (((dc_reconstruction_interp.sossheig - dc_ref_sample.sossheig)**2).mean(dim=('lon', 'lat')))**0.5/(((dc_ref_sample.sossheig)**2).mean(dim=('lon', 'lat')))**0.5

In [None]:
plt.figure(figsize=(15, 5))
(1.0 - rmse_t).plot(color='r', lw=3)
plt.ylabel('RMSEs(t)', color='r', fontweight='bold', fontsize=20)
plt.xlabel('Time', fontweight='bold', fontsize=20)
plt.ylim(0, 1)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.grid(ls='--')
plt.twinx()
plt.bar(dc_reconstruction.time.values, dc_reconstruction.nobs.values, color='grey', alpha=0.3, zorder=1)
plt.ylabel('nobs(t)', color='grey', fontweight='bold', fontsize=20)
plt.ylim(0, 10000)
plt.yticks(fontsize=18)
plt.title('RMSE-based score', fontweight='bold', fontsize=20)
plt.savefig('rmse_t.png')

##### Compute PSD-based score

In [None]:
with ProgressBar():

    err = (dc_reconstruction_interp.sossheig - dc_ref_sample.sossheig)
    err = err.chunk({"lat":1, 'time': err.time.size, 'lon': err.lon.size})
    # make time vector in days units (can be nicer !!!!!)
    err['time'] = np.arange(0, err['time'].values.size, 1.)#/24.
    
    signal = dc_ref_sample.sossheig.chunk({"lat":1, 'time': dc_ref_sample.sossheig.time.size, 'lon': dc_ref_sample.sossheig.lon.size})
    # make time vector in days units
    signal['time'] = np.arange(0, signal['time'].values.size, 1.)#/24.
    
    
    psd_err = xrft.power_spectrum(err, dim=['time', 'lon'], detrend='linear', window=True).compute()
    psd_signal = xrft.power_spectrum(signal, dim=['time', 'lon'], detrend='linear', window=True).compute()
    

In [None]:
mean_psd_signal = psd_signal.mean(dim='lat').where((psd_signal.freq_lon > 0.) & (psd_signal.freq_time > 0), drop=True)
mean_psd_err = psd_err.mean(dim='lat').where((psd_err.freq_lon > 0.) & (psd_err.freq_time > 0), drop=True)

In [None]:
plt.figure(figsize=(8, 6))
ax = plt.gca()
ax.invert_yaxis()
ax.invert_xaxis()
c1 = plt.contourf(1./(mean_psd_signal.freq_lon), 1./mean_psd_signal.freq_time, (1.0 - mean_psd_err/mean_psd_signal), levels=np.arange(0,1.1, 0.1), cmap='RdYlGn', extend='both')
cbar = plt.colorbar(pad=0.01)
plt.xlabel('wavenumber(degree_lon)', fontweight='bold', fontsize=20)
plt.ylabel('frequency (days)', fontweight='bold', fontsize=20)
#plt.xscale('log')
plt.yscale('log')
plt.grid(linestyle='--', lw=1, color='w')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.title('PSD-based score', fontweight='bold', fontsize=20)
for axis in [ax.xaxis, ax.yaxis]:
    axis.set_major_formatter(ScalarFormatter())
c2 = plt.contour(1./(mean_psd_signal.freq_lon), 1./mean_psd_signal.freq_time, (1.0 - mean_psd_err/mean_psd_signal), levels=[0.5], linewidths=2, colors='k')
cbar.add_lines(c2)

bbox_props = dict(boxstyle="round,pad=0.5", fc="w", ec="k", lw=2)
ax.annotate('Resolved scales',
            xy=(1.15, 0.8),
            xycoords='axes fraction',
            xytext=(1.15, 0.55),
            bbox=bbox_props,
            arrowprops=
                dict(facecolor='black', shrink=0.05),
                horizontalalignment='left',
                verticalalignment='center')

ax.annotate('UN-resolved scales',
            xy=(1.15, 0.2),
            xycoords='axes fraction',
            xytext=(1.15, 0.45),
            bbox=bbox_props,
            arrowprops=
                dict(facecolor='black', shrink=0.05),
                horizontalalignment='left',
                verticalalignment='center')