## Relative bias between data sets

Compare the mean difference between data sets over all time/seasons. This would give a reference of bias. Could see how the bias compares to the error variance. If it is on the same scale as the error variance, then it would show that bias is not as important as the error variance.

In [1]:
import hvplot.xarray
import holoviews as hv
import panel as pn
import cartopy.crs as ccrs
import numpy as np
import xarray as xr
import itertools
import warnings

In [2]:
files = ['Data/ssebop/ssebop_aet_regridded.nc',
         'Data/gleam/gleam_aet.nc',
         'Data/era5/era5_aet_regridded.nc',
         'Data/nldas/nldas_aet_regridded.nc',
         'Data/terraclimate/terraclimate_aet_regridded.nc',        
         'Data/wbet/wbet_aet_regridded.nc',
         ]
dataset_name = ['SSEBop', 'GLEAM', 'ERA5', 'NLDAS', 'TerraClimate', 'WBET']
dataset_abrv = ['S', 'G', 'E', 'N', 'T', 'W']

date_ranges = np.zeros((2, len(files)), dtype='datetime64[ns]')
for i, file in enumerate(files):
    set = xr.open_dataset(file, engine='netcdf4', chunks={'lon': -1, 'lat': -1, 'time': -1})
    date_ranges[:, i] = [set.time.min().values, set.time.max().values]

# Take the third oldest start and third most recent end dates
date_range = [np.sort(date_ranges[0, :])[2], np.sort(date_ranges[1, :])[3]]
date_range

[numpy.datetime64('1958-01-01T00:00:00.000000000'),
 numpy.datetime64('2022-12-01T00:00:00.000000000')]

In [3]:
def preprocess(ds):
    """
    Keep only the specified time range for each file.
    """
    return ds.sel(time=slice(date_range[0], date_range[1]))

ds = xr.open_mfdataset(files, engine='netcdf4', preprocess=preprocess, combine='nested', concat_dim='dataset_name')
ds = ds.assign_coords({'dataset_name': dataset_name})
ds.dataset_name.attrs['description'] = 'Dataset name'

# Need time as first index for TC computation
ds = ds.transpose('time', ...)
# The data set is less than 1GiB, so let's read it into memory vs keeping as a dask array
ds = ds.compute()
ds

In [4]:
# Generate a list of the combinations
combos = list(itertools.combinations(dataset_abrv, 2))
combos = [list(combo) for combo in combos]
combos

[['S', 'G'],
 ['S', 'E'],
 ['S', 'N'],
 ['S', 'T'],
 ['S', 'W'],
 ['G', 'E'],
 ['G', 'N'],
 ['G', 'T'],
 ['G', 'W'],
 ['E', 'N'],
 ['E', 'T'],
 ['E', 'W'],
 ['N', 'T'],
 ['N', 'W'],
 ['T', 'W']]

In [5]:
def common_date_range(ds, combo):
    """Return the common date slice of the datasets."""
    old_common_date = []
    recent_common_date = []
    for abrv in combo:
        idx = [j for j in range(len(ds['dataset_name'])) if abrv == ds['dataset_name'][j]][0]
        old_common_date.append(date_ranges[0, idx])
        recent_common_date.append(date_ranges[1, idx])
    
    return slice(np.max(old_common_date), np.min(recent_common_date))

In [6]:
# We want to ignore all of the sqrt and log warnings with negative values
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Override the name to the abbreviation for easier indexing
ds['dataset_name'] = dataset_abrv

# Create list of seasons
ds_diff = []
for combo in combos:
    ds_combo = ds.sel(time=common_date_range(ds, combo), dataset_name=combo)
    
    da_diff = ds_combo.aet.diff('dataset_name')
    da_diff = da_diff.squeeze('dataset_name').drop_vars('dataset_name')
    
    ds_diff.append(xr.Dataset(data_vars={'difference': da_diff},
                              coords={'dataset_combo': [''.join(list(reversed(combo)))], 
                                      'time': ds.time, 'lat': ds.lat, 'lon': ds.lon}))

ds_diff = xr.concat(ds_diff, dim='dataset_combo')

