In [1]:
import datetime
import hashlib
import itertools
import json
import os
import pickle
import re
import warnings

from collections import namedtuple
from copy import copy
from pathlib import Path

In [2]:
import cartopy.feature as cfeature
import cartopy.crs as ccrs

import joblib

import matplotlib as mpl
from matplotlib import gridspec
from matplotlib import pyplot as plt

import netCDF4
from netCDF4 import Dataset

import numpy as np

import pandas as pd

import scipy as sp

import sklearn
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.preprocessing import StandardScaler

import tqdm

In [3]:
import matplotlib.patheffects

In [4]:
outlined = [
    mpl.patheffects.Stroke(linewidth=3, foreground="white"),
    mpl.patheffects.Normal(),
]

In [5]:
import colorcet as cc

time_cmap = cc.m_CET_L19

In [6]:
TrajectoryPaths = namedtuple("TrajectoryPaths", ["date", "out", "aer", "ant", "bio", "met"])
TrajectoryDatasets = namedtuple("TrajectoryDatasets", ["date", "out", "aer", "ant", "bio", "met"])
MLDataset = namedtuple("MLDataset", ["date", "paths", "X_raw", "Y_raw", "X_train", "X_valid", "X_test", "Y_train", "Y_valid", "Y_test", "X_scaler", "Y_scaler"])
PerturbedDataset = namedtuple("PerturbedDataset", ["date", "perturbation", "paths", "X", "Y"])

In [7]:
def get_path_for_perturbation(dt: datetime.datetime, perturbation: Path) -> TrajectoryPaths:
    base = Path.cwd().parent / "trajectories"
    
    out_path = base / "outputs" / perturbation / dt.strftime('%Y%m%d_T%H') / "output.nc"
    aer_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{dt.strftime('%Y%m%d')}_7daybwd_Hyde_traj_AER_{24-dt.hour:02}_L3.nc"
    )
    ant_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{dt.strftime('%Y%m%d')}_7daybwd_Hyde_traj_ANT_{24-dt.hour:02}_L3.nc"
    )
    bio_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{dt.strftime('%Y%m%d')}_7daybwd_Hyde_traj_BIO_{24-dt.hour:02}_L3.nc"
    )
    met_path = (
        base / "inputs" / perturbation / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{dt.strftime('%Y%m%d')}" /
        "METEO" /
        f"METEO_{dt.strftime('%Y%m%d')}_R{24-dt.hour:02}.nc"
    )
    
    if (
        (not out_path.exists()) or (not aer_path.exists()) or
        (not ant_path.exists()) or (not bio_path.exists()) or
        (not met_path.exists())
    ):
        raise Exception(out_path, aer_path, ant_path, bio_path, met_path)
    
    return TrajectoryPaths(
        date=dt, out=out_path, aer=aer_path, ant=ant_path, bio=bio_path, met=met_path,
    )

In [8]:
def load_trajectory_dataset(paths: TrajectoryPaths) -> TrajectoryDatasets:
    outds = Dataset(paths.out, "r", format="NETCDF4")
    aerds = Dataset(paths.aer, "r", format="NETCDF4")
    antds = Dataset(paths.ant, "r", format="NETCDF4")
    biods = Dataset(paths.bio, "r", format="NETCDF4")
    metds = Dataset(paths.met, "r", format="NETCDF4")
    
    return TrajectoryDatasets(
        date=paths.date, out=outds, aer=aerds, ant=antds, bio=biods, met=metds,
    )

In [9]:
data_proj = ccrs.PlateCarree()
projection = ccrs.LambertConformal(
    central_latitude=50, central_longitude=20, standard_parallels=(25, 25)
)
extent = [-60, 60, 40, 80]

In [10]:
def get_ccn_concentration(ds: TrajectoryDatasets):
    ccn_bin_indices, = np.nonzero(ds.out["dp_dry_fs"][:].data > 80e-9)
    ccn_concentration = np.sum(ds.out["nconc_par"][:].data[:,ccn_bin_indices,:], axis=1)
    
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "ccn": ccn_concentration.flatten(),
    }).set_index(["time", "level"])

In [11]:
def get_output_time(ds: TrajectoryDatasets):
    fdom = datetime.datetime.strptime(
        ds.out["time"].__dict__["first_day_of_month"], "%Y-%m-%d %H:%M:%S",
    )
    dt = (ds.date - fdom).total_seconds()
    
    out_t = ds.out["time"][:].data
    
    return out_t - dt

