In [None]:
from __future__ import annotations

import datetime
import hashlib
import itertools
import json
import os
import pickle
import re
import warnings

from collections import namedtuple
from copy import copy
from functools import partial
from pathlib import Path
from typing import List

In [None]:
import cartopy.feature as cfeature
import cartopy.crs as ccrs

import joblib

import matplotlib as mpl
from matplotlib import gridspec
from matplotlib import pyplot as plt

import netCDF4
from netCDF4 import Dataset

import numpy as np

import pandas as pd

import scipy as sp

import sklearn
from sklearn.covariance import EmpiricalCovariance
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.preprocessing import StandardScaler

import tqdm

In [None]:
TrajectoryPaths = namedtuple("TrajectoryPaths", ["date", "out", "aer", "ant", "bio", "met"])
TrajectoryDatasets = namedtuple("TrajectoryDatasets", ["date", "out", "aer", "ant", "bio", "met"])
MLDataset = namedtuple("MLDataset", ["date", "paths", "X_raw", "Y_raw", "X_train", "X_valid", "X_test", "Y_train", "Y_valid", "Y_test", "X_scaler", "Y_scaler"])
PerturbedDataset = namedtuple("PerturbedDataset", ["date", "perturbation", "paths", "X", "Y"])

In [None]:
def get_path_for_perturbation(dt: datetime.datetime, perturbation: Path) -> TrajectoryPaths:
    base = Path.cwd().parent / "trajectories"
    
    out_path = base / "outputs" / perturbation / dt.strftime('%Y%m%d_T%H') / "output.nc"
    aer_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{dt.strftime('%Y%m%d')}_7daybwd_Hyde_traj_AER_{24-dt.hour:02}_L3.nc"
    )
    ant_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{dt.strftime('%Y%m%d')}_7daybwd_Hyde_traj_ANT_{24-dt.hour:02}_L3.nc"
    )
    bio_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{dt.strftime('%Y%m%d')}_7daybwd_Hyde_traj_BIO_{24-dt.hour:02}_L3.nc"
    )
    met_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "METEO" /
        f"METEO_{dt.strftime('%Y%m%d')}_R{24-dt.hour:02}.nc"
    )
    
    if (
        (not out_path.exists()) or (not aer_path.exists()) or
        (not ant_path.exists()) or (not bio_path.exists()) or
        (not met_path.exists())
    ):
        raise Exception(out_path, aer_path, ant_path, bio_path, met_path)
    
    return TrajectoryPaths(
        date=dt, out=out_path, aer=aer_path, ant=ant_path, bio=bio_path, met=met_path,
    )

In [None]:
def load_trajectory_dataset(paths: TrajectoryPaths) -> TrajectoryDatasets:
    outds = Dataset(paths.out, "r", format="NETCDF4")
    aerds = Dataset(paths.aer, "r", format="NETCDF4")
    antds = Dataset(paths.ant, "r", format="NETCDF4")
    biods = Dataset(paths.bio, "r", format="NETCDF4")
    metds = Dataset(paths.met, "r", format="NETCDF4")
    
    return TrajectoryDatasets(
        date=paths.date, out=outds, aer=aerds, ant=antds, bio=biods, met=metds,
    )

In [None]:
data_proj = ccrs.PlateCarree()
projection = ccrs.LambertConformal(
    central_latitude=50, central_longitude=20, standard_parallels=(25, 25)
)
extent = [-60, 60, 40, 80]

In [None]:
def get_ccn_concentration(ds: TrajectoryDatasets):
    ccn_bin_indices, = np.nonzero(ds.out["dp_dry_fs"][:].data > 80e-9)
    ccn_concentration = np.sum(ds.out["nconc_par"][:].data[:,ccn_bin_indices,:], axis=1)
    
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "ccn": ccn_concentration.flatten(),
    }).set_index(["time", "level"])

In [None]:
def get_output_time(ds: TrajectoryDatasets):
    fdom = datetime.datetime.strptime(
        ds.out["time"].__dict__["first_day_of_month"], "%Y-%m-%d %H:%M:%S",
    )
    dt = (ds.date - fdom).total_seconds()
    
    out_t = ds.out["time"][:].data
    
    return out_t - dt

