# Training set creation

This notebook is used to create the training set for the DL_EC model. The training set is created by combining the data from the different sources (mostly CMIP6). The data is then split into training and validation sets. The training set is then saved to a file.

In [1]:
import glob
import json
import os
import random
from glob import glob

import numpy as np
import pandas as pd
import xarray as xr
from dmelon.utils import check_folder
from tqdm.notebook import tqdm

In [2]:
def prepare_label(ds_label, index, stime, etime, model_name=None):
    if model_name is None:
        model_name = ds_label.attrs["model"]
    label_set = ds_label[index].rolling(time=24).construct("lead").dropna("time")
    label_set = label_set.assign_coords(model=model_name)
    # Here we susbstract 23 months to the time index to get the first month aligned with the
    # proper lead time (0 lead).
    label_set["time"] = (
        label_set.indexes["time"]
        .to_series()
        .apply(
            lambda x: (x - pd.DateOffset(months=23)).replace(
                day=15, hour=0, minute=0, second=0
            )
        )
        .values
    )
    # Just cause I know i tend to forget a lot about multiple things, I will leave this here
    # as a small snippet to prove the above offset is correct.

    # import pandas as pd
    # import numpy as np
    # import xarray as xr

    # _test = (
    #     xr.DataArray(
    #         np.arange(100),
    #         coords=[("time", pd.date_range("2000-01", freq="M", periods=100))],
    #     )
    #     .rolling(time=24)
    #     .construct("lead")
    #     .dropna("time")
    # )
    # _test["time"] = (
    #     _test.indexes["time"]
    #     .to_series()
    #     .apply(
    #        lambda x: (x - pd.DateOffset(months=23)).replace(
    #            day=15, hour=0, minute=0, second=0
    #        )
    #    )
    # )
    # _test

    # This results in a lead time of 0 with the correct time dimension

    # It is nice to have coordinates along the label dimension so we don't mess up
    # with the leads once we start working with the data.
    label_set["lead"] = np.arange(24)
    label_set["month"] = label_set.time.dt.month
    # print(f"{model_name:<12s}: {label_set.shape=}")
    # Here we return the slice aligned with the input data on the time dimension
    return label_set.transpose("time", "lead").sel(time=slice(stime, etime))


def get_dates(start: tuple, end: tuple):
    syear, smonth = start
    eyear, emonth = end
    return (
        f"{syear}-{smonth}",
        f"{eyear}-{emonth}",
        f"{syear}-{smonth}",
        f"{eyear}-{emonth}",
    )

# Observations

## GODAS

In [3]:
OBS_PATH = "../data/processed/"
ROLLING_AMOUNT = 3

OUT_PATH = os.path.join(OBS_PATH, "obs_train")

check_folder(OUT_PATH)

In [4]:
max_val_time = pd.to_datetime("2022-12") - pd.DateOffset(months=23)
sgodas_y, sgodas_m = 1980, 3
egodas_y, egodas_m = int(f"{max_val_time:%Y}"), int(f"{max_val_time:%m}")

# Format dates
godas_start, godas_end, godas_label_start, godas_label_end = get_dates(
    (sgodas_y, sgodas_m), (egodas_y, egodas_m)
)
print(f"{godas_start=}, {godas_end=}, {godas_label_start=}, {godas_label_end=}")

godas_start='1980-3', godas_end='2021-1', godas_label_start='1980-3', godas_label_end='2021-1'


In [5]:
ecindex = xr.open_dataset(os.path.join(OBS_PATH, "godas.ecindex.nc"))
elabel = prepare_label(ecindex, "E_index", godas_label_start, godas_label_end, "GODAS")
clabel = prepare_label(ecindex, "C_index", godas_label_start, godas_label_end, "GODAS")
elabel.to_netcdf(os.path.join(OUT_PATH, "godas.E_index.nc"))
clabel.to_netcdf(os.path.join(OUT_PATH, "godas.C_index.nc"))

