# Simple linear model

Try and train a simple linear model for one center, for one forecast date, for one lead time.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import dask
import dask.array as da
import dask.distributed
import dask_jobqueue
import pathlib
import xarray as xr

In [None]:
INPUT_TRAIN = '***HOME***hdd_scratch/s2s/training-input/0.3.0/netcdf'
input_path = pathlib.Path(INPUT_TRAIN)

In [None]:
OUTPUT_TRAIN = '***HOME***hdd_scratch/s2s/training-output-benchmark/'

In [None]:
REFERENCE_TRAIN = '***HOME***hdd_scratch/s2s/training-output-reference/'

## Boot Dask cluster

In [None]:
cluster = dask_jobqueue.SLURMCluster(
    env_extra=['source ***HOME***.bash_profile','conda activate s2s'],
)

In [None]:
cluster.scale(jobs=2)  # Scale to two working nodes as configured.
client = dask.distributed.Client(cluster)

In [None]:
client

## Read input data

In [None]:
CENTER = 'ecmwf'
DATE = '20200102'

In [None]:
files = [f for f in input_path.iterdir() if CENTER in f.stem if DATE in f.stem]

In [None]:
files

In [None]:
variable = [x.stem.split('-')[2] for x in files]

In [None]:
sorted(variable)

In [None]:
'entireAtmosphere' in datasets[0].coords

In [None]:
datasets[0]

In [None]:
dataset_with_level = []
dataset_wo_level = []

for d in datasets:
    if 'plev' in d.coords:
        dataset_with_level.append(d)
    else:
        dataset_wo_level.append(d)

In [None]:
dataset_wo_level[0]

In [None]:
full_dataset = xr.merge(dataset_wo_level)

In [None]:
full_dataset

In [None]:
one_lead_time = full_dataset.isel(lead_time=21, meanSea=0, entireAtmosphere=0, nominal_top=0).drop(['rsn', 'siconc', 'lsm', 'sst', 'tp'])

In [None]:
one_lead_time

In [None]:
one_lead_time.isnull().sum()

In [None]:
one_lead_time.shape

In [None]:
np_input = one_lead_time.to_array(dim='field').transpose('forecast_time', 'latitude', 'longitude', 'field', 'realization').data

In [None]:
np_input = np_input.reshape(20, 121, 240, -1)

In [None]:
np_input.shape

## Reading output values

In [None]:
!ls {OUTPUT_TRAIN}

In [None]:
t2m = xr.open_dataset(OUTPUT_TRAIN + '/t2m.nc')

In [None]:
t2m

In [None]:
first_forecast_t2m = t2m.sel(forecast_time=t2m.forecast_time.dt.day == 2).isel(lead_time=0)

In [None]:
first_forecast_t2m

In [None]:
t2m.lead_time.data.astype('timedelta64[D]')

In [None]:
first_forecast_t2m.isel(category=2, forecast_time=20).t2m.plot()

## Check one of the bi-weekly files

In [None]:
t2m_weekly = xr.open_dataset(OUTPUT_TRAIN + '/t2m-weeks-34.nc')

In [None]:
t2m_weekly

## Check the other type of forecast

In [None]:
reference_t2m = xr.open_dataset(REFERENCE_TRAIN + '/t2m-20200102.nc')

In [None]:
reference_t2m

In [None]:
reference_t2m.lead_time.data.astype('timedelta64[D]')