def interpolate_meteorology_values(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    met_t = ds.met["time"][:].data
    met_h = ds.met["lev"][:].data
    
    met_t_h = ds.met[key][:]
    
    met_t_h_int = sp.interpolate.interp2d(
        x=met_h, y=met_t, z=met_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return met_t_h_int(x=out_h, y=out_t)

def interpolate_meteorology_time_values(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    met_t = ds.met["time"][:].data
    
    met_t_v = ds.met[key][:]
    
    met_t_int = sp.interpolate.interp1d(
        x=met_t, y=met_t_v, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return np.repeat(
        met_t_int(x=out_t).reshape(-1, 1),
        out_h.shape[0], axis=1,
    )

def interpolate_biogenic_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    # depth of each box layer, assuming level heights are midpoints and end points are clamped
    out_d = (np.array(list(out_h[1:])+[out_h[-1]]) - np.array([out_h[0]]+list(out_h[:-1]))) / 2.0
    
    bio_t = ds.bio["time"][:].data
    
    # Biogenic emissions are limited to boxes at <= 10m height
    biogenic_emission_layers = np.nonzero(out_h <= 10.0)
    biogenic_emission_layer_height_cumsum = np.cumsum(out_d[biogenic_emission_layers])
    biogenic_emission_layer_proportion = biogenic_emission_layer_height_cumsum / biogenic_emission_layer_height_cumsum[-1]
    num_biogenic_emission_layers = sum(out_h <= 10.0)
    
    bio_t_h = np.zeros(shape=(out_t.size, out_h.size))
    
    bio_t_int = sp.interpolate.interp1d(
        x=bio_t, y=ds.bio[key][:], kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    # Split up the biogenic emissions relative to the depth of the boxes
    bio_t_h[:,biogenic_emission_layers] = (
        np.tile(bio_t_int(x=out_t), (num_biogenic_emission_layers, 1, 1)) * biogenic_emission_layer_proportion.reshape(-1, 1, 1)
    ).T
    
    return bio_t_h

def interpolate_aerosol_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    aer_t = ds.aer["time"][:].data
    aer_h = ds.aer["mid_layer_height"][:].data
    
    aer_t_h = ds.aer[key][:].T
    
    aer_t_h_int = sp.interpolate.interp2d(
        x=aer_h, y=aer_t, z=aer_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return aer_t_h_int(x=out_h, y=out_t)

def interpolate_anthropogenic_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    ant_t = ds.ant["time"][:].data
    ant_h = ds.ant["mid_layer_height"][:].data
    
    ant_t_h = ds.ant[key][:].T
    
    ant_t_h_int = sp.interpolate.interp2d(
        x=ant_h, y=ant_t, z=ant_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return ant_t_h_int(x=out_h, y=out_t)

In [12]:
def get_meteorology_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "met_t": interpolate_meteorology_values(ds, "t").flatten(),
        # "met_u": interpolate_meteorology_values(ds, "u").flatten(),
        # "met_v": interpolate_meteorology_values(ds, "v").flatten(),
        "met_q": interpolate_meteorology_values(ds, "q").flatten(),
        # "met_qc": interpolate_meteorology_values(ds, "qc").flatten(),
        # "met_sp": interpolate_meteorology_time_values(ds, "sp").flatten(),
        # "met_cp": interpolate_meteorology_time_values(ds, "cp").flatten(),
        # "met_sshf": interpolate_meteorology_time_values(ds, "sshf").flatten(),
        "met_ssr": interpolate_meteorology_time_values(ds, "ssr").flatten(),
        # "met_lsp": interpolate_meteorology_time_values(ds, "lsp").flatten(),
        # "met_ewss": interpolate_meteorology_time_values(ds, "ewss").flatten(),
        # "met_nsss": interpolate_meteorology_time_values(ds, "nsss").flatten(),
        # "met_tcc": interpolate_meteorology_time_values(ds, "tcc").flatten(),
        "met_lsm": interpolate_meteorology_time_values(ds, "lsm").flatten(),
        # "met_omega": interpolate_meteorology_values(ds, "omega").flatten(),
        # "met_z": interpolate_meteorology_time_values(ds, "z").flatten(),
        # "met_mla": interpolate_meteorology_values(ds, "mla").flatten(),
        # NOTE: lp is excluded because it allows the model to overfit
        # "met_lp": interpolate_meteorology_values(ds, "lp").flatten(),
        "met_blh": interpolate_meteorology_time_values(ds, "blh").flatten(),
    }).set_index(["time", "level"])

def get_bio_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "bio_acetaldehyde": interpolate_biogenic_emissions(ds, "acetaldehyde").flatten(),
        "bio_acetone": interpolate_biogenic_emissions(ds, "acetone").flatten(),
        "bio_butanes_and_higher_alkanes": interpolate_biogenic_emissions(ds, "butanes-and-higher-alkanes").flatten(),
        "bio_butanes_and_higher_alkenes": interpolate_biogenic_emissions(ds, "butenes-and-higher-alkenes").flatten(),
        "bio_ch4": interpolate_biogenic_emissions(ds, "CH4").flatten(),
        "bio_co": interpolate_biogenic_emissions(ds, "CO").flatten(),
        "bio_ethane": interpolate_biogenic_emissions(ds, "ethane").flatten(),
        "bio_ethanol": interpolate_biogenic_emissions(ds, "ethanol").flatten(),
        "bio_ethene": interpolate_biogenic_emissions(ds, "ethene").flatten(),
        "bio_formaldehyde": interpolate_biogenic_emissions(ds, "formaldehyde").flatten(),
        "bio_hydrogen_cyanide": interpolate_biogenic_emissions(ds, "hydrogen-cyanide").flatten(),
        "bio_iosprene": interpolate_biogenic_emissions(ds, "isoprene").flatten(),
        "bio_mbo": interpolate_biogenic_emissions(ds, "MBO").flatten(),
        "bio_methanol": interpolate_biogenic_emissions(ds, "methanol").flatten(),
        "bio_methyl_bromide": interpolate_biogenic_emissions(ds, "methyl-bromide").flatten(),
        "bio_methyl_chloride": interpolate_biogenic_emissions(ds, "methyl-chloride").flatten(),
        "bio_methyl_iodide": interpolate_biogenic_emissions(ds, "methyl-iodide").flatten(),
        "bio_other_aldehydes": interpolate_biogenic_emissions(ds, "other-aldehydes").flatten(),
        "bio_other_ketones": interpolate_biogenic_emissions(ds, "other-ketones").flatten(),
        "bio_other_monoterpenes": interpolate_biogenic_emissions(ds, "other-monoterpenes").flatten(),
        "bio_pinene_a": interpolate_biogenic_emissions(ds, "pinene-a").flatten(),
        "bio_pinene_b": interpolate_biogenic_emissions(ds, "pinene-b").flatten(),
        "bio_propane": interpolate_biogenic_emissions(ds, "propane").flatten(),
        "bio_propene": interpolate_biogenic_emissions(ds, "propene").flatten(),
        "bio_sesquiterpenes": interpolate_biogenic_emissions(ds, "sesquiterpenes").flatten(),
        "bio_toluene": interpolate_biogenic_emissions(ds, "toluene").flatten(),
        "bio_ch2br2": interpolate_biogenic_emissions(ds, "CH2Br2").flatten(),
        "bio_ch3i": interpolate_biogenic_emissions(ds, "CH3I").flatten(),
        "bio_chbr3": interpolate_biogenic_emissions(ds, "CHBr3").flatten(),
        "bio_dms": interpolate_biogenic_emissions(ds, "DMS").flatten(),
    }).set_index(["time", "level"])

def get_aer_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "aer_3_10_nm": interpolate_aerosol_emissions(ds, "3-10nm").flatten(),
        "aer_10_20_nm": interpolate_aerosol_emissions(ds, "10-20nm").flatten(),
        "aer_20_30_nm": interpolate_aerosol_emissions(ds, "20-30nm").flatten(),
        "aer_30_50_nm": interpolate_aerosol_emissions(ds, "30-50nm").flatten(),
        "aer_50_70_nm": interpolate_aerosol_emissions(ds, "50-70nm").flatten(),
        "aer_70_100_nm": interpolate_aerosol_emissions(ds, "70-100nm").flatten(),
        "aer_100_200_nm": interpolate_aerosol_emissions(ds, "100-200nm").flatten(),
        "aer_200_400_nm": interpolate_aerosol_emissions(ds, "200-400nm").flatten(),
        "aer_400_1000_nm": interpolate_aerosol_emissions(ds, "400-1000nm").flatten(),
    }).set_index(["time", "level"])

def get_ant_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "ant_co": interpolate_anthropogenic_emissions(ds, "co").flatten(),
        "ant_nox": interpolate_anthropogenic_emissions(ds, "nox").flatten(),
        "ant_co2": interpolate_anthropogenic_emissions(ds, "co2").flatten(),
        "ant_nh3": interpolate_anthropogenic_emissions(ds, "nh3").flatten(),
        "ant_ch4": interpolate_anthropogenic_emissions(ds, "ch4").flatten(),
        "ant_so2": interpolate_anthropogenic_emissions(ds, "so2").flatten(),
        "ant_nmvoc": interpolate_anthropogenic_emissions(ds, "nmvoc").flatten(),
        "ant_alcohols": interpolate_anthropogenic_emissions(ds, "alcohols").flatten(),
        "ant_ethane": interpolate_anthropogenic_emissions(ds, "ethane").flatten(),
        "ant_propane": interpolate_anthropogenic_emissions(ds, "propane").flatten(),
        "ant_butanes": interpolate_anthropogenic_emissions(ds, "butanes").flatten(),
        "ant_pentanes": interpolate_anthropogenic_emissions(ds, "pentanes").flatten(),
        "ant_hexanes": interpolate_anthropogenic_emissions(ds, "hexanes").flatten(),
        "ant_ethene": interpolate_anthropogenic_emissions(ds, "ethene").flatten(),
        "ant_propene": interpolate_anthropogenic_emissions(ds, "propene").flatten(),
        "ant_acetylene": interpolate_anthropogenic_emissions(ds, "acetylene").flatten(),
        "ant_isoprene": interpolate_anthropogenic_emissions(ds, "isoprene").flatten(),
        "ant_monoterpenes": interpolate_anthropogenic_emissions(ds, "monoterpenes").flatten(),
        "ant_other_alkenes_and_alkynes": interpolate_anthropogenic_emissions(ds, "other-alkenes-and-alkynes").flatten(),
        "ant_benzene": interpolate_anthropogenic_emissions(ds, "benzene").flatten(),
        "ant_toluene": interpolate_anthropogenic_emissions(ds, "toluene").flatten(),
        "ant_xylene": interpolate_anthropogenic_emissions(ds, "xylene").flatten(),
        "ant_trimethylbenzene": interpolate_anthropogenic_emissions(ds, "trimethylbenzene").flatten(),
        "ant_other_aromatics": interpolate_anthropogenic_emissions(ds, "other-aromatics").flatten(),
        "ant_esters": interpolate_anthropogenic_emissions(ds, "esters").flatten(),
        "ant_ethers": interpolate_anthropogenic_emissions(ds, "ethers").flatten(),
        "ant_formaldehyde": interpolate_anthropogenic_emissions(ds, "formaldehyde").flatten(),
        "ant_other_aldehydes": interpolate_anthropogenic_emissions(ds, "other-aldehydes").flatten(),
        "ant_total_ketones": interpolate_anthropogenic_emissions(ds, "total-ketones").flatten(),
        "ant_total_acids": interpolate_anthropogenic_emissions(ds, "total-acids").flatten(),
        "ant_other_vocs": interpolate_anthropogenic_emissions(ds, "other-VOCs").flatten(),
    }).set_index(["time", "level"])