def interpolate_meteorology_values(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    met_t = ds.met["time"][:].data
    met_h = ds.met["lev"][:].data
    
    met_t_h = ds.met[key][:]
    
    met_t_h_int = sp.interpolate.interp2d(
        x=met_h, y=met_t, z=met_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return met_t_h_int(x=out_h, y=out_t)

def interpolate_meteorology_time_values(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    met_t = ds.met["time"][:].data
    
    met_t_v = ds.met[key][:]
    
    met_t_int = sp.interpolate.interp1d(
        x=met_t, y=met_t_v, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return np.repeat(
        met_t_int(x=out_t).reshape(-1, 1),
        out_h.shape[0], axis=1,
    )

def interpolate_biogenic_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    # depth of each box layer, assuming level heights are midpoints and end points are clamped
    out_d = (np.array(list(out_h[1:])+[out_h[-1]]) - np.array([out_h[0]]+list(out_h[:-1]))) / 2.0
    
    bio_t = ds.bio["time"][:].data
    
    # Biogenic emissions are limited to boxes at <= 10m height
    biogenic_emission_layers = np.nonzero(out_h <= 10.0)
    biogenic_emission_layer_height_cumsum = np.cumsum(out_d[biogenic_emission_layers])
    biogenic_emission_layer_proportion = biogenic_emission_layer_height_cumsum / biogenic_emission_layer_height_cumsum[-1]
    num_biogenic_emission_layers = sum(out_h <= 10.0)
    
    bio_t_h = np.zeros(shape=(out_t.size, out_h.size))
    
    bio_t_int = sp.interpolate.interp1d(
        x=bio_t, y=ds.bio[key][:], kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    # Split up the biogenic emissions relative to the depth of the boxes
    bio_t_h[:,biogenic_emission_layers] = (
        np.tile(bio_t_int(x=out_t), (num_biogenic_emission_layers, 1, 1)) * biogenic_emission_layer_proportion.reshape(-1, 1, 1)
    ).T
    
    return bio_t_h

def interpolate_aerosol_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    aer_t = ds.aer["time"][:].data
    aer_h = ds.aer["mid_layer_height"][:].data
    
    aer_t_h = ds.aer[key][:].T
    
    aer_t_h_int = sp.interpolate.interp2d(
        x=aer_h, y=aer_t, z=aer_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return aer_t_h_int(x=out_h, y=out_t)

def interpolate_anthropogenic_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    ant_t = ds.ant["time"][:].data
    ant_h = ds.ant["mid_layer_height"][:].data
    
    ant_t_h = ds.ant[key][:].T
    
    ant_t_h_int = sp.interpolate.interp2d(
        x=ant_h, y=ant_t, z=ant_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return ant_t_h_int(x=out_h, y=out_t)

In [None]:
def get_meteorology_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "met_t": interpolate_meteorology_values(ds, "t").flatten(),
        # "met_u": interpolate_meteorology_values(ds, "u").flatten(),
        # "met_v": interpolate_meteorology_values(ds, "v").flatten(),
        "met_q": interpolate_meteorology_values(ds, "q").flatten(),
        # "met_qc": interpolate_meteorology_values(ds, "qc").flatten(),
        # "met_sp": interpolate_meteorology_time_values(ds, "sp").flatten(),
        # "met_cp": interpolate_meteorology_time_values(ds, "cp").flatten(),
        # "met_sshf": interpolate_meteorology_time_values(ds, "sshf").flatten(),
        "met_ssr": interpolate_meteorology_time_values(ds, "ssr").flatten(),
        # "met_lsp": interpolate_meteorology_time_values(ds, "lsp").flatten(),
        # "met_ewss": interpolate_meteorology_time_values(ds, "ewss").flatten(),
        # "met_nsss": interpolate_meteorology_time_values(ds, "nsss").flatten(),
        # "met_tcc": interpolate_meteorology_time_values(ds, "tcc").flatten(),
        "met_lsm": interpolate_meteorology_time_values(ds, "lsm").flatten(),
        # "met_omega": interpolate_meteorology_values(ds, "omega").flatten(),
        # "met_z": interpolate_meteorology_time_values(ds, "z").flatten(),
        # "met_mla": interpolate_meteorology_values(ds, "mla").flatten(),
        # NOTE: lp is excluded because it allows the model to overfit
        # "met_lp": interpolate_meteorology_values(ds, "lp").flatten(),
        "met_blh": interpolate_meteorology_time_values(ds, "blh").flatten(),
    }).set_index(["time", "level"])

def get_bio_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "bio_acetaldehyde": interpolate_biogenic_emissions(ds, "acetaldehyde").flatten(),
        "bio_acetone": interpolate_biogenic_emissions(ds, "acetone").flatten(),
        "bio_butanes_and_higher_alkanes": interpolate_biogenic_emissions(ds, "butanes-and-higher-alkanes").flatten(),
        "bio_butanes_and_higher_alkenes": interpolate_biogenic_emissions(ds, "butenes-and-higher-alkenes").flatten(),
        "bio_ch4": interpolate_biogenic_emissions(ds, "CH4").flatten(),
        "bio_co": interpolate_biogenic_emissions(ds, "CO").flatten(),
        "bio_ethane": interpolate_biogenic_emissions(ds, "ethane").flatten(),
        "bio_ethanol": interpolate_biogenic_emissions(ds, "ethanol").flatten(),
        "bio_ethene": interpolate_biogenic_emissions(ds, "ethene").flatten(),
        "bio_formaldehyde": interpolate_biogenic_emissions(ds, "formaldehyde").flatten(),
        "bio_hydrogen_cyanide": interpolate_biogenic_emissions(ds, "hydrogen-cyanide").flatten(),
        "bio_iosprene": interpolate_biogenic_emissions(ds, "isoprene").flatten(),
        "bio_mbo": interpolate_biogenic_emissions(ds, "MBO").flatten(),
        "bio_methanol": interpolate_biogenic_emissions(ds, "methanol").flatten(),
        "bio_methyl_bromide": interpolate_biogenic_emissions(ds, "methyl-bromide").flatten(),
        "bio_methyl_chloride": interpolate_biogenic_emissions(ds, "methyl-chloride").flatten(),
        "bio_methyl_iodide": interpolate_biogenic_emissions(ds, "methyl-iodide").flatten(),
        "bio_other_aldehydes": interpolate_biogenic_emissions(ds, "other-aldehydes").flatten(),
        "bio_other_ketones": interpolate_biogenic_emissions(ds, "other-ketones").flatten(),
        "bio_other_monoterpenes": interpolate_biogenic_emissions(ds, "other-monoterpenes").flatten(),
        "bio_pinene_a": interpolate_biogenic_emissions(ds, "pinene-a").flatten(),
        "bio_pinene_b": interpolate_biogenic_emissions(ds, "pinene-b").flatten(),
        "bio_propane": interpolate_biogenic_emissions(ds, "propane").flatten(),
        "bio_propene": interpolate_biogenic_emissions(ds, "propene").flatten(),
        "bio_sesquiterpenes": interpolate_biogenic_emissions(ds, "sesquiterpenes").flatten(),
        "bio_toluene": interpolate_biogenic_emissions(ds, "toluene").flatten(),
        "bio_ch2br2": interpolate_biogenic_emissions(ds, "CH2Br2").flatten(),
        "bio_ch3i": interpolate_biogenic_emissions(ds, "CH3I").flatten(),
        "bio_chbr3": interpolate_biogenic_emissions(ds, "CHBr3").flatten(),
        "bio_dms": interpolate_biogenic_emissions(ds, "DMS").flatten(),
    }).set_index(["time", "level"])

def get_aer_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "aer_3_10_nm": interpolate_aerosol_emissions(ds, "3-10nm").flatten(),
        "aer_10_20_nm": interpolate_aerosol_emissions(ds, "10-20nm").flatten(),
        "aer_20_30_nm": interpolate_aerosol_emissions(ds, "20-30nm").flatten(),
        "aer_30_50_nm": interpolate_aerosol_emissions(ds, "30-50nm").flatten(),
        "aer_50_70_nm": interpolate_aerosol_emissions(ds, "50-70nm").flatten(),
        "aer_70_100_nm": interpolate_aerosol_emissions(ds, "70-100nm").flatten(),
        "aer_100_200_nm": interpolate_aerosol_emissions(ds, "100-200nm").flatten(),
        "aer_200_400_nm": interpolate_aerosol_emissions(ds, "200-400nm").flatten(),
        "aer_400_1000_nm": interpolate_aerosol_emissions(ds, "400-1000nm").flatten(),
    }).set_index(["time", "level"])

def get_ant_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "ant_co": interpolate_anthropogenic_emissions(ds, "co").flatten(),
        "ant_nox": interpolate_anthropogenic_emissions(ds, "nox").flatten(),
        "ant_co2": interpolate_anthropogenic_emissions(ds, "co2").flatten(),
        "ant_nh3": interpolate_anthropogenic_emissions(ds, "nh3").flatten(),
        "ant_ch4": interpolate_anthropogenic_emissions(ds, "ch4").flatten(),
        "ant_so2": interpolate_anthropogenic_emissions(ds, "so2").flatten(),
        "ant_nmvoc": interpolate_anthropogenic_emissions(ds, "nmvoc").flatten(),
        "ant_alcohols": interpolate_anthropogenic_emissions(ds, "alcohols").flatten(),
        "ant_ethane": interpolate_anthropogenic_emissions(ds, "ethane").flatten(),
        "ant_propane": interpolate_anthropogenic_emissions(ds, "propane").flatten(),
        "ant_butanes": interpolate_anthropogenic_emissions(ds, "butanes").flatten(),
        "ant_pentanes": interpolate_anthropogenic_emissions(ds, "pentanes").flatten(),
        "ant_hexanes": interpolate_anthropogenic_emissions(ds, "hexanes").flatten(),
        "ant_ethene": interpolate_anthropogenic_emissions(ds, "ethene").flatten(),
        "ant_propene": interpolate_anthropogenic_emissions(ds, "propene").flatten(),
        "ant_acetylene": interpolate_anthropogenic_emissions(ds, "acetylene").flatten(),
        "ant_isoprene": interpolate_anthropogenic_emissions(ds, "isoprene").flatten(),
        "ant_monoterpenes": interpolate_anthropogenic_emissions(ds, "monoterpenes").flatten(),
        "ant_other_alkenes_and_alkynes": interpolate_anthropogenic_emissions(ds, "other-alkenes-and-alkynes").flatten(),
        "ant_benzene": interpolate_anthropogenic_emissions(ds, "benzene").flatten(),
        "ant_toluene": interpolate_anthropogenic_emissions(ds, "toluene").flatten(),
        "ant_xylene": interpolate_anthropogenic_emissions(ds, "xylene").flatten(),
        "ant_trimethylbenzene": interpolate_anthropogenic_emissions(ds, "trimethylbenzene").flatten(),
        "ant_other_aromatics": interpolate_anthropogenic_emissions(ds, "other-aromatics").flatten(),
        "ant_esters": interpolate_anthropogenic_emissions(ds, "esters").flatten(),
        "ant_ethers": interpolate_anthropogenic_emissions(ds, "ethers").flatten(),
        "ant_formaldehyde": interpolate_anthropogenic_emissions(ds, "formaldehyde").flatten(),
        "ant_other_aldehydes": interpolate_anthropogenic_emissions(ds, "other-aldehydes").flatten(),
        "ant_total_ketones": interpolate_anthropogenic_emissions(ds, "total-ketones").flatten(),
        "ant_total_acids": interpolate_anthropogenic_emissions(ds, "total-acids").flatten(),
        "ant_other_vocs": interpolate_anthropogenic_emissions(ds, "other-VOCs").flatten(),
    }).set_index(["time", "level"])

In [None]:
# https://stackoverflow.com/a/67809235
def df_to_numpy(df):
    try:
        shape = [len(level) for level in df.index.levels]
    except AttributeError:
        shape = [len(df.index)]
    ncol = df.shape[-1]
    if ncol > 1:
        shape.append(ncol)
    return df.to_numpy().reshape(shape)

In [None]:
def generate_time_level_windows():
    # -0.5h, -1.5h, -3h, -6h, -12h, -24h, -48h
    # 0, -2, -5, -11, -23, -47, -95
    time_windows = [(0, 0), (-2, -1), (-5, -3), (-11, -6), (-23, -12), (-47, -24), (-95, -48)]
    
    # +1l, +2l, +4l, +8l, +16l, +32l, +64
    top_windows = [(1, 1), (1, 2), (1, 4), (2, 8), (2, 16), (3, 32), (3, 64)]
    mid_windows = [(0, 0), (0, 0), (0, 0), (-1, 1), (-1, 1), (-2, 2), (-2, 2)]
    bot_windows = [(-1, -1), (-2, -1), (-4, -1), (-8, -2), (-16, -2), (-32, -3), (-64, -3)]
    
    return list(itertools.chain(
        zip(time_windows, top_windows), zip(time_windows, mid_windows), zip(time_windows, bot_windows),
    ))

In [None]:
def generate_windowed_feature_names(columns):
    time_windows = ["-0.5h", "-1.5h", "-3h", "-6h", "-12h", "-24h", "-48h"]
    
    top_windows = ["+1l", "+2l", "+4l", "+8l", "+16l", "+32l", "+64l"]
    mid_windows = ["+0l", "+0l", "+0l", "±1l", "±1l", "±2l", "±2l"]
    bot_windows = ["-1l", "-2l", "-4l", "-8l", "-16l", "-32l", "-64l"]
    
    names = []
    
    for (t, l) in itertools.chain(
        zip(time_windows, top_windows), zip(time_windows, mid_windows), zip(time_windows, bot_windows),
    ):
        for c in columns:
            names.append(f"{c}{t}{l}")
    
    return names

In [None]:
def time_level_window_mean_v1(input, t_range, l_range):
    output = np.zeros(shape=input.shape)

    for t in range(input.shape[0]):
        for l in range(input.shape[1]):
            for f in range(input.shape[2]):
                window = input[
                    min(max(0, t+t_range[0]), input.shape[0]):max(0, min(t+1+t_range[1], input.shape[0])),
                    min(max(0, l+l_range[0]), input.shape[1]):max(0, min(l+1+l_range[1], input.shape[1])),
                    f
                ]

                output[t,l,f] = np.mean(window) if window.size > 0 else 0.0
    
    return output

def time_level_window_mean_v2(input, t_range, l_range):
    output = np.zeros(shape=input.shape)

    for t in range(input.shape[0]):
        mint = min(max(0, t+t_range[0]), input.shape[0])
        maxt = max(0, min(t+1+t_range[1], input.shape[0]))
        
        if mint == maxt:
            continue
        
        for l in range(input.shape[1]):
            minl = min(max(0, l+l_range[0]), input.shape[1])
            maxl = max(0, min(l+1+l_range[1], input.shape[1]))
            
            if minl == maxl:
                continue
                
            output[t,l,:] = np.mean(input[mint:maxt,minl:maxl,:], axis=(0,1))
    
    return output

def time_level_window_mean_v3(input, t_range, l_range):
    min_t = min(t_range[0], 0)
    max_t = max(0, t_range[1])
    abs_t = max(abs(min_t), abs(max_t))
    
    min_l = min(l_range[0], 0)
    max_l = max(0, l_range[1])
    abs_l = max(abs(min_l), abs(max_l))
    
    kernel = np.zeros(shape=(abs_t*2 + 1, abs_l*2 + 1, 1))
    kernel[t_range[0]+abs_t:t_range[1]+abs_t+1,l_range[0]+abs_l:l_range[1]+abs_l+1,:] = 1.0
    kernel = kernel[::-1,::-1]
    
    quot = sp.ndimage.convolve(np.ones_like(input), kernel, mode='constant', cval=0.0)
    
    result = np.zeros_like(input)
    
    np.divide(
        sp.ndimage.convolve(input, kernel, mode='constant', cval=0.0),
        quot, out=result, where=quot > 0,
    )
    
    return result

