In [1]:
import os
from glob import glob

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from dmelon.utils import check_folder

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

OBS_PATH = "../data/processed/obs_train/"

In [3]:
std = xr.DataArray(
    [1.36, 13.86, 1.97, 1.14], coords=[("channel", ["sst", "ssh", "uas", "vas"])]
)
std

## GODAS (validation)

In [4]:
godas_input = (
    xr.open_dataset(os.path.join(OBS_PATH, "godas.train_set.nc")).sst_anom / std
).transpose("time", "channel", "lag", "lat", "lon")

godas_label_Eindex = xr.open_dataset(os.path.join(OBS_PATH, "godas.E_index.nc"))[
    "E_index"
]
godas_label_Cindex = xr.open_dataset(os.path.join(OBS_PATH, "godas.C_index.nc"))[
    "C_index"
]

godas_time = xr.open_dataset(
    os.path.join(OBS_PATH, "godas.train_time_set.nc")
).expand_dims({"channel": [1]}, -1)

godas_time_sin = godas_time.time_sin
godas_time_cos = godas_time.time_cos

godas_extreme_class = (godas_label_Eindex.sel(lead=1, drop=True) >= 1.5) & (
    godas_label_Eindex.month == 12
)

godas_extreme_class = (
    godas_extreme_class.astype(int)
    .where(godas_extreme_class)
    .bfill(dim="time", limit=11)
    .fillna(0)
)


godas_data = godas_input

In [5]:
import tensorflow as tf
from model_definition import (
    CategoricalFalseNegatives,
    CategoricalFalsePositives,
    CategoricalTrueNegatives,
    CategoricalTruePositives,
    CriticalScoreIndex,
    ECNet_keras,
)

tf.keras.utils.set_random_seed(1337)
gpus = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(gpus[0], True)
print(f"Using GPU: {gpus[0]}")

# input data on the CPU
# https://github.com/keras-team/keras/issues/16997#issuecomment-1252488327
BUFFER_SIZE = 100000
start_lead = 1
end_lead = 12

Using GPU: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [6]:
NENS = 30

for ens_number in range(NENS):
    EXP_NAME = f"ecmodel_m{ens_number:02d}"
    OUT_PATH = os.path.join("../models/IGP-UHM-v1.0", EXP_NAME)
    # PLOT_PATH = os.path.join(os.getcwd(), "plots", EXP_NAME)
    MODEL_FT_PATH = os.path.join(OUT_PATH, f"model_ft_{ens_number:02d}")
    # check_folder(PLOT_PATH)

    # FINE TUNING
    print(f"Doing finetune ensemble member: {ens_number:02d}")

    model = tf.keras.models.load_model(
        MODEL_FT_PATH,
        custom_objects=dict(
            CriticalScoreIndex=CriticalScoreIndex,
            CategoricalTruePositives=CategoricalTruePositives,
            CategoricalFalsePositives=CategoricalFalsePositives,
            CategoricalTrueNegatives=CategoricalTrueNegatives,
            CategoricalFalseNegatives=CategoricalFalseNegatives,
        ),
    )

    # Get model predictions

    eindex_hat, cindex_hat, time_hat, class_hat = model.predict(godas_data.data)

    eindex_hat, cindex_hat, time_hat, class_hat = (
        np.array(eindex_hat).astype(np.float32),
        np.array(cindex_hat).astype(np.float32),
        np.array(time_hat).astype(np.float32),
        np.array(class_hat).astype(np.float32),
    )

    model_test_output = xr.Dataset(
        {
            "eindex": (["time", "lead"], eindex_hat),
            "cindex": (["time", "lead"], cindex_hat),
            "tsin": (["time"], time_hat[:, 0]),
            "tcos": (["time"], time_hat[:, 1]),
            "class_hat": (["time"], class_hat[:, 1]),
        },
        coords={
            "time": godas_data.time,
            "lead": godas_label_Eindex.sel(lead=slice(start_lead, end_lead)).lead,
            "month": (["time"], godas_data.time.dt.month.data),
            "year": (["time"], godas_data.time.dt.year.data),
        },
    )

    model_test_output.to_netcdf(
        os.path.join(OUT_PATH, f"model_output_ft_godas_{ens_number:02d}.nc")
    )

