In [3]:
%load_ext nb_black

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

In [5]:
# Basics
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Data
import xarray as xr
import h5py

# Helpful
import time
import datetime
import itertools
from itertools import product

# My Methods
from src.utils.CRPS import *
from src.utils.data_split import *
from src.models.EMOS import *

<IPython.core.display.Javascript object>

In [6]:
importlib.reload(CRPS)
importlib.reload(EMOS)

NameError: name 'CRPS' is not defined

<IPython.core.display.Javascript object>

### Goal of this notebook: Implement EMOS and train it on train dataset

#### 1. Load dataset

In [7]:
# Training Dataset
t2m_train = xr.open_dataset("/Data/Delong_BA_Data/Mean_ens_std/t2m_train.h5")
t2m_X_train = t2m_train.t2m_train
t2m_y_train = t2m_train.t2m_truth

<IPython.core.display.Javascript object>

In [8]:
# Test Detaset
t2m_test = xr.open_dataset("/Data/Delong_BA_Data/Mean_ens_std/t2m_test.h5")
t2m_X_test = t2m_test.t2m_test
t2m_y_test = t2m_test.t2mtest_truth

<IPython.core.display.Javascript object>

#### 2. Prepare Data
Make datasets for all 32 lead times

##### 2.1 Train Dataset

In [9]:
t2m_X_train_glob_mean = []
t2m_X_train_glob_std = []
t2m_y_train_glob_truth = []
for i in range(t2m_train.phony_dim_1.shape[0]):
    t2m_X_train_glob_mean.append(
        t2m_X_train.isel(phony_dim_4=0, phony_dim_1=i).values.flatten()
    )
    t2m_X_train_glob_std.append(
        t2m_X_train.isel(phony_dim_4=1, phony_dim_1=i).values.flatten()
    )
    t2m_y_train_glob_truth.append(t2m_y_train.isel(phony_dim_1=i).values.flatten())

<IPython.core.display.Javascript object>

In [10]:
# Only necessary for Baseline, not for EMOS
# Change all zeros in std to a small value epsilon
epsilon = 1e-9  # Small epsilon to add to zeros in std
for i in range(len(t2m_X_train_glob_std)):
    t2m_X_train_glob_std[i][(t2m_X_train_glob_std[i] == 0)] += epsilon

<IPython.core.display.Javascript object>

##### 2.2 Test Dataset

In [11]:
t2m_X_test_glob_mean = []
t2m_X_test_glob_std = []
t2m_y_test_glob_truth = []
for i in range(t2m_test.phony_dim_1.shape[0]):
    t2m_X_test_glob_mean.append(
        t2m_X_test.isel(phony_dim_4=0, phony_dim_1=i).values.flatten()
    )
    t2m_X_test_glob_std.append(
        t2m_X_test.isel(phony_dim_4=1, phony_dim_1=i).values.flatten()
    )
    t2m_y_test_glob_truth.append(t2m_y_test.isel(phony_dim_1=i).values.flatten())

<IPython.core.display.Javascript object>

In [12]:
# Only necessary for Baseline, not for EMOS
# Change all zeros in std to a small value epsilon
epsilon = 1e-9  # Small epsilon to add to zeros in std
for i in range(len(t2m_X_test_glob_std)):
    t2m_X_test_glob_std[i][(t2m_X_test_glob_std[i] == 0)] += epsilon

<IPython.core.display.Javascript object>

In [13]:
# Decide on which lead_time to test
lead_time = 8

<IPython.core.display.Javascript object>

#### 3. Baseline

In [14]:
crps_baseline = crps_normal(
    mu=t2m_X_train_glob_mean[lead_time],
    sigma=t2m_X_train_glob_std[lead_time],
    y=t2m_y_train_glob_truth[lead_time],
)

<IPython.core.display.Javascript object>

In [15]:
crps_baseline.mean()

0.028882670428874848

<IPython.core.display.Javascript object>

#### 4. Train global Emos

In [17]:
EMOS_glob = build_EMOS_network_keras(compile=True)

<IPython.core.display.Javascript object>

In [18]:
t2m_X_train_glob_mean[lead_time]

array([-0.0249654 , -0.02140305, -0.02148053, ...,  0.38360298,
        0.40706664,  0.4334942 ], dtype=float32)

<IPython.core.display.Javascript object>

In [None]:
EMOS_glob.fit(
    [t2m_X_train_glob_mean[lead_time], t2m_X_train_glob_std[lead_time]],
    t2m_y_train_glob_truth[lead_time],
    batch_size=5000,
    epochs=5,
    validation_split=0.2,
)

In [None]:
t2m_31_preds = EMOS_glob.predict(
    [t2m_X_test_glob_mean[lead_time], t2m_X_test_glob_std[lead_time]]
)

In [None]:
crps_emos = crps_normal(
    mu=t2m_31_preds[:, 0], sigma=t2m_31_preds[:, 0], y=t2m_y_test_glob_truth[lead_time]
)

In [None]:
crps_emos.mean()