ecindex_3mn = xr.open_dataset(os.path.join(OBS_PATH, "godas.ecindex_3mn.nc"))
elabel_3mn = prepare_label(
    ecindex_3mn,
    "E_index",
    godas_label_start,
    godas_label_end,
    "GODAS",
)
clabel_3mn = prepare_label(
    ecindex_3mn,
    "C_index",
    godas_label_start,
    godas_label_end,
    "GODAS",
)
elabel_3mn.to_netcdf(os.path.join(OUT_PATH, "godas.E_index_3mn.nc"))
clabel_3mn.to_netcdf(os.path.join(OUT_PATH, "godas.C_index_3mn.nc"))

en34 = xr.open_dataset(os.path.join(OBS_PATH, "godas.en34.nc"))
en34 = prepare_label(en34, "en34", godas_label_start, godas_label_end, "GODAS")
en34.to_netcdf(os.path.join(OUT_PATH, "godas.en34.nc"))

en34_3mn = xr.open_dataset(os.path.join(OBS_PATH, "godas.en34_3mn.nc"))
en34_3mn = prepare_label(en34_3mn, "en34", godas_label_start, godas_label_end, "GODAS")
en34_3mn.to_netcdf(os.path.join(OUT_PATH, "godas.en34_3mn.nc"))

In [6]:
godas_sst_ds = xr.open_dataset(os.path.join(OBS_PATH, "godas.sst.anom.nc"))
godas_sst_anom = godas_sst_ds.ssta
godas_sst_anom["time"] = (
    godas_sst_anom.indexes["time"]
    .to_series()
    .apply(lambda x: x.replace(day=15, hour=0, minute=0, second=0))
    .values
)

godas_ssh_anom = xr.open_dataset(os.path.join(OBS_PATH, "godas.ssh.anom.nc")).sla
godas_ssh_anom["time"] = (
    godas_ssh_anom.indexes["time"]
    .to_series()
    .apply(lambda x: x.replace(day=15, hour=0, minute=0, second=0))
    .values
)

ncep_uwnd_anom = xr.open_dataset(os.path.join(OBS_PATH, "ncep.uwnd.anom.nc")).uwnda
ncep_uwnd_anom["time"] = (
    ncep_uwnd_anom.indexes["time"]
    .to_series()
    .apply(lambda x: x.replace(day=15, hour=0, minute=0, second=0))
    .values
)

ncep_vwnd_anom = xr.open_dataset(os.path.join(OBS_PATH, "ncep.vwnd.anom.nc")).vwnda
ncep_vwnd_anom["time"] = (
    ncep_vwnd_anom.indexes["time"]
    .to_series()
    .apply(lambda x: x.replace(day=15, hour=0, minute=0, second=0))
    .values
)

input_set = (
    xr.concat(
        [
            godas_sst_anom.rolling(time=ROLLING_AMOUNT).construct("lag"),
            godas_ssh_anom.rolling(time=ROLLING_AMOUNT).construct("lag"),
            ncep_uwnd_anom.rolling(time=ROLLING_AMOUNT).construct("lag"),
            ncep_vwnd_anom.rolling(time=ROLLING_AMOUNT).construct("lag"),
        ],
        dim="channel",
    )
    .transpose("time", "lag", "channel", "lat", "lon")
    .sel(time=slice(godas_start, godas_end))
    .fillna(0)
)
input_set["lag"] = np.arange(-2, 1, 1)
input_set["channel"] = ["sst", "ssh", "uas", "vas"]

input_set = input_set.assign_coords(model="GODAS")
input_set["mask"] = godas_sst_ds.mask

input_set.name = "sst_anom"

input_set.to_netcdf(os.path.join(OUT_PATH, "godas.train_set.nc"))

In [7]:
time_comp1 = np.cos(2 * np.pi * input_set.time.dt.month / 12)
time_comp1.name = "time_cos"
time_comp2 = np.sin(2 * np.pi * input_set.time.dt.month / 12)
time_comp2.name = "time_sin"
xr.merge([time_comp1, time_comp2]).to_netcdf(
    os.path.join(OUT_PATH, "godas.train_time_set.nc")
)