In [13]:
# https://stackoverflow.com/a/67809235
def df_to_numpy(df):
    try:
        shape = [len(level) for level in df.index.levels]
    except AttributeError:
        shape = [len(df.index)]
    ncol = df.shape[-1]
    if ncol > 1:
        shape.append(ncol)
    return df.to_numpy().reshape(shape)

In [14]:
def generate_time_level_windows():
    # -0.5h, -1.5h, -3h, -6h, -12h, -24h, -48h
    # 0, -2, -5, -11, -23, -47, -95
    time_windows = [(0, 0), (-2, -1), (-5, -3), (-11, -6), (-23, -12), (-47, -24), (-95, -48)]
    
    # +1l, +2l, +4l, +8l, +16l, +32l, +64
    top_windows = [(1, 1), (1, 2), (1, 4), (2, 8), (2, 16), (3, 32), (3, 64)]
    mid_windows = [(0, 0), (0, 0), (0, 0), (-1, 1), (-1, 1), (-2, 2), (-2, 2)]
    bot_windows = [(-1, -1), (-2, -1), (-4, -1), (-8, -2), (-16, -2), (-32, -3), (-64, -3)]
    
    return list(itertools.chain(
        zip(time_windows, top_windows), zip(time_windows, mid_windows), zip(time_windows, bot_windows),
    ))

