In [1]:
import datetime
import hashlib
import itertools
import json
import os
import pickle
import re
import warnings

from collections import namedtuple
from copy import copy
from pathlib import Path

In [2]:
import cartopy.feature as cfeature
import cartopy.crs as ccrs

import joblib

import matplotlib as mpl
from matplotlib import gridspec
from matplotlib import pyplot as plt

import netCDF4
from netCDF4 import Dataset

import numpy as np

import pandas as pd

import scipy as sp

import sklearn
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.preprocessing import StandardScaler

import tqdm

In [3]:
import colorcet as cc

time_cmap = cc.m_CET_L19
temporal_cmap = cc.m_CET_L8

In [4]:
height_cmap = mpl.cm.rainbow

In [5]:
TrajectoryPaths = namedtuple("TrajectoryPaths", ["date", "out", "aer", "ant", "bio", "met"])
TrajectoryDatasets = namedtuple("TrajectoryDatasets", ["date", "out", "aer", "ant", "bio", "met"])
MLDataset = namedtuple("MLDataset", ["date", "paths", "X_raw", "Y_raw", "X_train", "X_valid", "X_test", "Y_train", "Y_valid", "Y_test", "X_scaler", "Y_scaler"])

In [6]:
OUTDIR_PATTERN = re.compile(r"(\d{4})(\d{2})(\d{2})_T(\d{2})")

In [7]:
traj_datetimes = dict()

base = Path.cwd().parent / "trajectories"

for child in (base / "outputs" / "baseline").iterdir():
    if not child.is_dir():
        continue
    
    match = OUTDIR_PATTERN.match(child.name)
    
    if match is None:
        continue
        
    date = datetime.datetime(
        year=int(match.group(1)),
        month=int(match.group(2)),
        day=int(match.group(3)),
        hour=int(match.group(4)),
    )
    
    out_path = child / "output.nc"
    aer_path = (
        base / "inputs" / "baseline" / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{date.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{date.strftime('%Y%m%d')}_7daybwd_Hyde_traj_AER_{24-date.hour:02}_L3.nc"
    )
    ant_path = (
        base / "inputs" / "baseline" / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{date.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{date.strftime('%Y%m%d')}_7daybwd_Hyde_traj_ANT_{24-date.hour:02}_L3.nc"
    )
    bio_path = (
        base / "inputs" / "baseline" / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{date.strftime('%Y%m%d')}" /
        "EMISSIONS_0422" /
        f"{date.strftime('%Y%m%d')}_7daybwd_Hyde_traj_BIO_{24-date.hour:02}_L3.nc"
    )
    met_path = (
        base / "inputs" / "baseline" / "HYDE_BASE_Y2018" /
        f"OUTPUT_bwd_{date.strftime('%Y%m%d')}" /
        "METEO" /
        f"METEO_{date.strftime('%Y%m%d')}_R{24-date.hour:02}.nc"
    )
    
    if (
        (not out_path.exists()) or (not aer_path.exists()) or
        (not ant_path.exists()) or (not bio_path.exists()) or
        (not met_path.exists())
    ):
        raise Exception(out_path, aer_path, ant_path, bio_path, met_path)
    
    traj_datetimes[date] = TrajectoryPaths(
        date=date, out=out_path, aer=aer_path, ant=ant_path, bio=bio_path, met=met_path,
    )

traj_dates = sorted(set(d.date() for d in traj_datetimes.keys()))

In [8]:
def load_trajectory_dataset(paths: TrajectoryPaths) -> TrajectoryDatasets:
    outds = Dataset(paths.out, "r", format="NETCDF4")
    aerds = Dataset(paths.aer, "r", format="NETCDF4")
    antds = Dataset(paths.ant, "r", format="NETCDF4")
    biods = Dataset(paths.bio, "r", format="NETCDF4")
    metds = Dataset(paths.met, "r", format="NETCDF4")
    
    return TrajectoryDatasets(
        date=paths.date, out=outds, aer=aerds, ant=antds, bio=biods, met=metds,
    )

In [9]:
data_proj = ccrs.PlateCarree()
projection = ccrs.LambertConformal(
    central_latitude=50, central_longitude=20, standard_parallels=(25, 25)
)
extent = [-60, 60, 40, 80]

In [10]:
def get_ccn_concentration(ds: TrajectoryDatasets):
    ccn_bin_indices, = np.nonzero(ds.out["dp_dry_fs"][:].data > 80e-9)
    ccn_concentration = np.sum(ds.out["nconc_par"][:].data[:,ccn_bin_indices,:], axis=1)
    
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "ccn": ccn_concentration.flatten(),
    }).set_index(["time", "level"])

