# Evaluation MIOST: 

In [None]:
import xarray as xr
import numpy
import hvplot.xarray
import pyinterp
import dask
import warnings
import xrft
import os
import sys
import pandas as pd
import logging
warnings.filterwarnings('ignore')
import numpy as np

##  1. Input files

##### libraries versions

In [None]:
print('xarray', xr.__version__)
print('numpy', numpy.__version__)
print('hvplot', hvplot.__version__)
print('pyinterp', pyinterp.__version__)
print('dask', dask.__version__)
print('logging', logging.__version__)
print('xrft', xrft.__version__)

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
sys.path.append('..')

In [None]:
from src.mod_regrid import *
from src.mod_eval import *
from src.mod_plot import *

### 1.1 Read Nature run SSH for mapping evaluation

In [None]:
if not os.path.exists('../ds_ref/ds_ref_1_20.nc'):
    print('ds_ref file not found...')
    print('download it...')
    os.makedirs('../ds_ref/',exist_ok=True)
    # Get nature run (it may take several minutes depending on your connection!!!!)
    !wget -O ds_ref_1_20.nc https://s3.eu-central-1.wasabisys.com/melody/data_challenge_Daniel_Guillaume/public/dc_ref/NATL60-CJM165-daily-reg-1_20.nc
ds_ref = xr.open_dataset('./ds_ref_1_20.nc')
ds_ref

### 1.2 Read MIOST SSH reconstruction

In [None]:
if not os.path.exists('../ds_rec/ds_rec_miost_1_20.nc'):
    print('ds_rec file not found...')
    print('download it...')
    !wget -O ../ds_rec/ds_rec_miost_1_20.nc https://s3.eu-central-1.wasabisys.com/melody/data_challenge_Daniel_Guillaume/public/mappings/miost/1_20/mapping_miost_Natl60_1_20_alg_c2_h2ag_h2b_j3_s3a_s3b_component_geos_lwe_barotrop.nc

ds_rec = xr.open_dataset('../ds_rec/ds_rec_miost_1_20.nc')
ds_rec

In [None]:
ds_rec=ds_rec.rename({'longitude':'lon', 'latitude':'lat'})

In [None]:
ds_ref = ds_ref.sel(dict(lat=slice(27,46),lon=slice(-66,-47)))
ds_rec = ds_rec.sel(dict(lat=slice(27,46),lon=slice(-66,-47)))

In [None]:
ds_ref

In [None]:
ds_rec

### 1.3 Regrid SSH reconstructions onto nature run grid (if needed)

In [None]:
ds_rec = rec_regrid(ds_rec, ds_ref, field='sla')

In [None]:
ds_rec

In [None]:
ds_ref

In [None]:
ds_rec=ds_rec.assign(ssh=ds_rec.sla+ds_ref.mdt)

In [None]:
ds_ref

## 2. Select periods (whole year, seasons) and domain of evaluation

In [None]:
# # Seasons for seasonal evaluation
Whole_year = slice(numpy.datetime64('2012-10-21') , numpy.datetime64('2013-09-10'))
Mid_autumn = slice(numpy.datetime64('2012-10-21') , numpy.datetime64('2012-11-30'))
Mid_winter = slice(numpy.datetime64('2013-02-01') , numpy.datetime64('2013-03-13'))
Mid_spring = slice(numpy.datetime64('2013-04-30') , numpy.datetime64('2013-06-09'))
Mid_summer = slice(numpy.datetime64('2013-07-11') , numpy.datetime64('2013-08-20'))

# # Domain for analysis : GF
lon_min = -64                                        # domain min lon
lon_max = -49                                       # domain max lon
lat_min = 29                                         # domain min lat
lat_max = 44                                        # domain max lat

In [None]:
## whole year
ds_ref_whole_year = ds_ref.sel(time=Whole_year , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True)
ds_rec_whole_year = ds_rec.sel(time=Whole_year , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True)
## seasons
ds_ref_seasonal = [ds_ref.sel(time=Mid_autumn , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True),
                   ds_ref.sel(time=Mid_winter , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True),
                   ds_ref.sel(time=Mid_spring , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True),
                   ds_ref.sel(time=Mid_summer , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True)]

ds_rec_seasonal = [ds_rec.sel(time=Mid_autumn , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True),
                   ds_rec.sel(time=Mid_winter , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True),
                   ds_rec.sel(time=Mid_spring , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True),
                   ds_rec.sel(time=Mid_summer , lon=slice(lon_min, lon_max), lat=slice(lat_min, lat_max), drop=True)]

In [None]:
date_seasonal_ref=np.asarray([ds_ref_seasonal[j].time.size for j in range(len(ds_ref_seasonal))])
date_seasonal_rec=np.asarray([ds_rec_seasonal[j].time.size for j in range(len(ds_rec_seasonal))])
if not (np.all(date_seasonal_ref == date_seasonal_ref[0]) and np.all(date_seasonal_rec == date_seasonal_rec[0])):
    raise ValueError('Every seasonal datasets must have the same size along time axis ')

In [None]:
ds_rec_whole_year

In [None]:
ds_ref_whole_year

##  3. Evaluation MIOST

In [None]:
output_directory = '../results/'
if not os.path.exists(output_directory):
    os.mkdir(output_directory) 

In [None]:
!pip install tabulate

### 3.1 Whole year evaluation

