In [None]:
%load_ext autoreload
%autoreload 2

# Linear model

Use a linear model to go from input forecast to output distribution.

In [None]:
import dask
import dask.array as da
import dask.distributed
import dask_jobqueue
import datetime
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import scipy.stats
import xarray as xr
import xskillscore as xs

In [None]:
from crims2s.util import fix_dataset_dims

In [None]:
INPUT_TRAIN = '***BASEDIR***training-input/0.3.0/netcdf'
OBSERVATIONS = '***BASEDIR***training-output-reference/'
BENCHNMARK = '***BASEDIR***training-output-benchmark/'

## Boost dask cluster

In [None]:
cluster = dask_jobqueue.SLURMCluster(
    env_extra=['source ***HOME***.bash_profile','conda activate s2s'],
)

In [None]:
cluster.scale(jobs=6)

In [None]:
cluster.scale(jobs=3)  # Scale to two working nodes as configured.
client = dask.distributed.Client(cluster)

In [None]:
client

## Read data

### ECMWF

In [None]:
CENTER = 'ecmwf'
INPUT_FIELDS = ['t2m', 'sm100', 'st100', 'sst', 'gh']

In [None]:
input_path = pathlib.Path(INPUT_TRAIN)

In [None]:
input_files = []
for field in INPUT_FIELDS:
    input_files.extend([f for f in input_path.iterdir() if CENTER in f.stem and field in f.stem])

input_files = sorted(input_files)

In [None]:
input_files[:10]

In [None]:
ecmwf = xr.open_mfdataset(input_files, preprocess=fix_dataset_dims)

In [None]:
ecmwf = ecmwf.squeeze(dim=['depth_below_and_layer'], drop=True).sel(plev=500., drop=True)

In [None]:
ecmwf = ecmwf.rename_vars({'gh': 'gh500'})

In [None]:
ecmwf

In [None]:
ecmwf_w34 = ecmwf.sel(lead_time=slice('14D', '27D'))
ecmwf_w34_train = ecmwf_w34.sel(forecast_year=slice(None, 2018))
ecmwf_w34_val = ecmwf_w34.sel(forecast_year=slice(2019, None))

In [None]:
ecmwf_w34_val

### Observations

In [None]:
obs_path = pathlib.Path(OBSERVATIONS)
obs_files = [f for f in obs_path.iterdir() if 't2m' in f.stem]

In [None]:
obs = xr.open_mfdataset(obs_files, preprocess=fix_dataset_dims).isel(lead_time=slice(1, None))
obs_w34 = obs.sel(lead_time=slice('14D', '27D'))

In [None]:
obs_w34_train = obs_w34.sel(forecast_year=slice(None, 2018))
obs_w34_val = obs_w34.sel(forecast_year=slice(2019, None))

In [None]:
obs_w34_val

## Normalize fields

In [None]:
ecmwf_w34_train

In [None]:
ecmwf_w34_train_mean.t2m.compute()

In [None]:
ecmwf_w34_train_mean = ecmwf_w34_train.mean().compute()

In [None]:
ecmwf_w34_train_std = ecmwf_w34_train.std().compute()

In [None]:
ecmwf_w34_train_mean

In [None]:
ecmwf_w34_train_normalized = (ecmwf_w34_train - ecmwf_w34_train_mean) / ecmwf_w34_train_std

In [None]:
ecmwf_w34_train_mean.compute()