In [11]:
def get_output_time(ds: TrajectoryDatasets):
    fdom = datetime.datetime.strptime(
        ds.out["time"].__dict__["first_day_of_month"], "%Y-%m-%d %H:%M:%S",
    )
    dt = (ds.date - fdom).total_seconds()
    
    out_t = ds.out["time"][:].data
    
    return out_t - dt

def interpolate_meteorology_values(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    met_t = ds.met["time"][:].data
    met_h = ds.met["lev"][:].data
    
    met_t_h = ds.met[key][:]
    
    met_t_h_int = sp.interpolate.interp2d(
        x=met_h, y=met_t, z=met_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return met_t_h_int(x=out_h, y=out_t)

def interpolate_meteorology_time_values(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    met_t = ds.met["time"][:].data
    
    met_t_v = ds.met[key][:]
    
    met_t_int = sp.interpolate.interp1d(
        x=met_t, y=met_t_v, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return np.repeat(
        met_t_int(x=out_t).reshape(-1, 1),
        out_h.shape[0], axis=1,
    )

def interpolate_biogenic_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    # depth of each box layer, assuming level heights are midpoints and end points are clamped
    out_d = (np.array(list(out_h[1:])+[out_h[-1]]) - np.array([out_h[0]]+list(out_h[:-1]))) / 2.0
    
    bio_t = ds.bio["time"][:].data
    
    # Biogenic emissions are limited to boxes at <= 10m height
    biogenic_emission_layers = np.nonzero(out_h <= 10.0)
    biogenic_emission_layer_height_cumsum = np.cumsum(out_d[biogenic_emission_layers])
    biogenic_emission_layer_proportion = biogenic_emission_layer_height_cumsum / biogenic_emission_layer_height_cumsum[-1]
    num_biogenic_emission_layers = sum(out_h <= 10.0)
    
    bio_t_h = np.zeros(shape=(out_t.size, out_h.size))
    
    bio_t_int = sp.interpolate.interp1d(
        x=bio_t, y=ds.bio[key][:], kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    # Split up the biogenic emissions relative to the depth of the boxes
    bio_t_h[:,biogenic_emission_layers] = (
        np.tile(bio_t_int(x=out_t), (num_biogenic_emission_layers, 1, 1)) * biogenic_emission_layer_proportion.reshape(-1, 1, 1)
    ).T
    
    return bio_t_h

def interpolate_aerosol_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    aer_t = ds.aer["time"][:].data
    aer_h = ds.aer["mid_layer_height"][:].data
    
    aer_t_h = ds.aer[key][:].T
    
    aer_t_h_int = sp.interpolate.interp2d(
        x=aer_h, y=aer_t, z=aer_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return aer_t_h_int(x=out_h, y=out_t)

def interpolate_anthropogenic_emissions(ds: TrajectoryDatasets, key: str):
    out_t = get_output_time(ds)
    out_h = ds.out["lev"][:].data
    
    ant_t = ds.ant["time"][:].data
    ant_h = ds.ant["mid_layer_height"][:].data
    
    ant_t_h = ds.ant[key][:].T
    
    ant_t_h_int = sp.interpolate.interp2d(
        x=ant_h, y=ant_t, z=ant_t_h, kind="linear", bounds_error=False, fill_value=0.0,
    )
    
    return ant_t_h_int(x=out_h, y=out_t)

In [12]:
def get_meteorology_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "met_t": interpolate_meteorology_values(ds, "t").flatten(),
        # "met_u": interpolate_meteorology_values(ds, "u").flatten(),
        # "met_v": interpolate_meteorology_values(ds, "v").flatten(),
        "met_q": interpolate_meteorology_values(ds, "q").flatten(),
        # "met_qc": interpolate_meteorology_values(ds, "qc").flatten(),
        # "met_sp": interpolate_meteorology_time_values(ds, "sp").flatten(),
        # "met_cp": interpolate_meteorology_time_values(ds, "cp").flatten(),
        # "met_sshf": interpolate_meteorology_time_values(ds, "sshf").flatten(),
        "met_ssr": interpolate_meteorology_time_values(ds, "ssr").flatten(),
        # "met_lsp": interpolate_meteorology_time_values(ds, "lsp").flatten(),
        # "met_ewss": interpolate_meteorology_time_values(ds, "ewss").flatten(),
        # "met_nsss": interpolate_meteorology_time_values(ds, "nsss").flatten(),
        # "met_tcc": interpolate_meteorology_time_values(ds, "tcc").flatten(),
        "met_lsm": interpolate_meteorology_time_values(ds, "lsm").flatten(),
        # "met_omega": interpolate_meteorology_values(ds, "omega").flatten(),
        # "met_z": interpolate_meteorology_time_values(ds, "z").flatten(),
        # "met_mla": interpolate_meteorology_values(ds, "mla").flatten(),
        # NOTE: lp is excluded because it allows the model to overfit
        # "met_lp": interpolate_meteorology_values(ds, "lp").flatten(),
        "met_blh": interpolate_meteorology_time_values(ds, "blh").flatten(),
    }).set_index(["time", "level"])

def get_bio_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "bio_acetaldehyde": interpolate_biogenic_emissions(ds, "acetaldehyde").flatten(),
        "bio_acetone": interpolate_biogenic_emissions(ds, "acetone").flatten(),
        "bio_butanes_and_higher_alkanes": interpolate_biogenic_emissions(ds, "butanes-and-higher-alkanes").flatten(),
        "bio_butanes_and_higher_alkenes": interpolate_biogenic_emissions(ds, "butenes-and-higher-alkenes").flatten(),
        "bio_ch4": interpolate_biogenic_emissions(ds, "CH4").flatten(),
        "bio_co": interpolate_biogenic_emissions(ds, "CO").flatten(),
        "bio_ethane": interpolate_biogenic_emissions(ds, "ethane").flatten(),
        "bio_ethanol": interpolate_biogenic_emissions(ds, "ethanol").flatten(),
        "bio_ethene": interpolate_biogenic_emissions(ds, "ethene").flatten(),
        "bio_formaldehyde": interpolate_biogenic_emissions(ds, "formaldehyde").flatten(),
        "bio_hydrogen_cyanide": interpolate_biogenic_emissions(ds, "hydrogen-cyanide").flatten(),
        "bio_iosprene": interpolate_biogenic_emissions(ds, "isoprene").flatten(),
        "bio_mbo": interpolate_biogenic_emissions(ds, "MBO").flatten(),
        "bio_methanol": interpolate_biogenic_emissions(ds, "methanol").flatten(),
        "bio_methyl_bromide": interpolate_biogenic_emissions(ds, "methyl-bromide").flatten(),
        "bio_methyl_chloride": interpolate_biogenic_emissions(ds, "methyl-chloride").flatten(),
        "bio_methyl_iodide": interpolate_biogenic_emissions(ds, "methyl-iodide").flatten(),
        "bio_other_aldehydes": interpolate_biogenic_emissions(ds, "other-aldehydes").flatten(),
        "bio_other_ketones": interpolate_biogenic_emissions(ds, "other-ketones").flatten(),
        "bio_other_monoterpenes": interpolate_biogenic_emissions(ds, "other-monoterpenes").flatten(),
        "bio_pinene_a": interpolate_biogenic_emissions(ds, "pinene-a").flatten(),
        "bio_pinene_b": interpolate_biogenic_emissions(ds, "pinene-b").flatten(),
        "bio_propane": interpolate_biogenic_emissions(ds, "propane").flatten(),
        "bio_propene": interpolate_biogenic_emissions(ds, "propene").flatten(),
        "bio_sesquiterpenes": interpolate_biogenic_emissions(ds, "sesquiterpenes").flatten(),
        "bio_toluene": interpolate_biogenic_emissions(ds, "toluene").flatten(),
        "bio_ch2br2": interpolate_biogenic_emissions(ds, "CH2Br2").flatten(),
        "bio_ch3i": interpolate_biogenic_emissions(ds, "CH3I").flatten(),
        "bio_chbr3": interpolate_biogenic_emissions(ds, "CHBr3").flatten(),
        "bio_dms": interpolate_biogenic_emissions(ds, "DMS").flatten(),
    }).set_index(["time", "level"])

def get_aer_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "aer_3_10_nm": interpolate_aerosol_emissions(ds, "3-10nm").flatten(),
        "aer_10_20_nm": interpolate_aerosol_emissions(ds, "10-20nm").flatten(),
        "aer_20_30_nm": interpolate_aerosol_emissions(ds, "20-30nm").flatten(),
        "aer_30_50_nm": interpolate_aerosol_emissions(ds, "30-50nm").flatten(),
        "aer_50_70_nm": interpolate_aerosol_emissions(ds, "50-70nm").flatten(),
        "aer_70_100_nm": interpolate_aerosol_emissions(ds, "70-100nm").flatten(),
        "aer_100_200_nm": interpolate_aerosol_emissions(ds, "100-200nm").flatten(),
        "aer_200_400_nm": interpolate_aerosol_emissions(ds, "200-400nm").flatten(),
        "aer_400_1000_nm": interpolate_aerosol_emissions(ds, "400-1000nm").flatten(),
    }).set_index(["time", "level"])