ds_diff.difference.attrs['description'] = 'Difference between two data sets listed in dataset_combo'
ds_diff.dataset_combo.attrs['description'] = ('Dataset combination used in difference '
                                             '(abbriviations: T=TerraClimate, E=ERA5, '
                                             'N=NLDAS, G=GLEAM, W=WBET, S=SSEBop).')
ds_diff.difference.attrs['units'] = 'mm.month-1'

# Reset the name back from the abbreviation
ds['dataset_name'] = dataset_name

ds_diff

In [7]:
plt = ds_diff.difference.hvplot(groupby=['dataset_combo', 'time'], geo=True, coastline=True,
                                clim=(-75, 75), cmap='PuOr').opts(frame_width=500)

pn.panel(plt, widget_location='top')

In [8]:
seasons = ['All'] + list(np.unique(ds.time.dt.season))
tc_est_averages = xr.open_dataset('Data/compiled_TC_avg_errs.nc', engine='netcdf4')
tc_est_averages['dataset_name'] = dataset_abrv


def mean_diff_plots(dataset_combo='WS', season='All'):
    tc_avg_season = tc_est_averages.sel(season=season)
    if season == 'All':
        ds_season = ds_diff
    else:
        ds_season = ds_diff.isel(time=(ds.time.dt.season == season))

    ds_combo = ds_season.sel(dataset_combo=dataset_combo)
    ds_median = ds_combo.difference.median(dim='time')
    ds_median_abs = abs(ds_median)
    ds_median_abs.name = 'absolute difference'
    ds1_error = tc_avg_season.median_error.sel(dataset_name=dataset_combo[0])
    ds2_error = tc_avg_season.median_error.sel(dataset_name=dataset_combo[1])
    ds1_bias_var_diff = (ds1_error - ds_median_abs)/ds1_error * 100
    ds2_bias_var_diff = (ds2_error - ds_median_abs)/ds2_error * 100

    plt = (ds_median.hvplot(geo=True, coastline=True, clim=(-50, 50), cmap='PuOr',
                            title='Median Difference of '+dataset_combo[0]+' - '+dataset_combo[1]+' (Bias)').opts(frame_width=500)
           + ds_median_abs.hvplot(geo=True, coastline=True, clim=(0, 50), cmap='Purples',
                                  title='Median Absolute Difference of '+dataset_combo[0]+' - '+dataset_combo[1]+' (Absolute Bias)').opts(frame_width=500)
           # + ds1_error.hvplot(geo=True, coastline=True, clim=(0, 50), cmap='Purples',
           #                    title='Median Error Standard Deviation for '+dataset_combo[0]).opts(frame_width=500)
           # + ds2_error.hvplot(geo=True, coastline=True, clim=(0, 50), cmap='Purples',
           #                    title='Median Error Standard Deviation for '+dataset_combo[1]).opts(frame_width=500)
           + ds1_bias_var_diff.hvplot(geo=True, coastline=True, clim=(-100, 100), cmap='PuOr',
                                      title='Percent Difference of Median Error of '+dataset_combo[0]+' and Absolute Bias').opts(frame_width=500)
           + ds2_bias_var_diff.hvplot(geo=True, coastline=True, clim=(-100, 100), cmap='PuOr',
                                      title='Percent Difference of Median Error of '+dataset_combo[1]+' and Absolute Bias').opts(frame_width=500))

    return plt.cols(2)

# Limit combo options to have W as the common base
dataset_combo_widget = pn.widgets.Select(name="dataset_combo", value="WS", options=['WS', 'WG', 'WE', 'WN', 'WT'])
season_widget = pn.widgets.Select(name="season", value="All", options=['All', 'DJF', 'MAM', 'JJA', 'SON'])

bound_plot = pn.bind(mean_diff_plots, dataset_combo=dataset_combo_widget, season=season_widget)

pn.Column(dataset_combo_widget, season_widget, bound_plot)

As a reminder the data set abbreviation are: **S=SSEBop, G=GLEAM, E=ERA5, N=NLDAS, T=TerraClimate, W=WBET**