In [None]:
def get_raw_features_for_dataset(ds: TrajectoryDatasets):
    bio_features = get_bio_emissions_features(ds)
    aer_features = get_aer_emissions_features(ds) * 1e21
    ant_features = get_ant_emissions_features(ds)
    met_features = get_meteorology_features(ds)
    
    return pd.concat([
        bio_features, aer_features, ant_features, met_features,
    ], axis="columns")

In [None]:
def get_features_from_raw_features(raw_features):
    raw_features_np = df_to_numpy(raw_features)
    
    features_np = np.concatenate([
        raw_features.index.get_level_values(0).to_numpy().reshape(
            (raw_features.index.levels[0].size, raw_features.index.levels[1].size, 1)
        ),
        raw_features.index.get_level_values(1).to_numpy().reshape(
            (raw_features.index.levels[0].size, raw_features.index.levels[1].size, 1)
        )
    ] + joblib.Parallel(n_jobs=-1)([
        joblib.delayed(time_level_window_mean_v2)(raw_features_np, t, l) for t, l in generate_time_level_windows()
    ]), axis=2)
    
    # Trim off the first two days, for which the time features are ill-defined
    features_np_trimmed = features_np[95:-1,:,:]
    
    feature_names = ["time", "level"] + generate_windowed_feature_names(raw_features.columns)
    
    features = pd.DataFrame(features_np_trimmed.reshape(
        features_np_trimmed.shape[0]*features_np_trimmed.shape[1], features_np_trimmed.shape[2],
    ), columns=feature_names).set_index(["time", "level"])
    
    return features

In [None]:
def get_labels_for_dataset(ds: TrajectoryDatasets):
    ccn_concentration = get_ccn_concentration(ds)
    
    ccn_concentration_np = df_to_numpy(ccn_concentration)
    
    labels_np = np.concatenate([
        ccn_concentration.index.get_level_values(0).to_numpy().reshape(
            (ccn_concentration.index.levels[0].size, ccn_concentration.index.levels[1].size, 1)
        ),
        ccn_concentration.index.get_level_values(1).to_numpy().reshape(
            (ccn_concentration.index.levels[0].size, ccn_concentration.index.levels[1].size, 1)
        ),
        ccn_concentration_np.reshape(
            (ccn_concentration_np.shape[0], ccn_concentration_np.shape[1], 1)
        ),
    ], axis=2)
    
    # Trim off the first two days, for which the time features are ill-defined
    labels_np_trimmed = labels_np[96:,:,:]
    
    label_names = ["time", "level", "ccn"]
    
    labels = pd.DataFrame(labels_np_trimmed.reshape(
        labels_np_trimmed.shape[0]*labels_np_trimmed.shape[1], labels_np_trimmed.shape[2],
    ), columns=label_names).set_index(["time", "level"])
    
    return labels

In [None]:
def hash_for_dt(dt):
    if not(isinstance(dt, tuple) or isinstance(dt, list)):
        dt = [dt]
    
    dt_str = '.'.join(dtt.strftime('%d.%m.%Y-%H:00%z') for dtt in dt)
    
    h = hashlib.shake_256()
    h.update(dt_str.encode('ascii'))
    
    return h

In [None]:
"""
Clumped 0/1 sampler using a Markov Process

P(0) = p and P(1) = 1-p
clump = 0 => IID samples
clump -> 1 => highly correlated samples

"""
class Clump:
    def __init__(self, p=0.5, clump=0.0, rng=None):
        a = 1 - (1-p)*(1-clump)
        b = (1-a)*p/(1-p)
        
        self.C = np.array([[a, 1-a],[b, 1-b]])
        
        self.i = 0 if rng.random() < p else 1
    
    def sample(self, rng):
        p = self.C[self.i,0]
        u = rng.random()
        
        self.i = 0 if u < p else 1
        
        return self.i
    
    def steady(self, X):
        return np.matmul(X, self.C)

In [None]:
def train_test_split(X, Y, test_size=0.25, random_state=None, shuffle=True, clump=0.0):
    assert len(X) == len(Y)
    assert type(X) == type(Y)
    assert test_size > 0.0
    assert test_size < 1.0
    assert random_state is not None
    assert clump >= 0.0
    assert clump < 1.0
    
    c = Clump(p=test_size, clump=clump, rng=random_state)
    
    if isinstance(X, pd.DataFrame):
        assert X.index.values.shape == Y.index.values.shape
        
        # Split only based on the first-level index instead of flattening
        n1 = len(X.index.levels[1])
        n0 = len(X) // n1
        
        C = np.array([c.sample(random_state) for _ in range(n0)])
        I_train, = np.nonzero(C)
        I_train = np.repeat(I_train, n1) * n1 + np.tile(np.arange(n1), len(I_train))
        I_test, = np.nonzero(1-C)
        I_test = np.repeat(I_test, n1) * n1 + np.tile(np.arange(n1), len(I_test))
    else:
        C = np.array([c.sample(random_state) for _ in range(len(X))])
        I_train, = np.nonzero(C)
        I_test, = np.nonzero(1-C)
    
    if shuffle:
        random_state.shuffle(I_train)
        random_state.shuffle(I_test)
    
    if isinstance(X, pd.DataFrame):
        X_train = X.iloc[I_train]
        X_test = X.iloc[I_test]
        
        Y_train = Y.iloc[I_train]
        Y_test = Y.iloc[I_test]
    else:
        X_train = X[I_train]
        X_test = X[I_test]
        
        Y_train = Y[I_train]
        Y_test = Y[I_test]
    
    return X_train, X_test, Y_train, Y_test

In [None]:
def load_and_cache_dataset(dt: datetime.datetime, clump: float, datasets: dict) -> MLDataset:
    if isinstance(dt, tuple) or isinstance(dt, list):
        dt = tuple(sorted(dt))
    
    cached = datasets.get((dt, clump))
    
    if cached is not None:
        return cached
    
    if isinstance(dt, tuple) or isinstance(dt, list):
        mls = [load_and_cache_dataset(dtt, clump, datasets) for dtt in dt]

        dp = tuple(ml.paths for ml in mls)
        X_raw = pd.concat([ml.X_raw for ml in mls], axis='index')
        Y = pd.concat([ml.Y_raw for ml in mls], axis='index')

        train_features = np.concatenate([ml.X_scaler.inverse_transform(ml.X_train) for ml in mls], axis=0)
        train_labels = np.concatenate([ml.Y_scaler.inverse_transform(ml.Y_train) for ml in mls], axis=0)
        valid_features = np.concatenate([ml.X_scaler.inverse_transform(ml.X_valid) for ml in mls], axis=0)
        valid_labels = np.concatenate([ml.Y_scaler.inverse_transform(ml.Y_valid) for ml in mls], axis=0)
        test_features = np.concatenate([ml.X_scaler.inverse_transform(ml.X_test) for ml in mls], axis=0)
        test_labels = np.concatenate([ml.Y_scaler.inverse_transform(ml.Y_test) for ml in mls], axis=0)
    else:
        dp = get_path_for_perturbation(dt, Path("baseline"))
        ds = load_trajectory_dataset(dp)

        X_raw = get_raw_features_for_dataset(ds)

        X = get_features_from_raw_features(X_raw)
        Y = np.log10(get_labels_for_dataset(ds) + 1)

        rng = np.random.RandomState(seed=int.from_bytes(hash_for_dt(dt).digest(4), 'little'))

        train_features, test_features, train_labels, test_labels = train_test_split(
            X, Y, test_size=0.25, random_state=rng, clump=clump,
        )
        train_features, valid_features, train_labels, valid_labels = train_test_split(
            train_features, train_labels, test_size=1.0/3.0, random_state=rng, clump=clump,
        )

        # Close the NetCDF datasets
        ds.out.close()
        ds.aer.close()
        ds.ant.close()
        ds.bio.close()
        ds.met.close()

    # Scale features to N(0,1)
    # - only fit on training data
    # - OOD inputs for constants at training time are blown up
    feature_scaler = StandardScaler().fit(train_features)
    feature_scaler.scale_[np.nonzero(feature_scaler.var_ == 0.0)] = np.nan_to_num(np.inf)

    label_scaler = StandardScaler().fit(train_labels)

    train_features = feature_scaler.transform(train_features)
    train_labels = label_scaler.transform(train_labels)
    valid_features = feature_scaler.transform(valid_features)
    valid_labels = label_scaler.transform(valid_labels)
    test_features = feature_scaler.transform(test_features)
    test_labels = label_scaler.transform(test_labels)

    dataset = MLDataset(
        date=dt, paths=dp, X_raw=X_raw, Y_raw=Y,
        X_train=train_features, X_valid=valid_features, X_test=test_features,
        Y_train=train_labels, Y_valid=valid_labels, Y_test=test_labels,
        X_scaler=feature_scaler, Y_scaler=label_scaler,
    )

    datasets[(dt, clump)] = dataset
    
    return dataset

In [None]:
DATASETS = dict()

In [None]:
PERTURBED_DATASETS = dict()

In [None]:
import abc

IcarusPrediction = namedtuple("IcarusPrediction", ["prediction", "uncertainty", "confidence"])

class IcarusRSM(abc.ABC):
    @abc.abstractmethod
    def fit(
        self,
        X_train: np.ndarray, Y_train: np.ndarray,
        X_valid: np.ndarray, Y_valid: np.ndarray,
        rng: np.random.Generator, **kwargs,
    ) -> IcarusRSM:
        return self
    
    @abc.abstractmethod
    def predict(self, X_test: np.ndarray, rng: np.random.Generator, **kwargs) -> IcarusPrediction:
        return None