def get_ant_emissions_features(ds: TrajectoryDatasets):
    return pd.DataFrame({
        "time": np.repeat(get_output_time(ds), ds.out["lev"].shape[0]),
        "level": np.tile(ds.out["lev"][:].data, ds.out["time"].shape[0]),
        "ant_co": interpolate_anthropogenic_emissions(ds, "co").flatten(),
        "ant_nox": interpolate_anthropogenic_emissions(ds, "nox").flatten(),
        "ant_co2": interpolate_anthropogenic_emissions(ds, "co2").flatten(),
        "ant_nh3": interpolate_anthropogenic_emissions(ds, "nh3").flatten(),
        "ant_ch4": interpolate_anthropogenic_emissions(ds, "ch4").flatten(),
        "ant_so2": interpolate_anthropogenic_emissions(ds, "so2").flatten(),
        "ant_nmvoc": interpolate_anthropogenic_emissions(ds, "nmvoc").flatten(),
        "ant_alcohols": interpolate_anthropogenic_emissions(ds, "alcohols").flatten(),
        "ant_ethane": interpolate_anthropogenic_emissions(ds, "ethane").flatten(),
        "ant_propane": interpolate_anthropogenic_emissions(ds, "propane").flatten(),
        "ant_butanes": interpolate_anthropogenic_emissions(ds, "butanes").flatten(),
        "ant_pentanes": interpolate_anthropogenic_emissions(ds, "pentanes").flatten(),
        "ant_hexanes": interpolate_anthropogenic_emissions(ds, "hexanes").flatten(),
        "ant_ethene": interpolate_anthropogenic_emissions(ds, "ethene").flatten(),
        "ant_propene": interpolate_anthropogenic_emissions(ds, "propene").flatten(),
        "ant_acetylene": interpolate_anthropogenic_emissions(ds, "acetylene").flatten(),
        "ant_isoprene": interpolate_anthropogenic_emissions(ds, "isoprene").flatten(),
        "ant_monoterpenes": interpolate_anthropogenic_emissions(ds, "monoterpenes").flatten(),
        "ant_other_alkenes_and_alkynes": interpolate_anthropogenic_emissions(ds, "other-alkenes-and-alkynes").flatten(),
        "ant_benzene": interpolate_anthropogenic_emissions(ds, "benzene").flatten(),
        "ant_toluene": interpolate_anthropogenic_emissions(ds, "toluene").flatten(),
        "ant_xylene": interpolate_anthropogenic_emissions(ds, "xylene").flatten(),
        "ant_trimethylbenzene": interpolate_anthropogenic_emissions(ds, "trimethylbenzene").flatten(),
        "ant_other_aromatics": interpolate_anthropogenic_emissions(ds, "other-aromatics").flatten(),
        "ant_esters": interpolate_anthropogenic_emissions(ds, "esters").flatten(),
        "ant_ethers": interpolate_anthropogenic_emissions(ds, "ethers").flatten(),
        "ant_formaldehyde": interpolate_anthropogenic_emissions(ds, "formaldehyde").flatten(),
        "ant_other_aldehydes": interpolate_anthropogenic_emissions(ds, "other-aldehydes").flatten(),
        "ant_total_ketones": interpolate_anthropogenic_emissions(ds, "total-ketones").flatten(),
        "ant_total_acids": interpolate_anthropogenic_emissions(ds, "total-acids").flatten(),
        "ant_other_vocs": interpolate_anthropogenic_emissions(ds, "other-VOCs").flatten(),
    }).set_index(["time", "level"])