In [15]:
def generate_windowed_feature_names(columns):
    time_windows = ["-0.5h", "-1.5h", "-3h", "-6h", "-12h", "-24h", "-48h"]
    
    top_windows = ["+1l", "+2l", "+4l", "+8l", "+16l", "+32l", "+64l"]
    mid_windows = ["+0l", "+0l", "+0l", "±1l", "±1l", "±2l", "±2l"]
    bot_windows = ["-1l", "-2l", "-4l", "-8l", "-16l", "-32l", "-64l"]
    
    names = []
    
    for (t, l) in itertools.chain(
        zip(time_windows, top_windows), zip(time_windows, mid_windows), zip(time_windows, bot_windows),
    ):
        for c in columns:
            names.append(f"{c}{t}{l}")
    
    return names

In [16]:
def time_level_window_mean_v1(input, t_range, l_range):
    output = np.zeros(shape=input.shape)

    for t in range(input.shape[0]):
        for l in range(input.shape[1]):
            for f in range(input.shape[2]):
                window = input[
                    min(max(0, t+t_range[0]), input.shape[0]):max(0, min(t+1+t_range[1], input.shape[0])),
                    min(max(0, l+l_range[0]), input.shape[1]):max(0, min(l+1+l_range[1], input.shape[1])),
                    f
                ]

                output[t,l,f] = np.mean(window) if window.size > 0 else 0.0
    
    return output