In [None]:
class RandomForestSosaaRSM(IcarusRSM):
    def fit(
        self,
        X_train: np.ndarray, Y_train: np.ndarray,
        X_valid: np.ndarray, Y_valid: np.ndarray,
        rng: np.random.Generator,
        n_trees: int = 16, verbose: int = 1,
    ) -> RandomForestSosaaRSM:
        assert Y_train.shape[1:] == (1,)
        
        if verbose > 0:
            print("Training the RandomForestSosaaRSM")

            print(" - Training the OOD detector")
            print("   - Fitting truncated PCA")
        
        self.pca = PCA(random_state=rng).fit(X_train)
        self.bn = np.searchsorted(np.cumsum(self.pca.explained_variance_ratio_), 0.95)
        
        if verbose > 0:
            print("   - Fitting truncated PCA reconstruction error covariance")
        
        self.cov = EmpiricalCovariance().fit((self._predict_truncated_pca(X_train) - X_train))
        
        if verbose > 0:
            print("   - Generating FGSM OOD inputs")
        
        adv_grad = self.pca.components_[self.bn]
        
        X_ood = rng.normal(loc=X_valid, scale=0.01) + np.sign(adv_grad) * np.abs(
            rng.normal(loc=2.0, scale=0.5, size=(len(X_valid), 1))
        ) * rng.choice([-1, 1])
        
        if verbose > 0:
            print("   - Training the OOD classifier")
        
        M_id = self.cov.mahalanobis(self._predict_truncated_pca(X_valid) - X_valid)
        M_ood = self.cov.mahalanobis(self._predict_truncated_pca(X_ood) - X_ood)
        
        self.scaler = StandardScaler().fit(M_id.reshape(-1, 1))
        
        self.ood_detector = LogisticRegression(
            penalty='none', class_weight="balanced", random_state=rng,
        ).fit(
            np.concatenate([
                self.scaler.transform(M_id.reshape(-1, 1)),
                self.scaler.transform(M_ood.reshape(-1, 1)),
            ], axis=0).reshape(-1, 1),
            np.concatenate([
                np.ones(len(M_id)), np.zeros(len(M_ood)),
            ], axis=0),
        )
        
        if verbose > 0:
            print(" - Training the Prediction Model and Uncertainty Quantifier")
        
        self.predictor = RandomForestRegressor(
            n_estimators=n_trees, random_state=rng, n_jobs=-1, min_samples_leaf=5,
            max_features=1.0/3.0, verbose=verbose,
        ).fit(X_train, Y_train.ravel())
        
        if verbose > 0:
            print(" - Finished training the RandomForestSosaaRSM")
        
        return self
    
    def predict(
        self, X_test: np.ndarray, rng: np.random.Generator, verbose: int = 1,
    ) -> IcarusPrediction:
        # No extra randomness is needed during prediction
        _rng = rng
        
        if verbose > 1:
            print("Predicting with the RandomForestSosaaRSM")
            
            print(" - Generating confidence scores")
        
        confidence = self.ood_detector.predict_proba(self.scaler.transform(
            self.cov.mahalanobis(self._predict_truncated_pca(X_test) - X_test).reshape(-1, 1)
        ))[:,1]
            
        if verbose > 1:
            print(" - Generating ensemble predictions")
            
        def tree_predict(i: int) -> np.ndarray:
            if verbose > 1:
                print(f"   - Predicting tree {i}/{len(self.predictor.estimators_)}")
            
            return self.predictor.estimators_[i].predict(X_test)
        
        predictions = joblib.Parallel(n_jobs=-1, prefer="threads")(
            joblib.delayed(tree_predict)(i) for i in range(len(self.predictor.estimators_))
        )
        
        prediction = np.mean(np.stack(predictions, axis=0), axis=0).reshape((len(X_test), 1))
        uncertainty = np.std(np.stack(predictions, axis=0), axis=0).reshape((len(X_test), 1))
        
        if verbose > 1:
            print(" - Finished predicting with the RandomForestSosaaRSM")
            
        return IcarusPrediction(
            prediction=prediction, uncertainty=uncertainty, confidence=confidence,
        )

    def _predict_truncated_pca(self, X: np.ndarray) -> np.ndarray:
        if self.pca.mean_ is not None:
            X = X - self.pca.mean_
        
        X_trans = np.dot(X, self.pca.components_[:self.bn].T)
        X = np.dot(X_trans, self.pca.components_[:self.bn])
        
        if self.pca.mean_ is not None:
            X = X + self.pca.mean_
        
        return X

In [None]:
class PairwiseDifferenceRegressionRandomForestSosaaRSM(IcarusRSM):
    def fit(
        self,
        X_train: np.ndarray, Y_train: np.ndarray,
        X_valid: np.ndarray, Y_valid: np.ndarray,
        rng: np.random.Generator,
        n_trees: int = 16, n_samples: int = 16, verbose: int = 1,
    ) -> PairwiseDifferenceRegressionRandomForestSosaaRSM:
        assert Y_train.shape[1:] == (1,)
        
        self.X_train = X_train
        self.Y_train = Y_train
        
        if verbose > 0:
            print("Training the PairwiseDifferenceRegressionRandomForestSosaaRSM")
            
            print(" - Resampling the training and validation datasets")
    
        Ia_train = rng.choice(len(self.X_train), size=len(X_train)*n_samples, replace=True)
        Ib_train = rng.choice(len(X_train), size=len(X_train)*n_samples, replace=True)
        Ia_valid = rng.choice(len(self.X_train), size=len(X_valid)*n_samples, replace=True)
        Ib_valid = rng.choice(len(X_valid), size=len(X_valid)*n_samples, replace=True)
        
        # N(0,1)-N(0,1) ~ N(0,2) -> divide by sqrt(2) s.t. all features are N(0,1)
        X_train = np.concatenate([
            self.X_train[Ia_train], (X_train[Ib_train]-self.X_train[Ia_train]) / np.sqrt(2.0),
        ], axis=1)
        Y_train = Y_train[Ib_train] - self.Y_train[Ia_train]
        
        X_valid = np.concatenate([
            self.X_train[Ia_valid], (X_valid[Ib_valid]-self.X_train[Ia_valid]) / np.sqrt(2.0),
        ], axis=1)
        Y_valid = Y_valid[Ib_valid] - self.Y_train[Ia_valid]
            
        if verbose > 0:
            print(" - Training the OOD detector")
            print("   - Fitting truncated PCA")
        
        self.pca = PCA(random_state=rng).fit(X_train)
        self.bn = np.searchsorted(np.cumsum(self.pca.explained_variance_ratio_), 0.95)
        
        if verbose > 0:
            print("   - Fitting truncated PCA reconstruction error covariance")
        
        self.cov = EmpiricalCovariance().fit((self._predict_truncated_pca(X_train) - X_train))
        
        if verbose > 0:
            print("   - Generating FGSM OOD inputs")
        
        adv_grad = self.pca.components_[self.bn]
        
        X_ood = rng.normal(loc=X_valid, scale=0.01) + np.sign(adv_grad) * np.abs(
            rng.normal(loc=2.0, scale=0.5, size=(len(X_valid), 1))
        ) * rng.choice([-1, 1])
        
        if verbose > 0:
            print("   - Training the OOD classifier")
        
        M_id = self.cov.mahalanobis(self._predict_truncated_pca(X_valid) - X_valid)
        M_ood = self.cov.mahalanobis(self._predict_truncated_pca(X_ood) - X_ood)
        
        self.scaler = StandardScaler().fit(M_id.reshape(-1, 1))
        
        self.ood_detector = LogisticRegression(
            penalty='none', class_weight="balanced", random_state=rng,
        ).fit(
            np.concatenate([
                self.scaler.transform(M_id.reshape(-1, 1)),
                self.scaler.transform(M_ood.reshape(-1, 1)),
            ], axis=0).reshape(-1, 1),
            np.concatenate([
                np.ones(len(M_id)), np.zeros(len(M_ood)),
            ], axis=0),
        )
        
        if verbose > 0:
            print(" - Training the Prediction Model and Uncertainty Quantifier")
        
        self.predictor = RandomForestRegressor(
            n_estimators=n_trees, random_state=rng, n_jobs=-1, min_samples_leaf=5,
            max_features=1.0/3.0, verbose=verbose,
        ).fit(X_train, Y_train)
        
        if verbose > 0:
            print(" - Calibrating the Uncertainty Quantifier")
        
        def tree_predict(i: int) -> np.ndarray:
            if verbose > 1:
                print(f"   - Predicting tree {i}/{len(self.predictor.estimators_)}")
            
            return self.predictor.estimators_[i].predict(X_valid)
            
        valid_predictions = joblib.Parallel(n_jobs=-1, prefer="threads")(
            joblib.delayed(tree_predict)(i) for i in range(len(self.predictor.estimators_))
        )
        
        Y_valid_pred = np.mean(np.stack(valid_predictions, axis=0), axis=0)
        Y_valid_stdv = np.std(np.stack(valid_predictions, axis=0), axis=0)
            
        Zc = (Y_valid.flatten() - Y_valid_pred.flatten()) / Y_valid_stdv.flatten()
        
        self.Zc_mean = np.mean(Zc)
        self.Zc_stdv = np.std(Zc)
        
        if verbose > 0:
            print("- Finished training the PairwiseDifferenceRegressionRandomForestSosaaRSM")
        
        return self
    
    def predict(
        self, X_test: np.ndarray, rng: np.random.Generator, verbose: int = 1,
        direct_difference: bool = False, X_base: np.ndarray = None, Y_base: np.ndarray = None,
    ) -> IcarusPrediction:
        if verbose > 1:
            print("Predicting with the PairwiseDifferenceRegressionRandomForestSosaaRSM")
            
            print(" - Resampling the input dataset")
    
        if not direct_difference:
            # Only one anchor sample is produced, call predict several times for more
            Ia_test = rng.choice(len(self.X_train), size=len(X_test), replace=True)
            X_train = self.X_train
            Y_train = self.Y_train
        else:
            if (X_base is not None) and (Y_base is not None):
                X_train = X_base
                Y_train = Y_base
            else:
                X_train = self.X_train
                Y_train = self.Y_train
            
            assert len(X_test) == len(X_train)
            assert len(X_train) == len(Y_train)
            
            Ia_test = np.arange(len(X_train))
        
        # N(0,1)-N(0,1) ~ N(0,2) -> divide by sqrt(2) s.t. all features are N(0,1)
        X_test = np.concatenate([
            X_train[Ia_test], (X_test-X_train[Ia_test]) / np.sqrt(2.0),
        ], axis=1)
        
        if verbose > 1:
            print(" - Generating confidence scores")
        
        confidence = self.ood_detector.predict_proba(self.scaler.transform(
            self.cov.mahalanobis(self._predict_truncated_pca(X_test) - X_test).reshape(-1, 1)
        ))[:,1]
            
        if verbose > 1:
            print(" - Generating ensemble predictions")
        
        def tree_predict(i: int) -> np.ndarray:
            if verbose > 1:
                print(f"   - Predicting tree {i}/{len(self.predictor.estimators_)}")
            
            return self.predictor.estimators_[i].predict(X_test)
        
        predictions = joblib.Parallel(n_jobs=-1, prefer="threads")(
            joblib.delayed(tree_predict)(i) for i in range(len(self.predictor.estimators_))
        )
        
        prediction = np.mean(np.stack(predictions, axis=0), axis=0).reshape((len(X_test), 1))
        uncertainty = np.std(np.stack(predictions, axis=0), axis=0).reshape((len(X_test), 1))
        
        if verbose > 1:
            print(" - Recalibrating predictions and uncertainties")
        
        prediction = Y_train[Ia_test] + (
            prediction.flatten() + self.Zc_mean * uncertainty.flatten()
        ).reshape((len(X_test), 1))
        uncertainty = (uncertainty.flatten() * self.Zc_stdv).reshape((len(X_test), 1))
        
        if verbose > 1:
            print(" - Finished predicting with the PairwiseDifferenceRegressionRandomForestSosaaRSM")
            
        return IcarusPrediction(
            prediction=prediction, uncertainty=uncertainty, confidence=confidence,
        )

    def _predict_truncated_pca(self, X: np.ndarray) -> np.ndarray:
        if self.pca.mean_ is not None:
            X = X - self.pca.mean_
        
        X_trans = np.dot(X, self.pca.components_[:self.bn].T)
        X = np.dot(X_trans, self.pca.components_[:self.bn])
        
        if self.pca.mean_ is not None:
            X = X + self.pca.mean_
        
        return X

