In [None]:
%load_ext autoreload
%autoreload 2

# Debias precipitation using a gamma distribution

So same thing as the other debiasing notebooks, but this time we use a gamma distribution instead of a gaussian.
This should lead to a slightly better debiasing.

In [None]:
import dask
import dask.array as da
import dask.distributed
import datetime
import matplotlib.pyplot as plt
import os
import numpy as np
import pandas as pd
import pathlib
import scipy.stats
import typing
import xarray as xr
import xskillscore as xs

In [None]:
from crims2s.dask import create_dask_cluster
from crims2s.util import fix_dataset_dims

In [None]:
INPUT_TRAIN = '***BASEDIR***training-input/0.3.0/netcdf'
OBSERVATIONS = '***BASEDIR***/processed/training-output-reference/'
BENCHNMARK = '***BASEDIR***training-output-benchmark/'

## Boost dask cluster

In [None]:
cluster = create_dask_cluster()
cluster.scale(jobs=2)

In [None]:
client = dask.distributed.Client(cluster)

In [None]:
client

## Generic Functions

In [None]:
def extract_train_validation_from_lead_time(xr_data) -> typing.Tuple:
    xr_data_sub_train = xr_data.sel(forecast_year=slice(None, 2018))
    xr_data_sub_val = xr_data.sel(forecast_year=slice(2019, None))
    
    return xr_data_sub_train, xr_data_sub_val

In [None]:
def compute_and_correct_bias(data_center_train, data_center_val, obs_train):
    
    bias = (obs_train - data_center_train).mean(dim=['lead_time', 'forecast_year'])
    corrected_bias = data_center_val + bias
    
    return bias, corrected_bias

In [None]:
def add_biweekly_dim(dataset):
    weeklys = []
    for s in [slice('0D', '13D'), slice('14D', '27D'), slice('28D', '41D')]:
        weekly_forecast = dataset.sel(lead_time=s)

        first_lead = pd.to_timedelta(weekly_forecast.lead_time[0].item())

        weekly_forecast = weekly_forecast.expand_dims(dim='biweekly_forecast').assign_coords(biweekly_forecast=[first_lead])
        weekly_forecast = weekly_forecast.assign_coords(lead_time=(weekly_forecast.lead_time - first_lead))
        weeklys.append(weekly_forecast)
        
    weeklys[2] -= weeklys[1].isel(lead_time=-1)
    weeklys[1] -= weeklys[0].isel(lead_time=-1)
        
    return xr.concat(weeklys, dim='biweekly_forecast').transpose('forecast_year', 'forecast_dayofyear', 'biweekly_forecast', ...)

# Read DATA

In [None]:
CENTER = 'ncep'
FIELD = 'tp'

In [None]:
input_path = pathlib.Path(INPUT_TRAIN)
input_files_tp = sorted([f for f in input_path.iterdir() if CENTER in f.stem and FIELD in f.stem])

In [None]:
input_files_tp[:10]

In [None]:
ecmwf_tp_raw = xr.open_mfdataset(input_files_tp, preprocess=fix_dataset_dims)

In [None]:
ecmwf_tp_raw.isel(lead_time=0, realization=0, forecast_year=0, forecast_dayofyear=0).tp.plot()

There are non-zero values on the first lead time, so it's the values accumulated after 24h

In [None]:
ecmwf_tp = add_biweekly_dim(ecmwf_tp_raw)

In [None]:
ecmwf_tp

In [None]:
n_smaller = (ecmwf_tp.isel(lead_time=-1) < ecmwf_tp.isel(lead_time=0))

In [None]:
n_smaller

In [None]:
n_smaller.sum(dim=['realization', 'latitude', 'longitude', 'forecast_dayofyear', 'forecast_year']).compute()

### Observations

In [None]:
obs_path = pathlib.Path(OBSERVATIONS)
obs_files = [f for f in obs_path.iterdir() if 'tp' in f.stem]

In [None]:
obs_files[:4]

In [None]:
obs_tp_raw = xr.open_mfdataset(obs_files)
obs_tp_raw = obs_tp_raw.assign_coords(lead_time=obs_tp_raw.lead_time - obs_tp_raw.lead_time[0])

In [None]:
obs_tp = add_biweekly_dim(obs_tp_raw)

In [None]:
obs_tp

In [None]:
obs_tp = obs_tp.isel(lead_time=-1) - obs_tp.isel(lead_time=0)

## Split in train test

In [None]:
ecmwf_tp_train, ecmwf_tp_val = extract_train_validation_from_lead_time(ecmwf_tp)

In [None]:
obs_tp_train, obs_tp_val = extract_train_validation_from_lead_time(obs_tp)

In [None]:
ecmwf_tp_train

In [None]:
ecmwf_tp_train.isel(biweekly_forecast=1, forecast_dayofyear=10, latitude=30, longitude=30)

In [None]:
obs_tp_train

## Fit Gamma distribution

In [None]:
one_slice = ecmwf_tp_train.isel(biweekly_forecast=1, forecast_dayofyear=0).compute()

In [None]:
(one_slice.tp < 0.0).sum(dim=['forecast_year', 'realization']).plot()

In [None]:
scipy.stats.gamma.fit(one_slice.tp.data)

In [None]:
xr.apply_ufunc(scipy.stats.gamma.fit)