def time_level_window_mean_v2(input, t_range, l_range):
    output = np.zeros(shape=input.shape)

    for t in range(input.shape[0]):
        mint = min(max(0, t+t_range[0]), input.shape[0])
        maxt = max(0, min(t+1+t_range[1], input.shape[0]))
        
        if mint == maxt:
            continue
        
        for l in range(input.shape[1]):
            minl = min(max(0, l+l_range[0]), input.shape[1])
            maxl = max(0, min(l+1+l_range[1], input.shape[1]))
            
            if minl == maxl:
                continue
                
            output[t,l,:] = np.mean(input[mint:maxt,minl:maxl,:], axis=(0,1))
    
    return output

def time_level_window_mean_v3(input, t_range, l_range):
    min_t = min(t_range[0], 0)
    max_t = max(0, t_range[1])
    abs_t = max(abs(min_t), abs(max_t))
    
    min_l = min(l_range[0], 0)
    max_l = max(0, l_range[1])
    abs_l = max(abs(min_l), abs(max_l))
    
    kernel = np.zeros(shape=(abs_t*2 + 1, abs_l*2 + 1, 1))
    kernel[t_range[0]+abs_t:t_range[1]+abs_t+1,l_range[0]+abs_l:l_range[1]+abs_l+1,:] = 1.0
    kernel = kernel[::-1,::-1]
    
    quot = sp.ndimage.convolve(np.ones_like(input), kernel, mode='constant', cval=0.0)
    
    result = np.zeros_like(input)
    
    np.divide(
        sp.ndimage.convolve(input, kernel, mode='constant', cval=0.0),
        quot, out=result, where=quot > 0,
    )
    
    return result

In [17]:
def get_raw_features_for_dataset(ds: TrajectoryDatasets):
    bio_features = get_bio_emissions_features(ds)
    aer_features = get_aer_emissions_features(ds) * 1e21
    ant_features = get_ant_emissions_features(ds)
    met_features = get_meteorology_features(ds)
    
    return pd.concat([
        bio_features, aer_features, ant_features, met_features,
    ], axis="columns")

In [18]:
def get_features_from_raw_features(raw_features):
    raw_features_np = df_to_numpy(raw_features)
    
    features_np = np.concatenate([
        raw_features.index.get_level_values(0).to_numpy().reshape(
            (raw_features.index.levels[0].size, raw_features.index.levels[1].size, 1)
        ),
        raw_features.index.get_level_values(1).to_numpy().reshape(
            (raw_features.index.levels[0].size, raw_features.index.levels[1].size, 1)
        )
    ] + joblib.Parallel(n_jobs=-1)([
        joblib.delayed(time_level_window_mean_v2)(raw_features_np, t, l) for t, l in generate_time_level_windows()
    ]), axis=2)
    
    # Trim off the first two days, for which the time features are ill-defined
    features_np_trimmed = features_np[95:-1,:,:]
    
    feature_names = ["time", "level"] + generate_windowed_feature_names(raw_features.columns)
    
    features = pd.DataFrame(features_np_trimmed.reshape(
        features_np_trimmed.shape[0]*features_np_trimmed.shape[1], features_np_trimmed.shape[2],
    ), columns=feature_names).set_index(["time", "level"])
    
    return features