In [None]:
class PercentilePairwiseDifferenceRegressionRandomForestSosaaRSM(IcarusRSM):
    def fit(
        self,
        X_train: np.ndarray, Y_train: np.ndarray,
        X_valid: np.ndarray, Y_valid: np.ndarray,
        rng: np.random.Generator,
        n_trees: int = 16, n_samples: int = 16, verbose: int = 1,
    ) -> PairwiseDifferenceRegressionRandomForestSosaaRSM:
        assert Y_train.shape[1:] == (1,)
        
        self.X_train = X_train
        self.Y_train = Y_train
        
        if verbose > 0:
            print("Training the PercentilePairwiseDifferenceRegressionRandomForestSosaaRSM")
            
            print(" - Resampling the training and validation datasets")
    
        Ia_train = rng.choice(len(self.X_train), size=len(X_train)*n_samples, replace=True)
        Ib_train = rng.choice(len(X_train), size=len(X_train)*n_samples, replace=True)
        Ia_valid = rng.choice(len(self.X_train), size=len(X_valid)*n_samples, replace=True)
        Ib_valid = rng.choice(len(X_valid), size=len(X_valid)*n_samples, replace=True)
        
        # N(0,1)-N(0,1) ~ N(0,2) -> divide by sqrt(2) s.t. all features are N(0,1)
        X_train = np.concatenate([
            self.X_train[Ia_train], (X_train[Ib_train]-self.X_train[Ia_train]) / np.sqrt(2.0),
        ], axis=1)
        Y_train = Y_train[Ib_train] - self.Y_train[Ia_train]
        
        X_valid = np.concatenate([
            self.X_train[Ia_valid], (X_valid[Ib_valid]-self.X_train[Ia_valid]) / np.sqrt(2.0),
        ], axis=1)
        Y_valid = Y_valid[Ib_valid] - self.Y_train[Ia_valid]
            
        if verbose > 0:
            print(" - Training the OOD detector")
            print("   - Fitting truncated PCA")
        
        self.pca = PCA(random_state=rng).fit(X_train)
        self.bn = np.searchsorted(np.cumsum(self.pca.explained_variance_ratio_), 0.95)
        
        if verbose > 0:
            print("   - Fitting truncated PCA reconstruction error covariance")
        
        self.cov = EmpiricalCovariance().fit((self._predict_truncated_pca(X_train) - X_train))
        
        self.err_valid = np.sort(
            self.cov.mahalanobis(self._predict_truncated_pca(X_valid) - X_valid)
        )
        
        if verbose > 0:
            print(" - Training the Prediction Model and Uncertainty Quantifier")
        
        self.predictor = RandomForestRegressor(
            n_estimators=n_trees, random_state=rng, n_jobs=-1, min_samples_leaf=5,
            max_features=1.0/3.0, verbose=verbose,
        ).fit(X_train, Y_train)
        
        if verbose > 0:
            print(" - Calibrating the Uncertainty Quantifier")
        
        def tree_predict(i: int) -> np.ndarray:
            if verbose > 1:
                print(f"   - Predicting tree {i}/{len(self.predictor.estimators_)}")
            
            return self.predictor.estimators_[i].predict(X_valid)
            
        valid_predictions = joblib.Parallel(n_jobs=-1, prefer="threads")(
            joblib.delayed(tree_predict)(i) for i in range(len(self.predictor.estimators_))
        )
        
        Y_valid_pred = np.mean(np.stack(valid_predictions, axis=0), axis=0)
        Y_valid_stdv = np.std(np.stack(valid_predictions, axis=0), axis=0)
            
        Zc = (Y_valid.flatten() - Y_valid_pred.flatten()) / Y_valid_stdv.flatten()
        
        self.Zc_mean = np.mean(Zc)
        self.Zc_stdv = np.std(Zc)
        
        if verbose > 0:
            print("- Finished training the PercentilePairwiseDifferenceRegressionRandomForestSosaaRSM")
        
        return self
    
    def predict(
        self, X_test: np.ndarray, rng: np.random.Generator, verbose: int = 1,
        direct_difference: bool = False, X_base: np.ndarray = None, Y_base: np.ndarray = None,
    ) -> IcarusPrediction:
        if verbose > 1:
            print("Predicting with the PercentilePairwiseDifferenceRegressionRandomForestSosaaRSM")
            
            print(" - Resampling the input dataset")
    
        if not direct_difference:
            # Only one anchor sample is produced, call predict several times for more
            Ia_test = rng.choice(len(self.X_train), size=len(X_test), replace=True)
            X_train = self.X_train
            Y_train = self.Y_train
        else:
            if (X_base is not None) and (Y_base is not None):
                X_train = X_base
                Y_train = Y_base
            else:
                X_train = self.X_train
                Y_train = self.Y_train
            
            assert len(X_test) == len(X_train)
            assert len(X_train) == len(Y_train)
            
            Ia_test = np.arange(len(X_train))
        
        # N(0,1)-N(0,1) ~ N(0,2) -> divide by sqrt(2) s.t. all features are N(0,1)
        X_test = np.concatenate([
            X_train[Ia_test], (X_test-X_train[Ia_test]) / np.sqrt(2.0),
        ], axis=1)
        
        if verbose > 1:
            print(" - Generating confidence scores")
            
        confidence = 1.0 - np.searchsorted(
            self.err_valid,
            self.cov.mahalanobis((self._predict_truncated_pca(X_test) - X_test)),
        ) / len(self.err_valid)
            
        if verbose > 1:
            print(" - Generating ensemble predictions")
        
        def tree_predict(i: int) -> np.ndarray:
            if verbose > 1:
                print(f"   - Predicting tree {i}/{len(self.predictor.estimators_)}")
            
            return self.predictor.estimators_[i].predict(X_test)
        
        predictions = joblib.Parallel(n_jobs=-1, prefer="threads")(
            joblib.delayed(tree_predict)(i) for i in range(len(self.predictor.estimators_))
        )
        
        prediction = np.mean(np.stack(predictions, axis=0), axis=0).reshape((len(X_test), 1))
        uncertainty = np.std(np.stack(predictions, axis=0), axis=0).reshape((len(X_test), 1))
        
        if verbose > 1:
            print(" - Recalibrating predictions and uncertainties")
        
        prediction = Y_train[Ia_test] + (
            prediction.flatten() + self.Zc_mean * uncertainty.flatten()
        ).reshape((len(X_test), 1))
        uncertainty = (uncertainty.flatten() * self.Zc_stdv).reshape((len(X_test), 1))
        
        if verbose > 1:
            print(" - Finished predicting with the PercentilePairwiseDifferenceRegressionRandomForestSosaaRSM")
            
        return IcarusPrediction(
            prediction=prediction, uncertainty=uncertainty, confidence=confidence,
        )

    def _predict_truncated_pca(self, X: np.ndarray) -> np.ndarray:
        if self.pca.mean_ is not None:
            X = X - self.pca.mean_
        
        X_trans = np.dot(X, self.pca.components_[:self.bn].T)
        X = np.dot(X_trans, self.pca.components_[:self.bn])
        
        if self.pca.mean_ is not None:
            X = X + self.pca.mean_
        
        return X

