# ML Dataset stats

I want to perform some verification on the ml dataset I generated.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import dask.distributed
import logging
import pathlib
import xarray as xr
import torch
import torch.distributions

from crims2s.dask import create_dask_cluster
from crims2s.util import add_biweekly_dim

In [None]:
#ML_DATASET_DIR = '***BASEDIR***/mlready/2021-08-08-test/'
ML_DATASET_DIR = '***BASEDIR***/mlready/2021-08-28-test'

In [None]:
ML_DATASET_DIR = '***BASEDIR***/mlready/2021-09-07-test-set'

## Load dataset

In [None]:
def preprocess_one_example(dataset):
    return dataset.expand_dims('forecast_time')

In [None]:
dataset_files = sorted([x for x in pathlib.Path(ML_DATASET_DIR).iterdir() if '.nc' in x.name])

In [None]:
dataset_files[:10]

In [None]:
for f in dataset_files:
    xr.open_dataset(f, group='/edges')

In [None]:
print(f)

In [None]:
features = xr.open_mfdataset(dataset_files[:3], group='/features', concat_dim='forecast_time', preprocess=preprocess_one_example)

In [None]:
model = xr.open_mfdataset(dataset_files[:3], group='/model', concat_dim='forecast_time', preprocess=preprocess_one_example)

In [None]:
obs = xr.open_mfdataset(dataset_files[:3], group='/obs', concat_dim='forecast_time', preprocess=preprocess_one_example)

In [None]:
parameters = xr.open_mfdataset(dataset_files[:3], group='/model_parameters', concat_dim='forecast_time', preprocess=preprocess_one_example)

In [None]:
model

In [None]:
parameters

In [None]:
obs.isnull().sum().compute()

In [None]:
features

In [None]:
obs

In [None]:
len(obs.lead_time)

In [None]:
46 / 7

In [None]:
some_computed = sample.isnull().sum(dim=['latitude', 'longitude', 'realization'])

In [None]:
some_computed.sum(dim=['variable', 'forecast_time']).compute().x.plot()

In [None]:
some_computed.isel(lead_time=[0, 1]).sum(dim='variable').compute()

In [None]:
sample.sel(variable='sst').isel(lead_time=slice(1, None)).sum(dim=['forecast_time', 'lead_time', 'realization']).compute().x.plot()

## Check y

In [None]:
target = xr.open_mfdataset(dataset_files[:10], group='/terciles', concat_dim='forecast_time', preprocess=preprocess_one_example)

In [None]:
target = target.compute()

In [None]:
target

In [None]:
target.isnull().isel(forecast_time=0).sum(dim=['category', 'lead_time']).t2m.plot()

# Check obs

In [None]:
obs = xr.open_mfdataset(dataset_files[:10], group='/obs', concat_dim='forecast_time', preprocess=preprocess_one_example)

In [None]:
obs.isnull().sum(dim=['forecast_time', 'lead_time']).t2m.plot()

In [None]:
model_params = xr.open_mfdataset(dataset_files, group='/model_parameters', preprocess=preprocess_one_example)

In [None]:
model_params

# Check parameters fit

## For T2M

In [None]:
model_biweekly = add_biweekly_dim(model, weeks_12=False)

In [None]:
model_biweekly

In [None]:
reworked_t2m = model_biweekly.isel(lead_time=slice(1, None)).t2m.squeeze().transpose('realization', 'lead_time', 'forecast_time', ...)

In [None]:
reworked_t2m

In [None]:
t2m_data = torch.tensor(reworked_t2m.data.compute())

In [None]:
t2m_mu = torch.tensor(parameters.t2m_mu.squeeze().data.compute())
t2m_sigma = torch.tensor(parameters.t2m_sigma.squeeze().data.compute())

In [None]:
t2m_mu.max()

In [None]:
t2m_mu.shape

In [None]:
t2m_data.shape

In [None]:
distribution = torch.distributions.Normal(t2m_mu, t2m_sigma)

In [None]:
-distribution.log_prob(t2m_data).mean()

## For TP

In [None]:
tp_data = torch.tensor(model_biweekly.isel(lead_time=-1).tp.squeeze().transpose('realization', 'forecast_time', ...).data.compute())

In [None]:
tp_data.shape

In [None]:
parameters

In [None]:
tp_mu = torch.tensor(parameters.tp_mu.squeeze().data.compute())
tp_sigma = torch.tensor(parameters.tp_sigma.squeeze().data.compute())

In [None]:
tp_sigma.min()

In [None]:
distribution = torch.distributions.Normal(tp_mu, tp_sigma + 1e-9)

In [None]:
-distribution.log_prob(tp_data).mean()

In [None]:
tp_alpha = torch.tensor(parameters.tp_alpha.squeeze().data.compute())
tp_beta = torch.tensor(parameters.tp_beta.squeeze().data.compute())

In [None]:
distribution = torch.distributions.Gamma(tp_alpha, tp_beta)

In [None]:
-distribution.log_prob(tp_data + 1e-9).mean()

In [None]:
tp_data.min()

In [None]:
tp_mu.min()