In [19]:
def get_labels_for_dataset(ds: TrajectoryDatasets):
    ccn_concentration = get_ccn_concentration(ds)
    
    ccn_concentration_np = df_to_numpy(ccn_concentration)
    
    labels_np = np.concatenate([
        ccn_concentration.index.get_level_values(0).to_numpy().reshape(
            (ccn_concentration.index.levels[0].size, ccn_concentration.index.levels[1].size, 1)
        ),
        ccn_concentration.index.get_level_values(1).to_numpy().reshape(
            (ccn_concentration.index.levels[0].size, ccn_concentration.index.levels[1].size, 1)
        ),
        ccn_concentration_np.reshape(
            (ccn_concentration_np.shape[0], ccn_concentration_np.shape[1], 1)
        ),
    ], axis=2)
    
    # Trim off the first two days, for which the time features are ill-defined
    labels_np_trimmed = labels_np[96:,:,:]
    
    label_names = ["time", "level", "ccn"]
    
    labels = pd.DataFrame(labels_np_trimmed.reshape(
        labels_np_trimmed.shape[0]*labels_np_trimmed.shape[1], labels_np_trimmed.shape[2],
    ), columns=label_names).set_index(["time", "level"])
    
    return labels

In [20]:
def hash_for_dt(dt):
    if not(isinstance(dt, tuple) or isinstance(dt, list)):
        dt = [dt]
    
    dt_str = '.'.join(dtt.strftime('%d.%m.%Y-%H:00%z') for dtt in dt)
    
    h = hashlib.shake_256()
    h.update(dt_str.encode('ascii'))
    
    return h

In [21]:
"""
Clumped 0/1 sampler using a Markov Process

P(0) = p and P(1) = 1-p
clump = 0 => IID samples
clump -> 1 => highly correlated samples

"""
class Clump:
    def __init__(self, p=0.5, clump=0.0, rng=None):
        a = 1 - (1-p)*(1-clump)
        b = (1-a)*p/(1-p)
        
        self.C = np.array([[a, 1-a],[b, 1-b]])
        
        self.i = 0 if rng.random() < p else 1
    
    def sample(self, rng):
        p = self.C[self.i,0]
        u = rng.random()
        
        self.i = 0 if u < p else 1
        
        return self.i
    
    def steady(self, X):
        return np.matmul(X, self.C)

In [22]:
def train_test_split(X, Y, test_size=0.25, random_state=None, shuffle=True, clump=0.0):
    assert len(X) == len(Y)
    assert type(X) == type(Y)
    assert test_size > 0.0
    assert test_size < 1.0
    assert random_state is not None
    assert clump >= 0.0
    assert clump < 1.0
    
    c = Clump(p=test_size, clump=clump, rng=random_state)
    
    if isinstance(X, pd.DataFrame):
        assert X.index.values.shape == Y.index.values.shape
        
        # Split only based on the first-level index instead of flattening
        n1 = len(X.index.levels[1])
        n0 = len(X) // n1
        
        C = np.array([c.sample(random_state) for _ in range(n0)])
        I_train, = np.nonzero(C)
        I_train = np.repeat(I_train, n1) * n1 + np.tile(np.arange(n1), len(I_train))
        I_test, = np.nonzero(1-C)
        I_test = np.repeat(I_test, n1) * n1 + np.tile(np.arange(n1), len(I_test))
    else:
        C = np.array([c.sample(random_state) for _ in range(len(X))])
        I_train, = np.nonzero(C)
        I_test, = np.nonzero(1-C)
    
    if shuffle:
        random_state.shuffle(I_train)
        random_state.shuffle(I_test)
    
    if isinstance(X, pd.DataFrame):
        X_train = X.iloc[I_train]
        X_test = X.iloc[I_test]
        
        Y_train = Y.iloc[I_train]
        Y_test = Y.iloc[I_test]
    else:
        X_train = X[I_train]
        X_test = X[I_test]
        
        Y_train = Y[I_train]
        Y_test = Y[I_test]
    
    return X_train, X_test, Y_train, Y_test