In [None]:
def train_and_cache_model(dt: datetime.datetime, clump: float, datasets: dict, models: dict, cls, **kwargs):
    if isinstance(dt, tuple) or isinstance(dt, list):
        dt = tuple(sorted(dt))
    
    model_key = (cls.__name__, dt, clump)
    
    cached = models.get(model_key)
    
    if cached is not None:
        return cached
    
    model_path = f"{cls.__name__.lower()}.icarus.{hash_for_dt(dt).hexdigest(8)}.{clump}.jl"
    
    if Path(model_path).exists():
        try:
            model = joblib.load(model_path)
            
            models[model_key] = model
        
            return model
        except:
            pass
    
    dataset = load_and_cache_dataset(dt, clump, datasets)
    
    rng = np.random.RandomState(seed=int.from_bytes(hash_for_dt(dt).digest(4), 'little'))
    
    model = cls().fit(
        X_train=dataset.X_train, Y_train=dataset.Y_train,
        X_valid=dataset.X_valid, Y_valid=dataset.Y_valid,
        rng=rng, **kwargs,
    )
    
    joblib.dump(model, model_path)
    
    models[model_key] = model
    
    return model

In [None]:
RANDOM_FOREST_MODELS = dict()

for h in tqdm.tqdm([
    130, 163, 192, 244, 303, 349,
]):
    if isinstance(h, tuple) or isinstance(h, list):
        dt = tuple(
            datetime.datetime(
                year=2018, month=5, day=9+hs//24, hour=hs%24,
            ) for hs in sorted(h)
        )
        
        train_and_cache_model(dt, 0.0, DATASETS, RANDOM_FOREST_MODELS, RandomForestSosaaRSM)
    else:
        dt = datetime.datetime(
            year=2018, month=5, day=9+h//24, hour=h%24,
        )
    
        for clump in tqdm.tqdm([0.0, 0.5, 0.75, 0.85, 0.9]):
            train_and_cache_model(dt, clump, DATASETS, RANDOM_FOREST_MODELS, RandomForestSosaaRSM)

In [None]:
PAIRWISE_RANDOM_FOREST_MODELS = dict()

for h in tqdm.tqdm([
    130, 163, 192, 244, 303, 349,
]):
    if isinstance(h, tuple) or isinstance(h, list):
        dt = tuple(
            datetime.datetime(
                year=2018, month=5, day=9+hs//24, hour=hs%24,
            ) for hs in sorted(h)
        )
        
        train_and_cache_model(
            dt, 0.0, DATASETS, PAIRWISE_RANDOM_FOREST_MODELS,
            PairwiseDifferenceRegressionRandomForestSosaaRSM,
        )
    else:
        dt = datetime.datetime(
            year=2018, month=5, day=9+h//24, hour=h%24,
        )
    
        for clump in tqdm.tqdm([0.0, 0.5, 0.75, 0.85, 0.9]):
            train_and_cache_model(
                dt, clump, DATASETS, PAIRWISE_RANDOM_FOREST_MODELS,
                PairwiseDifferenceRegressionRandomForestSosaaRSM,
            )

In [None]:
PERCENTILE_PAIRWISE_RANDOM_FOREST_MODELS = dict()

for h in tqdm.tqdm([
    130, 163, 192, 244, 303, 349,
]):
    if isinstance(h, tuple) or isinstance(h, list):
        dt = tuple(
            datetime.datetime(
                year=2018, month=5, day=9+hs//24, hour=hs%24,
            ) for hs in sorted(h)
        )
        
        train_and_cache_model(
            dt, 0.0, DATASETS, PERCENTILE_PAIRWISE_RANDOM_FOREST_MODELS,
            PercentilePairwiseDifferenceRegressionRandomForestSosaaRSM,
        )
    else:
        dt = datetime.datetime(
            year=2018, month=5, day=9+h//24, hour=h%24,
        )
    
        for clump in tqdm.tqdm([0.75]):
            train_and_cache_model(
                dt, clump, DATASETS, PERCENTILE_PAIRWISE_RANDOM_FOREST_MODELS,
                PercentilePairwiseDifferenceRegressionRandomForestSosaaRSM,
            )

In [None]:
class Table:
    def __init__(self, filepath, keys=[]):
        self.filepath = filepath
        
        self.cols = { k: [] for k in keys }
    
    def __enter__(self):
        return self
    
    def insert(self, **kwargs):
        for k, v in kwargs.items():
            self.cols[k].append(v)
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_val is None:
            df = pd.DataFrame(self.cols)
            df.to_csv(self.filepath, header=True, index=False)
        
        return False
            
    def backup(self):
        df = pd.DataFrame(self.cols)
        df.to_csv(self.filepath, header=True, index=False)

In [None]:
def analyse_icarus_predictions(
    predictions: List[IcarusPrediction],
    analysis: Callable[[List[np.ndarray], np.ndarray, np.random.Generator, dict], np.ndarray],
    rng: np.random.Generator,
    n_uncertain_samples: int = 1,  # number of samples to draw from expand each prediction per run
    n_analysis_runs: int = 100,  # number of repeats of the analysis to gather uncertainty
    **kwargs,
):
    assert len(predictions) > 0
    
    # predictions that need to coexist multiply their confidence
    prod_confidence = np.prod([p.confidence for p in predictions], axis=0)
    # independent predictions average their confidence
    confidence = np.mean(prod_confidence)

    results = []

    for _ in range(n_analysis_runs):
        confs = []
        preds = [[] for _ in predictions]
        for _ in range(n_uncertain_samples):
            I_conf = (
                rng.random(size=prod_confidence.shape) <= prod_confidence
            )
            (I_conf,) = np.nonzero(I_conf)

            confs.append(I_conf)
            
            for i, p in enumerate(predictions):
                preds[i].append(
                    rng.normal(
                        loc=p.prediction[I_conf],
                        scale=p.uncertainty[I_conf],
                    )
                )
        confs = np.concatenate(confs, axis=0)
        preds = [
            np.concatenate(p, axis=0) for p in preds
        ]

        results.append(analysis(preds, confs, rng, **kwargs))

    prediction = np.mean(np.stack(results, axis=0), axis=0)
    uncertainty = np.std(np.stack(results, axis=0), axis=0)

    return IcarusPrediction(
        prediction=prediction,
        uncertainty=uncertainty,
        confidence=confidence,
    )

In [None]:
def combine_many_icarus_predictions(
    model: IcarusRSM,
    X_test: np.ndarray,
    rng: np.random.Generator,
    n_samples: int,
    n_uncertain_samples: int, # number of samples to draw from expand each prediction per run
    n_analysis_runs: int, # number of repeats of the analysis to gather uncertainty
    **kwargs,
) -> IcarusPrediction:
    model_predictions = []
    for i in range(n_samples):
        model_predictions.append(model.predict(X_test, rng, **kwargs))

    combined_predictions = IcarusPrediction(
        prediction=[],
        uncertainty=[],
        confidence=[],
    )

    if len(model_predictions) > 0:
        for i in range(len(model_predictions[0].prediction)):
            predictions = np.array([p.prediction[i] for p in model_predictions])
            uncertainties = np.array([p.uncertainty[i] for p in model_predictions])
            confidences = np.array([p.confidence[i] for p in model_predictions])

            def combine_predictions(Y_pred, I_pred, rng, **kwargs):
                Y_pred, = Y_pred
                
                return np.mean(Y_pred) if len(Y_pred) > 0 else 0.0

            cp = analyse_icarus_predictions(
                [IcarusPrediction(
                    prediction=predictions,
                    uncertainty=uncertainties,
                    confidence=confidences,
                )],
                combine_predictions,
                rng,
                n_uncertain_samples=n_uncertain_samples,
                n_analysis_runs=n_analysis_runs,
            )

            combined_predictions.prediction.append(cp.prediction)
            combined_predictions.uncertainty.append(cp.uncertainty)
            combined_predictions.confidence.append(cp.confidence)

    return IcarusPrediction(
        prediction=np.array(combined_predictions.prediction).reshape(-1, 1),
        uncertainty=np.array(combined_predictions.uncertainty).reshape(-1, 1),
        confidence=np.array(combined_predictions.confidence),
    )

In [None]:
def evaluate_model_matrix(datasets, models, n_samples, title=None, **kwargs):
    rng = np.random.RandomState(seed=42)
    
    with Table(
        filepath=f"{title.lower().replace(' ', '-')}.csv",
        keys=[
            "model_date", "model_clump", "data_date", "data_clump", "mse", "mse_stdv",
            "mse_conf", "mae", "mae_stdv", "mae_conf", "r2", "r2_stdv", "r2_conf",
        ],
    ) as table:
        for key in models.keys():
            cls, dt, clump = key
            
            model = models[key]
            mlm = load_and_cache_dataset(dt, clump, datasets)

            for key2 in tqdm.tqdm(models.keys()):
                cls2, dt2, clump2 = key2
                
                mlt = load_and_cache_dataset(dt2, clump2, datasets)
                X_test = mlm.X_scaler.transform(
                    mlt.X_scaler.inverse_transform(mlt.X_test, copy=True)
                )

                Y_test = mlt.Y_scaler.inverse_transform(mlt.Y_test)
                Y_pred = combine_many_icarus_predictions(
                    model,
                    X_test,
                    rng,
                    n_samples,
                    10, # n_uncertain_samples
                    100, # n_analysis_runs
                    **kwargs,
                )
                
                Y_pred = IcarusPrediction(
                    prediction=mlm.Y_scaler.inverse_transform(Y_pred.prediction),
                    uncertainty=Y_pred.uncertainty * mlm.Y_scaler.scale_,
                    confidence=Y_pred.confidence,
                )
                
                def mse_mae_r2_analysis(Y_pred, I_pred, rng, **kwargs):
                    Y_pred, = Y_pred
                    Y_true = Y_test[I_pred]

                    if len(Y_pred) >= 1:
                        mse = mean_squared_error(Y_true, Y_pred)
                        mae = mean_absolute_error(Y_true, Y_pred)
                    else:
                        mse = 0.0
                        mae = 0.0
                    
                    if len(Y_pred) >= 2:
                        r2 = r2_score(Y_true, Y_pred)
                    else:
                        r2 = 1.0

                    return np.array([mse, mae, r2])
                
                mse_mae_r2 = analyse_icarus_predictions(
                    [Y_pred],
                    mse_mae_r2_analysis,
                    rng,
                    n_uncertain_samples=1,
                    n_analysis_runs=10,
                )
                
                print(mse_mae_r2)
                
                table.insert(
                    model_date=dt, model_clump=clump, data_date=dt2, data_clump=clump2,
                    mse=mse_mae_r2.prediction[0],
                    mse_stdv=mse_mae_r2.uncertainty[0],
                    mse_conf=mse_mae_r2.confidence,
                    mae=mse_mae_r2.prediction[1],
                    mae_stdv=mse_mae_r2.uncertainty[1],
                    mae_conf=mse_mae_r2.confidence,
                    r2=mse_mae_r2.prediction[2],
                    r2_stdv=mse_mae_r2.uncertainty[2],
                    r2_conf=mse_mae_r2.confidence,
                )

