In [None]:
import dask
import dask.distributed
import dask_jobqueue
import pathlib
import xarray as xr

from crims2s.distribution import fit_normal_xarray
from crims2s.util import fix_dataset_dims, add_biweekly_dim

In [None]:
TRAINING_INPUT = '***BASEDIR***/training-input'

In [None]:
train_input_path = pathlib.Path(TRAINING_INPUT)

In [None]:
files_list = []
for f in list(train_input_path.iterdir()):
    if 'eccc' in f.name and ('-t2m-' in f.name or '-tp-' in f.name):
        files_list.append(f)

In [None]:
len(files_list)

In [None]:
xr.open_dataset(files_list[0])

In [None]:
sample = add_biweekly_dim(xr.open_dataset(files_list[0]), weeks_12=False)

In [None]:
sample

In [None]:
sample.lead_time

In [None]:
parameters = fit_normal_xarray(sample.t2m, dim=['lead_time', 'realization'])

In [None]:
parameters

In [None]:
parameters.isel(forecast_time=0, biweekly_forecast=0).t2m_sigma.plot()

In [None]:
sample.isel(biweekly_forecast=2, lead_time=5, realization=0, forecast_time=0).t2m.plot()

In [None]:
d = xr.open_mfdataset(files_list[0])

In [None]:
d

In [None]:
cluster = dask_jobqueue.SLURMCluster(
    env_extra=['source ***HOME***.bash_profile','conda activate s2s'],
    name='s2s',
)

In [None]:
cluster.scale(jobs=2)

In [None]:
client = dask.distributed.Client(cluster)

In [None]:
client

In [None]:
d = xr.open_mfdataset(files_list, preprocess=fix_dataset_dims)

In [None]:
d.forecast_time

In [None]:
d.forecast_year

In [None]:
nan_counts = d.tp.isnull().mean(dim=['forecast_year', 'forecast_monthday', 'lead_time', 'latitude', 'longitude']).compute()

In [None]:
n_null = d.tp.isnull().sum(dim=['realization']).persist()

In [None]:
(n_null >= 1).mean().compute()

In [None]:
(n_null >= 2).mean().compute()

In [None]:
(n_null >= 3).mean().compute()

In [None]:
(n_null == 4).mean().compute()

In [None]:
nan_counts.compute()