In [None]:
# Eval
rmse_t, rmse_xy, leaderboard_nrmse, leaderboard_nrmse_std = rmse_based_scores(ds_rec_whole_year, ds_ref_whole_year)
psd, leaderboard_psds_score, leaderboard_psdt_score  = psd_based_scores(ds_rec_whole_year, ds_ref_whole_year)


filename_rmse_t = output_directory + 'rmse_t_miost_ssh_reconstruction_2012-10-21-2013-09-10_alg_c2_h2ag_h2b_j3_s3a_s3b.nc'
filename_rmse_xy = output_directory + 'rmse_xy_miost_ssh_reconstruction_2012-10-21-2013-09-10_alg_c2_h2ag_h2b_j3_s3a_s3b.nc'
filename_psd = output_directory + 'psd_miost_ssh_reconstruction_2012-10-21-2013-09-10_alg_c2_h2ag_h2b_j3_s3a_s3b.nc'
# Save results
# rmse_t.to_netcdf(filename_rmse_t)
# rmse_xy.to_netcdf(filename_rmse_xy)
psd.name = 'psd_score'
# psd.to_netcdf(filename_psd)

# Print leaderboard
data = [['MIOST', 
         leaderboard_nrmse, 
         leaderboard_nrmse_std, 
         leaderboard_psds_score, 
         leaderboard_psdt_score,
        'GF',
        'eval_miost.ipynb']]

Leaderboard = pd.DataFrame(data, 
                           columns=['Method', 
                                    "µ(RMSE) ", 
                                    "σ(RMSE)", 
                                    'λx (degree)', 
                                    'λt (days)', 
                                    'Domain',
                                    'Reference'])
print("Summary of the leaderboard metrics, over the whole year:")
Leaderboard
print(Leaderboard.to_markdown())

In [None]:
rmse_t.hvplot.line(x='time', y='rmse_t', ylim=(0, 1), cmap=['royalblue'], title='RMSE-based scores')

The figure above shows the time series of the RMSE scores for the reconstruction of SSH.

In [None]:
rmse_xy.hvplot.contourf(x='lon', y='lat', levels=list(numpy.arange(0.,0.5, 0.05)), height=300, width=400, cmap='Reds', subplots=True, clabel='RMSE[m]')

In [None]:
psd = psd.expand_dims({'experiment':1})
psd['experiment'] = ['Whole year']
plot_psd_score(psd,time_min=5,time_max=25, step_time=5)

The PSD-based score evaluates the spatio-temporal scales resolved in mapping (green area). Resolution limits can be defined as the contour where the PSD score = 0.5, black contour in the figure (i.e. space-time scales where the reconstruction SSH error level is 2 times lower than the real SSH signal). 

### 3.2 Seasonal evaluation 

In [None]:
Seasons = ['Mid_autumn','Mid_winter','Mid_spring','Mid_summer']
Leaderboard_seasonal = []
Psd_seasonal = []
Data = []
Method = ['MIOST','','','']
Domain = ['GF','','','']
Reference = ['eval_miost.ipynb','','','']

for i,ds_ref, ds_rec in zip(np.arange(len(Seasons)),ds_ref_seasonal, ds_rec_seasonal): 
    # Eval
    rmse_t, rmse_xy, leaderboard_nrmse, leaderboard_nrmse_std = rmse_based_scores(ds_rec, ds_ref)
    psd, leaderboard_psds_score, leaderboard_psdt_score  = psd_based_scores(ds_rec, ds_ref)


    filename_rmse_t = output_directory + 'rmse_t_miost_ssh_reconstruction_2012-10-21-2013-09-10_alg_c2_h2ag_h2b_j3_s3a_s3b.nc'
    filename_rmse_xy = output_directory + 'rmse_xy_miost_ssh_reconstruction_2012-10-21-2013-09-10_alg_c2_h2ag_h2b_j3_s3a_s3b.nc'
    filename_psd = output_directory + 'psd_miost_ssh_reconstruction_2012-10-21-2013-09-10_alg_c2_h2ag_h2b_j3_s3a_s3b.nc'
    # Save results
    # rmse_t.to_netcdf(filename_rmse_t)
    # rmse_xy.to_netcdf(filename_rmse_xy)
    psd.name = 'psd_score'
    # psd.to_netcdf(filename_psd)

    # Print leaderboard
    Data.append([Method[i],
             Seasons[i],
             leaderboard_nrmse, 
             leaderboard_nrmse_std, 
             leaderboard_psds_score, 
             leaderboard_psdt_score,
             Domain[i],
             Reference[i]])

    Psd_seasonal.append(psd)
    
Leaderboard_seasonal=pd.DataFrame(Data, 
                            columns=(['Method',
                                     'Season',
                                     "µ(RMSE) ", 
                                     "σ(RMSE)", 
                                     'λx (degree)', 
                                     'λt (days)', 
                                     'Domain',
                                     'Reference']))
    
print("Summary of the leaderboard metrics, for each season:")
print(Leaderboard_seasonal.to_markdown())

In [None]:
psd_concat = xr.concat((Psd_seasonal[0], Psd_seasonal[1], Psd_seasonal[2], Psd_seasonal[3]), dim='experiment')
psd_concat['experiment'] = ["Mid_autumn", "Mid_winter", "Mid_spring", "Mid_summer"] 
plot_psd_score_seasonal(psd_concat, time_min=5,time_max=25, step_time=5)