In [13]:
# https://stackoverflow.com/a/67809235
def df_to_numpy(df):
    try:
        shape = [len(level) for level in df.index.levels]
    except AttributeError:
        shape = [len(df.index)]
    ncol = df.shape[-1]
    if ncol > 1:
        shape.append(ncol)
    return df.to_numpy().reshape(shape)

In [14]:
def get_raw_features_for_dataset(ds: TrajectoryDatasets):
    bio_features = get_bio_emissions_features(ds)
    aer_features = get_aer_emissions_features(ds) * 1e21
    ant_features = get_ant_emissions_features(ds)
    met_features = get_meteorology_features(ds)
    
    return pd.concat([
        bio_features, aer_features, ant_features, met_features,
    ], axis="columns")

In [15]:
def get_labels_for_dataset(ds: TrajectoryDatasets):
    ccn_concentration = get_ccn_concentration(ds)
    
    ccn_concentration_np = df_to_numpy(ccn_concentration)
    
    labels_np = np.concatenate([
        ccn_concentration.index.get_level_values(0).to_numpy().reshape(
            (ccn_concentration.index.levels[0].size, ccn_concentration.index.levels[1].size, 1)
        ),
        ccn_concentration.index.get_level_values(1).to_numpy().reshape(
            (ccn_concentration.index.levels[0].size, ccn_concentration.index.levels[1].size, 1)
        ),
        ccn_concentration_np.reshape(
            (ccn_concentration_np.shape[0], ccn_concentration_np.shape[1], 1)
        ),
    ], axis=2)
    
    # Trim off the first two days, for which the time features are ill-defined
    labels_np_trimmed = labels_np[96:,:,:]
    
    label_names = ["time", "level", "ccn"]
    
    labels = pd.DataFrame(labels_np_trimmed.reshape(
        labels_np_trimmed.shape[0]*labels_np_trimmed.shape[1], labels_np_trimmed.shape[2],
    ), columns=label_names).set_index(["time", "level"])
    
    return labels