In [23]:
def load_and_cache_dataset(dt: datetime.datetime, clump: float, datasets: dict) -> MLDataset:
    if isinstance(dt, tuple) or isinstance(dt, list):
        dt = tuple(sorted(dt))
    
    cached = datasets.get((dt, clump))
    
    if cached is not None:
        return cached
    
    if isinstance(dt, tuple) or isinstance(dt, list):
        mls = [load_and_cache_dataset(dtt, clump, datasets) for dtt in dt]

        dp = tuple(ml.paths for ml in mls)
        X_raw = pd.concat([ml.X_raw for ml in mls], axis='index')
        Y = pd.concat([ml.Y_raw for ml in mls], axis='index')

        train_features = np.concatenate([ml.X_scaler.inverse_transform(ml.X_train) for ml in mls], axis=0)
        train_labels = np.concatenate([ml.Y_scaler.inverse_transform(ml.Y_train) for ml in mls], axis=0)
        valid_features = np.concatenate([ml.X_scaler.inverse_transform(ml.X_valid) for ml in mls], axis=0)
        valid_labels = np.concatenate([ml.Y_scaler.inverse_transform(ml.Y_valid) for ml in mls], axis=0)
        test_features = np.concatenate([ml.X_scaler.inverse_transform(ml.X_test) for ml in mls], axis=0)
        test_labels = np.concatenate([ml.Y_scaler.inverse_transform(ml.Y_test) for ml in mls], axis=0)
    else:
        dp = get_path_for_perturbation(dt, Path("baseline"))
        ds = load_trajectory_dataset(dp)

        X_raw = get_raw_features_for_dataset(ds)

        X = get_features_from_raw_features(X_raw)
        Y = np.log10(get_labels_for_dataset(ds) + 1)

        rng = np.random.RandomState(seed=int.from_bytes(hash_for_dt(dt).digest(4), 'little'))

        train_features, test_features, train_labels, test_labels = train_test_split(
            X, Y, test_size=0.25, random_state=rng, clump=clump,
        )
        train_features, valid_features, train_labels, valid_labels = train_test_split(
            train_features, train_labels, test_size=1.0/3.0, random_state=rng, clump=clump,
        )

        # Close the NetCDF datasets
        ds.out.close()
        ds.aer.close()
        ds.ant.close()
        ds.bio.close()
        ds.met.close()

    # Scale features to N(0,1)
    # - only fit on training data
    # - OOD inputs for constants at training time are blown up
    feature_scaler = StandardScaler().fit(train_features)
    feature_scaler.scale_[np.nonzero(feature_scaler.var_ == 0.0)] = np.nan_to_num(np.inf)

    label_scaler = StandardScaler().fit(train_labels)

    train_features = feature_scaler.transform(train_features)
    train_labels = label_scaler.transform(train_labels)
    valid_features = feature_scaler.transform(valid_features)
    valid_labels = label_scaler.transform(valid_labels)
    test_features = feature_scaler.transform(test_features)
    test_labels = label_scaler.transform(test_labels)

    dataset = MLDataset(
        date=dt, paths=dp, X_raw=X_raw, Y_raw=Y,
        X_train=train_features, X_valid=valid_features, X_test=test_features,
        Y_train=train_labels, Y_valid=valid_labels, Y_test=test_labels,
        X_scaler=feature_scaler, Y_scaler=label_scaler,
    )

    datasets[(dt, clump)] = dataset
    
    return dataset

In [24]:
DATASETS = dict()

In [25]:
perturbations = [
    Path("anthropogenic")/ "mul_1.5",
    Path("anthropogenic")/ "div_1.5",

    Path("biogenic")/ "mul_1.5",
    Path("biogenic")/ "div_1.5",

    Path("aerosols")/ "mul_1.5",
    Path("aerosols")/ "div_1.5",

    Path("monoterpenes")/ "mul_1.5",
    Path("monoterpenes")/ "div_1.5",

    Path("sesquiterpenes")/ "mul_1.5",
    Path("sesquiterpenes")/ "div_1.5",

    Path("so2")/ "mul_1.5",
    Path("so2")/ "div_1.5",

    Path("nox")/ "mul_1.5",
    Path("nox")/ "div_1.5",

    Path("temperature")/ "add_2K",
    Path("temperature")/ "sub_2K",
    
    Path("anthropogenic")/ "mul_1.01",
    Path("anthropogenic")/ "div_1.01",

    Path("biogenic")/ "mul_1.01",
    Path("biogenic")/ "div_1.01",

    Path("aerosols")/ "mul_1.01",
    Path("aerosols")/ "div_1.01",

    Path("monoterpenes")/ "mul_1.01",
    Path("monoterpenes")/ "div_1.01",

    Path("sesquiterpenes")/ "mul_1.01",
    Path("sesquiterpenes")/ "div_1.01",

    Path("so2")/ "mul_1.01",
    Path("so2")/ "div_1.01",

    Path("nox")/ "mul_1.01",
    Path("nox")/ "div_1.01",

    Path("temperature")/ "add_0.04K",
    Path("temperature")/ "sub_0.04K",
]