In [None]:
evaluate_model_matrix(DATASETS, {
    k: v for k, v in RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, 1, title="Trajectory Generalisation RF")

In [None]:
evaluate_model_matrix(DATASETS, {
    k: v for k, v in PAIRWISE_RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, 4, title="Trajectory Generalisation PADRE-RF")

In [None]:
def evaluate_model_neighbourhood_matrix(datasets, models, all_models, n_samples, title=None, **kwargs):
    offsets = [-4, -2, -1, 0, 1, 2, 4]
    
    rng = np.random.RandomState(seed=42)
    
    with Table(
        filepath=f"{title.lower().replace(' ', '-')}.csv",
        keys=[
            "model_date", "model_clump", "data_hour_offset", "data_clump",
            "mse", "mse_stdv", "mse_conf", "mae", "mae_stdv", "mae_conf",
            "r2", "r2_stdv", "r2_conf",
        ],
    ) as table:
        for key in models.keys():
            cls, dt, clump = key

            model = all_models[key]
            mlm = load_and_cache_dataset(dt, clump, datasets)

            for i in offsets:
                mlt = load_and_cache_dataset(dt + datetime.timedelta(hours=i), clump, datasets)
                
                X_test = mlm.X_scaler.transform(
                    mlt.X_scaler.inverse_transform(mlt.X_test, copy=True)
                )

                Y_test = mlt.Y_scaler.inverse_transform(mlt.Y_test)
                Y_pred = combine_many_icarus_predictions(
                    model,
                    X_test,
                    rng,
                    n_samples,
                    10, # n_uncertain_samples
                    100, # n_analysis_runs
                    **kwargs,
                )
                
                Y_pred = IcarusPrediction(
                    prediction=mlm.Y_scaler.inverse_transform(Y_pred.prediction),
                    uncertainty=Y_pred.uncertainty * mlm.Y_scaler.scale_,
                    confidence=Y_pred.confidence,
                )
                
                def mse_mae_r2_analysis(Y_pred, I_pred, rng, **kwargs):
                    Y_pred, = Y_pred
                    Y_true = Y_test[I_pred]

                    if len(Y_pred) >= 1:
                        mse = mean_squared_error(Y_true, Y_pred)
                        mae = mean_absolute_error(Y_true, Y_pred)
                    else:
                        mse = 0.0
                        mae = 0.0
                    
                    if len(Y_pred) >= 2:
                        r2 = r2_score(Y_true, Y_pred)
                    else:
                        r2 = 1.0

                    return np.array([mse, mae, r2])
                
                mse_mae_r2 = analyse_icarus_predictions(
                    [Y_pred],
                    mse_mae_r2_analysis,
                    rng,
                    n_uncertain_samples=1,
                    n_analysis_runs=10,
                )
                
                print(mse_mae_r2)
                
                table.insert(
                    model_date=dt, model_clump=clump, data_hour_offset=i, data_clump=clump,
                    mse=mse_mae_r2.prediction[0],
                    mse_stdv=mse_mae_r2.uncertainty[0],
                    mse_conf=mse_mae_r2.confidence,
                    mae=mse_mae_r2.prediction[1],
                    mae_stdv=mse_mae_r2.uncertainty[1],
                    mae_conf=mse_mae_r2.confidence,
                    r2=mse_mae_r2.prediction[2],
                    r2_stdv=mse_mae_r2.uncertainty[2],
                    r2_conf=mse_mae_r2.confidence,
                )

In [None]:
evaluate_model_neighbourhood_matrix(DATASETS, {
    k: v for k, v in RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, RANDOM_FOREST_MODELS, 1, title="Temporal Generalisation RF")

In [None]:
evaluate_model_neighbourhood_matrix(DATASETS, {
    k: v for k, v in PAIRWISE_RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, PAIRWISE_RANDOM_FOREST_MODELS, 4, title="Temporal Generalisation PADRE-RF")

In [None]:
def evaluate_model_clump_matrix(datasets, models, all_models, n_samples, title=None, **kwargs):
    clumps = [0.0, 0.5, 0.75, 0.85, 0.9]
    
    rng = np.random.RandomState(seed=42)
    
    with Table(
        filepath=f"{title.lower().replace(' ', '-')}.csv",
        keys=[
            "model_date", "model_clump", "mse", "mse_stdv", "mse_conf",
            "mae", "mae_stdv", "mae_conf", "r2", "r2_stdv", "r2_conf",
        ],
    ) as table:
        for key in models.keys():
            cls, dt, _ = key

            for c in clumps:
                model = all_models[(cls, dt, c)]
                mlm = load_and_cache_dataset(dt, c, datasets)
                
                X_test = np.copy(mlm.X_test)

                Y_test = mlm.Y_scaler.inverse_transform(mlm.Y_test)
                Y_pred = combine_many_icarus_predictions(
                    model,
                    X_test,
                    rng,
                    n_samples,
                    10, # n_uncertain_samples
                    100, # n_analysis_runs
                    **kwargs,
                )
                
                Y_pred = IcarusPrediction(
                    prediction=mlm.Y_scaler.inverse_transform(Y_pred.prediction),
                    uncertainty=Y_pred.uncertainty * mlm.Y_scaler.scale_,
                    confidence=Y_pred.confidence,
                )
                
                def mse_mae_r2_analysis(Y_pred, I_pred, rng, **kwargs):
                    Y_pred, = Y_pred
                    Y_true = Y_test[I_pred]

                    if len(Y_pred) >= 1:
                        mse = mean_squared_error(Y_true, Y_pred)
                        mae = mean_absolute_error(Y_true, Y_pred)
                    else:
                        mse = 0.0
                        mae = 0.0
                    
                    if len(Y_pred) >= 2:
                        r2 = r2_score(Y_true, Y_pred)
                    else:
                        r2 = 1.0

                    return np.array([mse, mae, r2])
                
                mse_mae_r2 = analyse_icarus_predictions(
                    [Y_pred],
                    mse_mae_r2_analysis,
                    rng,
                    n_uncertain_samples=1,
                    n_analysis_runs=10,
                )
                
                print(mse_mae_r2)
                
                table.insert(
                    model_date=dt, model_clump=c,
                    mse=mse_mae_r2.prediction[0],
                    mse_stdv=mse_mae_r2.uncertainty[0],
                    mse_conf=mse_mae_r2.confidence,
                    mae=mse_mae_r2.prediction[1],
                    mae_stdv=mse_mae_r2.uncertainty[1],
                    mae_conf=mse_mae_r2.confidence,
                    r2=mse_mae_r2.prediction[2],
                    r2_stdv=mse_mae_r2.uncertainty[2],
                    r2_conf=mse_mae_r2.confidence,
                )

In [None]:
evaluate_model_clump_matrix(DATASETS, {
    k: v for k, v in RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, RANDOM_FOREST_MODELS, 1, title="Clumped Generalisation RF")

In [None]:
evaluate_model_clump_matrix(DATASETS, {
    k: v for k, v in PAIRWISE_RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, PAIRWISE_RANDOM_FOREST_MODELS, 4, title="Clumped Generalisation PADRE-RF")

In [None]:
def evaluate_model_perturbation_matrix(
    datasets, perturbed_datasets, models, n_samples, title=None, direct_difference=False, **kwargs,
):
    perturbations = [
        Path("anthropogenic")/ "mul_1.5",
        Path("anthropogenic")/ "div_1.5",
        
        Path("biogenic")/ "mul_1.5",
        Path("biogenic")/ "div_1.5",
        
        Path("aerosols")/ "mul_1.5",
        Path("aerosols")/ "div_1.5",
        
        Path("monoterpenes")/ "mul_1.5",
        Path("monoterpenes")/ "div_1.5",
        
        Path("sesquiterpenes")/ "mul_1.5",
        Path("sesquiterpenes")/ "div_1.5",
        
        Path("so2")/ "mul_1.5",
        Path("so2")/ "div_1.5",
        
        Path("nox")/ "mul_1.5",
        Path("nox")/ "div_1.5",
        
        Path("temperature")/ "add_2K",
        Path("temperature")/ "sub_2K",
        
        Path("anthropogenic")/ "mul_1.01",
        Path("anthropogenic")/ "div_1.01",
        
        Path("biogenic")/ "mul_1.01",
        Path("biogenic")/ "div_1.01",
        
        Path("aerosols")/ "mul_1.01",
        Path("aerosols")/ "div_1.01",
        
        Path("monoterpenes")/ "mul_1.01",
        Path("monoterpenes")/ "div_1.01",
        
        Path("sesquiterpenes")/ "mul_1.01",
        Path("sesquiterpenes")/ "div_1.01",
        
        Path("so2")/ "mul_1.01",
        Path("so2")/ "div_1.01",
        
        Path("nox")/ "mul_1.01",
        Path("nox")/ "div_1.01",
        
        Path("temperature")/ "add_0.04K",
        Path("temperature")/ "sub_0.04K",
    ]
    
    rng = np.random.RandomState(seed=42)
    
    for (cls, dt, clump) in models.keys():
        for perturbation in tqdm.tqdm(perturbations):
            if perturbed_datasets.get((dt, perturbation)) is not None:
                continue
            
            dp = get_path_for_perturbation(dt, Path("perturbation") / perturbation)
            ds = load_trajectory_dataset(dp)

            X_raw = get_raw_features_for_dataset(ds)

            X = get_features_from_raw_features(X_raw)
            Y = np.log10(get_labels_for_dataset(ds) + 1)

            # Close the NetCDF datasets
            ds.out.close()
            ds.aer.close()
            ds.ant.close()
            ds.bio.close()
            ds.met.close()
            
            perturbed_datasets[(dt, perturbation)] = PerturbedDataset(
                dt, perturbation, dp, X, Y,
            )
    
    with Table(
        filepath=f"{title.lower().replace(' ', '-')}.csv",
        keys=[
            "model_date", "model_clump", "perturbation", "mse", "mse_stdv",
            "mse_conf", "mae", "mae_stdv", "mae_conf", "r2", "r2_stdv", "r2_conf",
        ],
    ) as table:
        for key in models.keys():
            cls, dt, clump = key
            
            model = models[key]
            mlm = load_and_cache_dataset(dt, clump, datasets)

            X_raw = get_features_from_raw_features(mlm.X_raw.copy())
            X_base = np.nan_to_num(mlm.X_scaler.transform(X_raw))
            Y_base = mlm.Y_raw.to_numpy()
            
            if not direct_difference:
                Y_unpt = combine_many_icarus_predictions(
                    model,
                    np.copy(X_base),
                    rng,
                    n_samples,
                    10, # n_uncertain_samples
                    100, # n_analysis_runs
                    **kwargs,
                )

                Y_unpt = IcarusPrediction(
                    prediction=mlm.Y_scaler.inverse_transform(Y_unpt.prediction),
                    uncertainty=Y_unpt.uncertainty * mlm.Y_scaler.scale_,
                    confidence=Y_unpt.confidence,
                )

            for p in tqdm.tqdm(perturbations):
                dsp = perturbed_datasets[(dt, p)]

                X_prtb = np.nan_to_num(mlm.X_scaler.transform(dsp.X, copy=True))
                Y_test = dsp.Y.to_numpy()
                
                if direct_difference:
                    kwargs["X_base"] = X_base
                    kwargs["Y_base"] = mlm.Y_scaler.transform(mlm.Y_raw)
                
                Y_prtb = combine_many_icarus_predictions(
                    model,
                    X_prtb,
                    rng,
                    n_samples,
                    10, # n_uncertain_samples
                    100, # n_analysis_runs
                    **kwargs,
                )
                
                Y_prtb = IcarusPrediction(
                    prediction=mlm.Y_scaler.inverse_transform(Y_prtb.prediction),
                    uncertainty=Y_prtb.uncertainty * mlm.Y_scaler.scale_,
                    confidence=Y_prtb.confidence,
                )
                
                def diff_mse_mae_r2_analysis(Y_true, Y_preds, I_pred, rng, **kwargs):
                    Y_unpt, Y_prtb = Y_preds
                    Y_base, Y_test = Y_true
                    Y_base = Y_base[I_pred]
                    Y_test = Y_test[I_pred]

                    if len(Y_prtb) >= 1:
                        mse = mean_squared_error(Y_test-Y_base, Y_prtb-Y_unpt)
                        mae = mean_absolute_error(Y_test-Y_base, Y_prtb-Y_unpt)
                    else:
                        mse = 0.0
                        mae = 0.0
                    
                    if len(Y_prtb) >= 2:
                        r2 = r2_score(Y_test-Y_base, Y_prtb-Y_unpt)
                    else:
                        r2 = 1.0
                    
                    # TODO: remove
                    bins = np.linspace(0.0, 15.0, 150)
                    binids = np.searchsorted(bins, Y_base.flatten())
                    bincnt = np.bincount(binids, minlength=150)
                    binsum_true = np.bincount(binids, weights=(Y_test-Y_base).flatten(), minlength=150)
                    binsum_pred = np.bincount(binids, weights=(Y_prtb-Y_unpt).flatten(), minlength=150)
                    nonzero = bincnt != 0
                    binmean_true = binsum_true[nonzero] / bincnt[nonzero]
                    binmean_pred = binsum_pred[nonzero] / bincnt[nonzero]
                    
                    if len(binmean_true) >= 2:
                        r2b = r2_score(binmean_true, binmean_pred)
                    else:
                        r2b = 1.0

                    return np.array([mse, mae, r2, r2b])
                
                def direct_mse_mae_r2_analysis(Y_true, Y_preds, I_pred, rng, **kwargs):
                    Y_prtb, = Y_preds
                    Y_base, Y_test = Y_true
                    Y_base = Y_base[I_pred]
                    Y_test = Y_test[I_pred]

                    if len(Y_prtb) >= 1:
                        mse = mean_squared_error(Y_test-Y_base, Y_prtb-Y_base)
                        mae = mean_absolute_error(Y_test-Y_base, Y_prtb-Y_base)
                    else:
                        mse = 0.0
                        mae = 0.0
                    
                    if len(Y_prtb) >= 2:
                        r2 = r2_score(Y_test-Y_base, Y_prtb-Y_base)
                    else:
                        r2 = 1.0
                    
                    # TODO: remove
                    bins = np.linspace(0.0, 15.0, 150)
                    binids = np.searchsorted(bins, Y_base.flatten())
                    bincnt = np.bincount(binids, minlength=150)
                    binsum_true = np.bincount(binids, weights=(Y_test-Y_base).flatten(), minlength=150)
                    binsum_pred = np.bincount(binids, weights=(Y_prtb-Y_base).flatten(), minlength=150)
                    nonzero = bincnt != 0
                    binmean_true = binsum_true[nonzero] / bincnt[nonzero]
                    binmean_pred = binsum_pred[nonzero] / bincnt[nonzero]
                    
                    if len(binmean_true) >= 2:
                        r2b = r2_score(binmean_true, binmean_pred)
                    else:
                        r2b = 1.0

                    return np.array([mse, mae, r2, r2b])
                
                mse_mae_r2 = analyse_icarus_predictions(
                    [Y_prtb] if direct_difference else [Y_unpt, Y_prtb],
                    partial(
                        direct_mse_mae_r2_analysis if direct_difference else diff_mse_mae_r2_analysis,
                        [Y_base, Y_test],
                    ),
                    rng,
                    n_uncertain_samples=1,
                    n_analysis_runs=10,
                )
                
                print(direct_difference, mse_mae_r2)
                
                table.insert(
                    model_date=dt, model_clump=clump, perturbation=p,
                    mse=mse_mae_r2.prediction[0],
                    mse_stdv=mse_mae_r2.uncertainty[0],
                    mse_conf=mse_mae_r2.confidence,
                    mae=mse_mae_r2.prediction[1],
                    mae_stdv=mse_mae_r2.uncertainty[1],
                    mae_conf=mse_mae_r2.confidence,
                    r2=mse_mae_r2.prediction[2],
                    r2_stdv=mse_mae_r2.uncertainty[2],
                    r2_conf=mse_mae_r2.confidence,
                )
            
            table.backup()

In [None]:
evaluate_model_perturbation_matrix(DATASETS, PERTURBED_DATASETS, {
    k: v for k, v in RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, 1, title="Perturbation Generalisation RF")

In [None]:
evaluate_model_perturbation_matrix(DATASETS, PERTURBED_DATASETS, {
    k: v for k, v in PAIRWISE_RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, 4, title="Perturbation Generalisation PADRE-RF")

In [None]:
evaluate_model_perturbation_matrix(DATASETS, PERTURBED_DATASETS, {
    k: v for k, v in PERCENTILE_PAIRWISE_RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, 4, title="Perturbation Generalisation Percentile PADRE-RF")

In [None]:
class DirectDifferencePairwiseDifferenceRegressionRandomForestSosaaRSM(IcarusRSM):
    def __init__(self, padre_rf: PairwiseDifferenceRegressionRandomForestSosaaRSM):
        self.padre_rf = padre_rf
    
    def fit(
        self,
        X_train: np.ndarray, Y_train: np.ndarray,
        X_valid: np.ndarray, Y_valid: np.ndarray,
        rng: np.random.Generator,
        **kwargs,
    ) -> DirectDifferencePairwiseDifferenceRegressionRandomForestSosaaRSM:
        self.padre_rf.fit(X_train, Y_train, X_valid, Y_valid, rng, **kwargs)
        
        return self
    
    def predict(
        self, X_test: np.ndarray, rng: np.random.Generator, **kwargs
    ) -> IcarusPrediction:
        return self.padre_rf.predict(X_test, rng, direct_difference=True, **kwargs)

evaluate_model_perturbation_matrix(DATASETS, PERTURBED_DATASETS, {
    k: DirectDifferencePairwiseDifferenceRegressionRandomForestSosaaRSM(
        v,
    ) for k, v in PAIRWISE_RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, 1, title="Perturbation Generalisation PADRE-RF Direct", direct_difference=True)

In [None]:
evaluate_model_perturbation_matrix(DATASETS, PERTURBED_DATASETS, {
    k: DirectDifferencePairwiseDifferenceRegressionRandomForestSosaaRSM(
        v,
    ) for k, v in PERCENTILE_PAIRWISE_RANDOM_FOREST_MODELS.items() if not isinstance(k[1], tuple) and k[2] == 0.75
}, 1, title="Perturbation Generalisation Percentile PADRE-RF Direct", direct_difference=True)