In [16]:
def plot_feature_violin(ax, X_raw, f, title, unit, zero):
    X = X_raw[f].to_numpy()
    N = np.nonzero(X)
    
    vp = ax.violinplot(X[N] + zero, showextrema=False)
    vp["bodies"][0].remove()#set_rasterized(True)

    tmpX = X[N] + zero
    tmpT = X_raw[f].index.get_level_values(0).to_numpy()[N]
    colours = time_cmap(np.linspace(0, 1, X_raw[f].index.levels[0].size))
    tmpI = np.random.choice(len(tmpX), size=10000)
    tmpT = np.searchsorted(X_raw[f].index.levels[0], tmpT[tmpI])
    tmpX = tmpX[tmpI]
    colours = colours[tmpT]
    
    vp = vp["bodies"][0].get_paths()[0]

    tmpY = []

    for x in tmpX:
        s = 1.0

        for _ in range(10):
            y = (np.random.random() - 0.5) * 0.5 * s + 1.0

            if vp.contains_point((y, x)):
                break

            s = abs(y-1.0)*4.0

        tmpY.append(y)

    ax.scatter(tmpY, tmpX, s=1, c=colours, rasterized=True)

    ax.set_title(title)
    ax.get_xaxis().set_visible(False)
    ax.set_ylabel(unit)

In [17]:
np.random.seed(42)