In [26]:
dates = [
    datetime.datetime(year=2018, month=5, day=14, hour=10),
    datetime.datetime(year=2018, month=5, day=15, hour=19),
    datetime.datetime(year=2018, month=5, day=17, hour=0),
    datetime.datetime(year=2018, month=5, day=19, hour=4),
    datetime.datetime(year=2018, month=5, day=21, hour=15),
    datetime.datetime(year=2018, month=5, day=23, hour=13),
]

In [27]:
from matplotlib.ticker import ScalarFormatter

# Based on https://stackoverflow.com/a/42156450
class ScalarFormatterForceSignedFormat(ScalarFormatter):
    def _set_format(self):
        self.format = "%+1.1f"
        
class ScalarFormatterForceFormat(ScalarFormatter):
    def _set_format(self):
        self.format = "%1.1f"

In [28]:
for perturbation in perturbations:
    for dt in dates:
        fig, ax = plt.subplots(1, 1, figsize=(2, 2))
        
        dp = get_path_for_perturbation(dt, Path("baseline"))
        ds = load_trajectory_dataset(dp)

        Y_base = get_labels_for_dataset(ds) # np.log10(get_labels_for_dataset(ds) + 1.0)

        # Close the NetCDF datasets
        ds.out.close()
        ds.aer.close()
        ds.ant.close()
        ds.bio.close()
        ds.met.close()
        
        dp = get_path_for_perturbation(dt, Path("perturbation") / perturbation)
        ds = load_trajectory_dataset(dp)

        Y_pert = get_labels_for_dataset(ds) # np.log10(get_labels_for_dataset(ds) + 1.0)

        # Close the NetCDF datasets
        ds.out.close()
        ds.aer.close()
        ds.ant.close()
        ds.bio.close()
        ds.met.close()
        
        ax.scatter(
            Y_base["ccn"], Y_pert["ccn"]-Y_base["ccn"], s=1,
            c=Y_pert.index.get_level_values(0), cmap=time_cmap,
            rasterized=True,
        )
        
        xlim = ax.get_xlim()
        ax.plot(xlim, [0, 0], c="black", lw=1)
        ax.set_xlim(xlim)
    
        xfmt = ScalarFormatterForceFormat(useOffset=False, useMathText=True)
        xfmt.set_powerlimits((-1, 1))
        ax.xaxis.set_major_formatter(xfmt)
        
        ax.xaxis.get_offset_text().set_visible(False)
        ax_max = max(ax.get_xticks())
        exponent_axis = np.floor(np.log10(ax_max)).astype(int)
        ax.annotate(
            fr"$\times 10^{{{exponent_axis}}}$", xy=(0.97, 0.03),
            xycoords="axes fraction", ha="right", va="bottom", path_effects=outlined,
        )
    
        yfmt = ScalarFormatterForceSignedFormat(useOffset=False, useMathText=True)
        yfmt.set_powerlimits((-1, 1))
        ax.yaxis.set_major_formatter(yfmt)
        
        ax.yaxis.get_offset_text().set_visible(False)
        ax_max = max(ax.get_yticks())
        exponent_axis = np.floor(np.log10(ax_max)).astype(int)
        ax.annotate(
            fr"$\times 10^{{{exponent_axis}}}$", xy=(0.03, 0.97),
            xycoords="axes fraction", ha="left", va="top", path_effects=outlined,
        )

        ax.set_title(dt.strftime("%d.%m.%Y %H:%M"))
    
        pp = str(perturbation).replace('/', '-').replace('_', '-')
        tp = dt.strftime('%d.%m.%Y:%H.%M')
    
        plt.savefig(
            f"perturbation-{tp}-{pp}.pdf", dpi=100, transparent=True, bbox_inches='tight',
        )
        # plt.show()
        plt.close(fig)