Doing finetune ensemble member: 00
Doing finetune ensemble member: 01
Doing finetune ensemble member: 02
Doing finetune ensemble member: 03
Doing finetune ensemble member: 04
Doing finetune ensemble member: 05
Doing finetune ensemble member: 06
Doing finetune ensemble member: 07
Doing finetune ensemble member: 08
Doing finetune ensemble member: 09
Doing finetune ensemble member: 10
Doing finetune ensemble member: 11
Doing finetune ensemble member: 12
Doing finetune ensemble member: 13
Doing finetune ensemble member: 14
Doing finetune ensemble member: 15
Doing finetune ensemble member: 16
Doing finetune ensemble member: 17
Doing finetune ensemble member: 18
Doing finetune ensemble member: 19
Doing finetune ensemble member: 20
Doing finetune ensemble member: 21
Doing finetune ensemble member: 22
Doing finetune ensemble member: 23
Doing finetune ensemble member: 24
Doing finetune ensemble member: 25
Doing finetune ensemble member: 26
Doing finetune ensemble member: 27
Doing finetune ensem

In [10]:
import innvestigate
import tensorflow as tf
from model_definition import (
    CategoricalFalseNegatives,
    CategoricalFalsePositives,
    CategoricalTrueNegatives,
    CategoricalTruePositives,
    CriticalScoreIndex,
    ECNet_keras,
)

tf.compat.v1.disable_v2_behavior()

gpus = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(gpus[0], True)
print(f"Using GPU: {gpus[0]}")

Instructions for updating:
non-resource variables are not supported in the long term
Using GPU: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


In [12]:
NENS = 30
concat_data = godas_data

for ens_number in range(NENS):
    print(f"Reading ensemble member: {ens_number:02d}")

    EXP_NAME = f"ecmodel_m{ens_number:02d}"
    OUT_PATH = os.path.join("../models/IGP-UHM-v1.0", EXP_NAME)
    MODEL_FT_PATH = os.path.join(OUT_PATH, f"model_ft_{ens_number:02d}")

    model = tf.keras.models.load_model(
        MODEL_FT_PATH,
        custom_objects=dict(
            CriticalScoreIndex=CriticalScoreIndex,
            CategoricalTruePositives=CategoricalTruePositives,
            CategoricalFalsePositives=CategoricalFalsePositives,
            CategoricalTrueNegatives=CategoricalTrueNegatives,
            CategoricalFalseNegatives=CategoricalFalseNegatives,
        ),
    )

    model_class = tf.keras.models.Model(
        inputs=model.input,
        outputs=innvestigate.backend.graph.pre_output_tensors(model.layers[-1].output),
    )

    class_hat = model_class.predict(godas_data.data)

    class_hat = np.array(class_hat).astype(np.float32)

    model_test_output = xr.Dataset(
        {
            "class_hat": (["time"], class_hat[:, 1]),
        },
        coords={
            "time": godas_data.time,
            "month": (["time"], godas_data.time.dt.month.data),
            "year": (["time"], godas_data.time.dt.year.data),
        },
    )
    model_test_output.to_netcdf(
        os.path.join(OUT_PATH, f"model_output_ft_godas_nosoft_{ens_number:02d}.nc")
    )

    lrpSeqA_analyzer_model_class = innvestigate.create_analyzer(
        "lrp.sequential_preset_a", model_class, neuron_selection_mode="index"
    )

    LRPSEQAFLAT_heatmaps_model_class = [
        lrpSeqA_analyzer_model_class.analyze(
            map_sample.expand_dims("time").data, neuron_selection=1
        )
        for map_sample in concat_data
    ]

    LRPSEQAFLAT_heatmaps_model_class = xr.DataArray(
        np.array(LRPSEQAFLAT_heatmaps_model_class).squeeze(),
        dims=["time", "channel", "lag", "lat", "lon"],
        coords={
            "time": concat_data.time,
            "channel": concat_data.channel,
            "lag": concat_data.lag,
            "lat": concat_data.lat,
            "lon": concat_data.lon,
            # "model": ("time", concat_data.model.data),
        },
        name="LRPSEQAFLAT_heatmaps_model_class",
    )

    LRPSEQAFLAT_heatmaps_model_class.to_netcdf(
        os.path.join(OUT_PATH, f"LRPSEQAFLAT_heatmaps_godas_class_{ens_number:02d}.nc")
    )

Reading ensemble member: 00
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 01


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 02


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 03


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 04


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 05


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 06


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 07


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 08


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 09


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 10


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 11


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 12


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 13


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 14


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 15


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 16


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 17


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 18


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 19


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 20


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 21


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 22


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 23


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 24


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 25


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 26


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 27


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 28


  updates=self.state_updates,
  updates=self.state_updates,


Reading ensemble member: 29


  updates=self.state_updates,
  updates=self.state_updates,