for ht in tqdm.tqdm([
    130, 163, 192, 244, 303, 349,
]):
    fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=100)
    ax.remove()
    ax = fig.add_subplot(1, 1, 1, projection=projection)

    ax.set_extent(extent, crs=data_proj)
    ax.coastlines(linewidth=0.5, color="k", resolution="110m", rasterized=True)
    ax.gridlines(alpha=0.75, draw_labels=True, rasterized=True)
    ax.add_feature(cfeature.LAND, rasterized=True)
    ax.add_feature(cfeature.OCEAN, rasterized=True)
    
    hcs = temporal_cmap(np.linspace(0.1, 0.9, 9))
    hcs = list(hcs[[0, 8, 2, 6, 3, 5]]) + ['black']
    
    trajs = []
    
    for h, hc, lw in zip([ht-4, ht+4, ht-2, ht+2, ht-1, ht+1, ht], hcs, [3, 3, 5, 5, 7, 7, 10]):
        dt = datetime.datetime(
            year=2018, month=5, day=9+h//24, hour=h%24,
        )

        ds = load_trajectory_dataset(traj_datetimes[dt])

        X_raw = get_raw_features_for_dataset(ds)
        X_raw_np = np.concatenate([
            X_raw.index.get_level_values(0).to_numpy().reshape(
                (X_raw.index.levels[0].size, X_raw.index.levels[1].size, 1)
            ),
            X_raw.index.get_level_values(1).to_numpy().reshape(
                (X_raw.index.levels[0].size, X_raw.index.levels[1].size, 1)
            ),
            df_to_numpy(X_raw),
        ], axis=2)[95:-1:,:]
        X_raw = pd.DataFrame(X_raw_np.reshape(
            X_raw_np.shape[0]*X_raw_np.shape[1], X_raw_np.shape[2],
        ), columns=["time", "level"] + list(X_raw.columns)).set_index(["time", "level"])

        Y_raw = get_labels_for_dataset(ds)
        
        trajs.append(ax.plot(
            ds.out["lon"][-Y_raw.index.levels[0].size:],
            ds.out["lat"][-Y_raw.index.levels[0].size::],
            lw=lw,
            transform=data_proj,
            c=hc,
            zorder=-1,
            solid_capstyle='round',
            rasterized=True,
        )[0])   

    colours = time_cmap(np.linspace(0, 1, Y_raw.index.levels[0].size))

    ax.scatter(
        ds.out["lon"][-Y_raw.index.levels[0].size::],
        ds.out["lat"][-Y_raw.index.levels[0].size::],
        s=10,
        transform=data_proj,
        c=colours,
        rasterized=True,
    )

    ax.text(
        0.5, 1.0, ds.date.strftime("%d.%m.%Y %H:00%z"), ha="center", va="top",
        size=20, c="black", bbox=dict(facecolor='white', alpha=0.5, edgecolor='black'),
    )
    
    ax.legend(
        list(np.array(trajs)[[0,2,4,6,5,3,1]]), ["$t-4$h", "$t-2$h", "$t-1$h", "$t$", "$t+1$h", "$t+2$h", "$t+4$h"],
        loc='upper center', ncol=7, handlelength=0.25, columnspacing=0.8,
    )
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        
        plt.savefig(f"trajectory-{dt.strftime('%d.%m.%Y:%H.%M')}.pdf", dpi=100, transparent=True, bbox_inches='tight')
        # plt.show()
        plt.close(fig)
        
    def sum_pf_series(X, fs):
        if len(fs) < 2:
            return X[fs[0]]
        
        acc = X[fs[0]] + X[fs[1]]
        
        for f in fs[2:]:
            acc += X[f]
        
        return acc
    
    monoterpenes = ['bio_other_monoterpenes', 'bio_pinene_a', 'bio_pinene_b']
    exclude = [
        'ant_ch4', 'ant_so2', 'ant_nh3', 'ant_nox', 'ant_co', 'bio_ch4', 'bio_dms',
        'bio_co', 'bio_chbr3', 'bio_ch3i', 'bio_ch2br2', 'bio_other_monoterpenes',
        'bio_pinene_a', 'bio_pinene_b', 'bio_sesquiterpenes',
    ]
    
    X_raw["anthropogenic"] = np.log10(
        sum_pf_series(X_raw, [f for f in X_raw.columns if f.startswith("ant_") and f not in exclude]) * 1e15 + 1
    )
    X_raw["biogenic"] = np.log10(
        sum_pf_series(X_raw, [f for f in X_raw.columns if f.startswith("bio_") and f not in exclude]) * 1e15 + 1
    )
    X_raw["aerosols"] = np.log10(
        sum_pf_series(X_raw, [f for f in X_raw.columns if f.startswith("ant_")]) * 1e21 + 1
    )
    X_raw["monoterpenes"] = np.log10(
        sum_pf_series(X_raw, monoterpenes) * 1e15 + 1
    )
    X_raw["sesquiterpenes"] = np.log10(X_raw["bio_sesquiterpenes"] * 1e15 + 1)
    X_raw["so2"] = np.log10(X_raw["ant_so2"] * 1e15 + 1)
    X_raw["nox"] = np.log10(X_raw["ant_nox"] * 1e15 + 1)
    X_raw["temperature"] = X_raw["met_t"]
    
    for f, t, u, z in zip([
        "anthropogenic", "biogenic", "aerosols", "monoterpenes", "sesquiterpenes", "so2", "nox", "temperature"
    ], [
        "Anthropogenic Emissions", "Non-terpene Biogenic Emissions", "Aerosol Emissions",
        "Monoterpene Emissions", "Sesquiterpene Emissions", "Anthropogenic SO$_2$ Emissions",
        "Anthropogenic NO$_x$ Emissions", "Air Temperature",
    ], [
        r"log$_{10}(E \times 10^{15}$ kg m$^{-2}$ s$^{-1})$", r"log$_{10}(E \times 10^{15}$ kg m$^{-2}$ s$^{-1})$",
        r"log$_{10}(E \times 10^{21}$ kg m$^{-2}$ s$^{-1})$", r"log$_{10}(E \times 10^{15}$ kg m$^{-2}$ s$^{-1})$",
        r"log$_{10}(E \times 10^{15}$ kg m$^{-2}$ s$^{-1})$", r"log$_{10}(E \times 10^{15}$ kg m$^{-2}$ s$^{-1})$",
        r"log$_{10}(E \times 10^{15}$ kg m$^{-2}$ s$^{-1})$", "$\degree$C",
    ], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -273.15]):
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
        
        plot_feature_violin(ax, X_raw, f, t, u, z)
        
        plt.savefig(f"trajectory-{dt.strftime('%d.%m.%Y:%H.%M')}-{f}.pdf", dpi=100, transparent=True, bbox_inches='tight')
        # plt.show()
        plt.close(fig)

    ccn_concentration = Y_raw
        
    level_mask = ccn_concentration.index.get_level_values(1)
    level_heights = ccn_concentration.index.levels[1]

    lcolours = height_cmap(np.linspace(0, 1, len(level_heights)))

    fig, ax1 = plt.subplots(1, 1, figsize=(6, 4))
    
    for l, h in enumerate(level_heights):
        ax1.plot(
            ccn_concentration[level_mask == h].index.get_level_values(0) / (60*60),
            ccn_concentration[level_mask == h]["ccn"], c=lcolours[l],
        )
        
    ax1.set_yscale("log")
        
    ylim = ax1.get_ylim()
        
    ax1.scatter(
        ccn_concentration.index.levels[0] / (60*60),
        [ylim[0]] * ccn_concentration.index.levels[0].size,
        c=colours,
    )
    
    ax1.set_ylim(ylim)
    
    ax1.set_xticks([
        h for h in range(
            int((ccn_concentration.index.levels[0][0] // (60*60))), 0, 24
        )
    ])
    ax1.set_xticklabels([
        (ds.date + datetime.timedelta(hours=h)).strftime("%d.%m") for h in range(
            int((ccn_concentration.index.levels[0][0] // (60*60))), 0, 24
        )
    ])
    
    ax1.text(
        0.05, 0.075, ds.date.strftime("%d.%m.%Y %H:00%z"), ha="left", va="bottom",
        size=20, c="black", bbox=dict(facecolor='white', alpha=0.5, edgecolor='black'),
        transform=ax1.transAxes,
    )
    
    ax1.set_ylabel("CCN concentration [m${}^{-3}$]")
    
    ax2 = ax1.twinx()
    
    xlim = ax1.get_xlim()
    
    ax2.scatter(
        [xlim[1]] * X_raw.index.levels[1].size,
        X_raw.index.levels[1],
        c=lcolours,
    )
    
    ax1.set_xlim(xlim)
    ax2.set_xlim(xlim)
    
    ax2.invert_yaxis()
    ax2.set_yticks(X_raw.index.levels[1])
    ax2.set_yscale("log")
    
    ax2.set_ylabel("SOSAA box level height [m]")
    
    plt.savefig(f"trajectory-{dt.strftime('%d.%m.%Y:%H.%M')}-ccn.pdf", dpi=100, transparent=True, bbox_inches='tight')
    # plt.show()
    plt.close(fig)

100%|███████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:53<00:00,  8.86s/it]
