In [1]:
import datetime
import os

from pathlib import Path

import adjustText
import matplotlib as mpl
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt

In [2]:
import matplotlib.patheffects

outlined = [
    matplotlib.patheffects.Stroke(linewidth=3, foreground="white"),
    matplotlib.patheffects.Normal(),
]

In [3]:
class TruncatedColormap(mpl.colors.Colormap):
    def __init__(self, cmap):
        self.cmap = cmap
        self.N = cmap.N

    def __call__(self, X, alpha=None, bytes=False):
        return self.cmap.__call__(X/1.1, alpha=alpha, bytes=bytes)

In [4]:
import colorcet as cc

traj_cmap = matplotlib.cm.tab10
error_cmap = TruncatedColormap(cc.m_CET_L8)
r2_cmap = TruncatedColormap(cc.m_CET_L19_r)

In [5]:
def format_perturbation(f):
    return f.replace(
        "anthropogenic", "ant",
    ).replace(
        "biogenic", "bio",
    ).replace(
        "aerosols", "aer",
    ).replace(
        "monoterpenes", "mtp",
    ).replace(
        "sesquiterpenes", "sqt",
    ).replace(
        "so2", "SO$_2$",
    ).replace(
        "nox", "NO$_x$",
    ).replace(
        "temperature", "$T$",
    ).replace(
        "/", " "
    ).replace(
        "_", " "
    ).replace(
        "mul", r"$\times$"
    ).replace(
        "div", r"$\div$"
    ).replace(
        "add", "$+$"
    ).replace(
        "sub", "$-$"
    )

In [6]:
dts = [
    "2018-05-14 10:00:00", "2018-05-15 19:00:00", "2018-05-17 00:00:00",
    "2018-05-19 04:00:00", "2018-05-21 15:00:00", "2018-05-23 13:00:00",
]

In [7]:
os.chdir("../../thesis/evaluation/results/")

In [8]:
def plot_temporal_generalisation(path: Path, title: str):
    df = pd.read_csv(path)
    
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))

    ax.set_title(f"{title} Performance on Time-Adjacent Trajectories")
    ax.set_ylabel(r"mean absolute error, over $\log_{10}(CCN)$")

    markers = ['o', '^', 'D', 's', '*', 'X']
    markersizes = [6, 7, 5, 6, 9, 7]
    
    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        dt = datetime.datetime.fromisoformat(dt).strftime('%d.%m %H:%M')

        ax.fill_between(
            dft["data_hour_offset"], dft["mae"] - dft["mae_stdv"], dft["mae"] + dft["mae_stdv"],
            color=traj_cmap(i), alpha=0.35,
        )

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        dt = datetime.datetime.fromisoformat(dt).strftime('%d.%m %H:%M')

        ax.plot(
            dft["data_hour_offset"], dft["mae"], label=dt,
            marker=markers[i], markersize=markersizes[i], c=traj_cmap(i),
        )
        
    texts = []

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        for _, (offset, mae, conf) in dft[["data_hour_offset", "mae", "mae_conf"]].iterrows():
            if np.round(conf, 2) < 0.99:
                texts.append(ax.text(
                    offset, mae, f"{conf:.0%}", va="center", ha="center",
                    path_effects=outlined, c=traj_cmap(i),
                ))

    ax.set_xticks([-4, -2, -1, 0, 1, 2, 4])
    ax.set_xticklabels([
        "$t-4$h", "$t-2$h", "$t-1$h", "t", "$t+1$h", "$t+2$h", "$t+4$h"
    ])
    
    if ax.get_ylim()[0] < 0.1:
        ax.set_ylim((0.0, ax.get_ylim()[1]))

    ax.legend(loc="upper center", ncol=3, bbox_to_anchor=(0.5, -0.075))
    
    while True:
        bboxes = [text.get_tightbbox(fig.canvas.get_renderer()) for text in texts]

        bad_texts = []
        for i, text in enumerate(texts):
            if bboxes[i].count_overlaps(bboxes[:i] + bboxes[i+1:]) > 0:
                bad_texts.append(text)
        
        if len(bad_texts) <= 1:
            break

        xlim = ax.get_xlim()
        ax.set_xlim((xlim[0]+0.1, xlim[1]-0.1))
        adjustText.adjust_text(
            bad_texts, avoid_self=True, time_lim=1, ax=ax, force_static=(0.0, 0.0),
            force_text=(0.2, 0.4), force_pull=(1, 1),
        )
        ax.set_xlim(xlim)

    plt.savefig(path.with_suffix(".pdf"), dpi=100, transparent=True, bbox_inches='tight')
    # plt.show()
    plt.close(fig)

In [9]:
plot_temporal_generalisation(Path("temporal-generalisation-rf.csv"), "RF")
plot_temporal_generalisation(Path("temporal-generalisation-padre-rf.csv"), "PADRE-RF")

In [10]:
def plot_clumping_generalisation(path: Path, title: str):
    df = pd.read_csv(path)

    fig, ax = plt.subplots(1, 1, figsize=(6, 4))

    ax.set_title(f"{title} Performance for increasing Clumping Factors")
    ax.set_ylabel(r"mean absolute error, over $\log_{10}(CCN)$")
    
    markers = ['o', '^', 'D', 's', '*', 'X']
    markersizes = [6, 7, 5, 6, 9, 7]
    
    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        dt = datetime.datetime.fromisoformat(dt).strftime('%d.%m %H:%M')

        ax.fill_between(
            dft["model_clump"], dft["mae"] - dft["mae_stdv"], dft["mae"] + dft["mae_stdv"],
            color=traj_cmap(i), alpha=0.35,
        )

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        dt = datetime.datetime.fromisoformat(dt).strftime('%d.%m %H:%M')

        ax.plot(
            dft["model_clump"], dft["mae"], label=dt,
            marker=markers[i], markersize=markersizes[i], c=traj_cmap(i),
        )

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        for _, (clump, mae, conf) in dft[["model_clump", "mae", "mae_conf"]].iterrows():
            if np.round(conf, 2) < 0.99:
                ax.text(clump, mae, f"{conf:.0%}", va="center", ha="center", path_effects=outlined, c=traj_cmap(i))

    ax.set_xscale("function", functions=(lambda x: -np.log(1-x), lambda x: 1-np.exp(-x)))
    ax.set_xticks([0.0, 0.5, 0.75, 0.85, 0.9])
    ax.set_xticklabels([
        r"$0\%$", r"$50\%$", r"$75\%$", r"$85\%$", r"$90\%$",
    ])
    
    if ax.get_ylim()[0] < 0.1:
        ax.set_ylim((0.0, ax.get_ylim()[1]))

    ax.legend(loc="upper center", ncol=3, bbox_to_anchor=(0.5, -0.075))

    plt.savefig(path.with_suffix(".pdf"), dpi=100, transparent=True, bbox_inches='tight')
    # plt.show()
    plt.close(fig)

In [11]:
plot_clumping_generalisation(Path("clumped-generalisation-rf.csv"), "RF")
plot_clumping_generalisation(Path("clumped-generalisation-padre-rf.csv"), "PADRE-RF")

  ax.set_xscale("function", functions=(lambda x: -np.log(1-x), lambda x: 1-np.exp(-x)))
  ax.set_xscale("function", functions=(lambda x: -np.log(1-x), lambda x: 1-np.exp(-x)))
  ax.set_xscale("function", functions=(lambda x: -np.log(1-x), lambda x: 1-np.exp(-x)))
  ax.set_xscale("function", functions=(lambda x: -np.log(1-x), lambda x: 1-np.exp(-x)))


In [12]:
def plot_trajectory_generalisation(path: Path, title: str):
    df = pd.read_csv(path)

    fig, ax1 = plt.subplots(1, 1, figsize=(6, 4))

    ax1.set_title(f"{title} Cross-Trajectory Mean Absolute Error")
    
    m = ax1.matshow(
        np.ones(shape=(len(dts), len(dts))), cmap="gray", vmin=0.0, vmax=1.0,
    )

    max_mae = df["mae"].max()

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        for j, (_, (mae, mae_stdv, conf)) in enumerate(dft[["mae", "mae_stdv", "mae_conf"]].iterrows()):
            ax1.add_patch(plt.Circle(
                (j, i), 0.45*np.sqrt(conf), facecolor=error_cmap(mae/max_mae),
                edgecolor='black', lw=0.5, zorder=2,
            ))
            
            c = "white" if (mae/max_mae) < 0.2 else "black"
            
            if np.round(conf, 3) > 0.0:
                if (mae_stdv*max_mae/len(dts)) < 0.4:
                    # Scale the error bars to match the colorbar scale
                    mmin = i - mae_stdv*2.0/len(dts)
                    mmax = i + mae_stdv*2.0/len(dts)

                    ax1.plot([j-0.1, j+0.1], [mmin, mmin], c=c, zorder=3)
                    ax1.plot([j-0.1, j+0.1], [mmax, mmax], c=c, zorder=3)
                else:
                    ax1.plot([j-0.1, j+0.1], [i-0.4, i-0.4], c=c, ls=(0,(1,1)), zorder=3)
                    ax1.plot([j-0.1, j+0.1], [i+0.4, i+0.4], c=c, ls=(0,(1,1)), zorder=3)

    cb = fig.colorbar(mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(vmin=0.0, vmax=max_mae), cmap=error_cmap,
    ), ax=ax1)
    cb.set_label(r"mean absolute error, over $\log_{10}(CCN)$")

    labels = [
        datetime.datetime.fromisoformat(dt).strftime('%d.%m') for dt in dts
    ]
    axis = np.arange(len(labels))
    ax1.set_xticks(axis)
    ax1.set_yticks(axis)
    ax1.set_xticklabels(labels)
    ax1.set_yticklabels(labels)
    ax1.xaxis.set_ticks_position("bottom")

    ax1.set_xlabel("test trajectory")
    ax1.set_ylabel("model training trajectory")
    
    ax1.grid()

    plt.tight_layout()
    
    plt.savefig(
        path.with_stem(f"{path.stem}-mae").with_suffix(".pdf"),
        dpi=100, transparent=True, bbox_inches='tight',
    )
    # plt.show()
    plt.close(fig)

    """  """

    fig, ax2 = plt.subplots(1, 1, figsize=(6, 4))

    ax2.set_title(f"{title} Cross-Trajectory R$^2$ Score")

    m = ax2.matshow(
        np.ones(shape=(len(dts), len(dts))), cmap="gray", vmin=0.0, vmax=1.0,
    )

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]

        for j, (_, (r2, r2_stdv, conf)) in enumerate(dft[["r2", "r2_stdv", "r2_conf"]].iterrows()):
            ax2.add_patch(plt.Circle(
                (j, i), 0.45*np.sqrt(conf), facecolor=r2_cmap((r2+1)/2), edgecolor='black', lw=0.5, zorder=2,
            ))
            
            if np.round(conf, 3) > 0.0:
                if (r2_stdv*2.0/len(dts)) < 0.4:
                    # Scale the error bars to match the colorbar scale
                    rmin = i - r2_stdv*0.95*2.0/len(dts)
                    rmax = i + r2_stdv*0.95*2.0/len(dts)

                    ax2.plot([j-0.1, j+0.1], [rmin, rmin], c="black", zorder=3)
                    ax2.plot([j-0.1, j+0.1], [rmax, rmax], c="black", zorder=3)
                else:
                    ax2.plot([j-0.1, j+0.1], [i-0.4, i-0.4], c="black", ls=(0,(1,1)), zorder=3)
                    ax2.plot([j-0.1, j+0.1], [i+0.4, i+0.4], c="black", ls=(0,(1,1)), zorder=3)


    cb = fig.colorbar(mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(vmin=-1.0, vmax=1.0), cmap=r2_cmap,
    ), ax=ax2, extend="min")
    cb.set_label(r"R$^2$ score")

    ax2.set_xticks(axis)
    ax2.set_yticks(axis)
    ax2.set_xticklabels(labels)
    ax2.set_yticklabels(labels)
    ax2.xaxis.set_ticks_position("bottom")

    ax2.set_xlabel("test trajectory")
    ax2.set_ylabel("model training trajectory")

    ax2.grid()

    plt.tight_layout()
    
    plt.savefig(
        path.with_stem(f"{path.stem}-r2").with_suffix(".pdf"),
        dpi=100, transparent=True, bbox_inches='tight',
    )
    # plt.show()
    plt.close(fig)

In [13]:
plot_trajectory_generalisation(Path("trajectory-generalisation-rf.csv"), "RF")
plot_trajectory_generalisation(Path("trajectory-generalisation-padre-rf.csv"), "PADRE-RF")

In [14]:
def plot_perturbation_generalisation(path: Path, title: str):
    df = pd.read_csv(path)

    fig, ax1 = plt.subplots(1, 1, figsize=(12, 4))

    ax1.set_title(f"{title} R$^2$ score and Confidence for different small Perturbations")

    xticks = np.arange(16)
    yticks = np.arange(len(dts))

    m = ax1.matshow(
        np.ones(shape=(len(dts), 16)), cmap="gray", vmin=0.0, vmax=1.0,
    )

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]
        dft = dft[[("1.01" in p) or ("0.04K" in p) for _, p in dft["perturbation"].items()]]

        for j, (_, (r2, r2_stdv, conf)) in enumerate(dft[["r2", "r2_stdv", "r2_conf"]].iterrows()):
            ax1.add_patch(plt.Circle(
                (j, i), 0.45*np.sqrt(conf), facecolor=r2_cmap((r2+5)/6), edgecolor='black', lw=0.5, zorder=2,
            ))

            if np.round(conf, 3) > 0.0:
                if (r2_stdv*6.0/len(dts)) < 0.4:
                    # Scale the error bars to match the colorbar scale
                    rmin = i - r2_stdv*0.95*6.0/len(dts)
                    rmax = i + r2_stdv*0.95*6.0/len(dts)

                    ax1.plot([j-0.1, j+0.1], [rmin, rmin], c="black", zorder=3)
                    ax1.plot([j-0.1, j+0.1], [rmax, rmax], c="black", zorder=3)
                else:
                    ax1.plot([j-0.1, j+0.1], [i-0.4, i-0.4], c="black", ls=(0,(1,1)), zorder=3)
                    ax1.plot([j-0.1, j+0.1], [i+0.4, i+0.4], c="black", ls=(0,(1,1)), zorder=3)

    cb = fig.colorbar(mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(vmin=-5.0, vmax=1.0), cmap=r2_cmap,
    ), ax=ax1, extend="min", pad=0.025)
    cb.set_label(r"R$^2$ score")
    cb.set_ticks([-5, -4, -3, -2, -1, 0, 1])

    labels = [
        datetime.datetime.fromisoformat(dt).strftime('%d.%m') for dt in dts
    ]
    axis = np.arange(len(labels))
    ax1.set_xticks(xticks)
    ax1.set_xticklabels([
        format_perturbation(p) for _, p in df["perturbation"][16:32].items()
    ], rotation=45, ha="right", rotation_mode="anchor")
    ax1.xaxis.set_ticks_position("bottom")
    ax1.set_yticks(yticks)
    ax1.set_yticklabels(labels)

    ax1.set_ylabel("model training trajectory")

    ax1.grid()

    plt.tight_layout()
    
    plt.savefig(
        path.with_stem(f"{path.stem}-small").with_suffix(".pdf"),
        dpi=100, transparent=True, bbox_inches='tight',
    )
    # plt.show()
    plt.close(fig)

    """ """

    fig, ax2 = plt.subplots(1, 1, figsize=(12, 4))

    ax2.set_title(f"{title} R$^2$ score and Confidence for different large Perturbations")

    m = ax2.matshow(
        np.ones(shape=(len(dts), 16)), cmap="gray", vmin=0.0, vmax=1.0,
    )

    for i, dt in enumerate(dts):
        dft = df[df["model_date"] == dt]
        dft = dft[[("1.5" in p) or ("2K" in p) for _, p in dft["perturbation"].items()]]

        for j, (_, (r2, r2_stdv, conf)) in enumerate(dft[["r2", "r2_stdv", "r2_conf"]].iterrows()):
            ax2.add_patch(plt.Circle(
                (j, i), 0.45*np.sqrt(conf), facecolor=r2_cmap((r2+5)/6), edgecolor='black', lw=0.5, zorder=2,
            ))

            if np.round(conf, 3) > 0.0:
                if (r2_stdv*6.0/len(dts)) < 0.4:
                    # Scale the error bars to match the colorbar scale
                    rmin = i - r2_stdv*0.95*6.0/len(dts)
                    rmax = i + r2_stdv*0.95*6.0/len(dts)

                    ax2.plot([j-0.1, j+0.1], [rmin, rmin], c="black", zorder=3)
                    ax2.plot([j-0.1, j+0.1], [rmax, rmax], c="black", zorder=3)
                else:
                    ax2.plot([j-0.1, j+0.1], [i-0.4, i-0.4], c="black", ls=(0,(1,1)), zorder=3)
                    ax2.plot([j-0.1, j+0.1], [i+0.4, i+0.4], c="black", ls=(0,(1,1)), zorder=3)

    cb = fig.colorbar(mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(vmin=-5.0, vmax=1.0), cmap=r2_cmap,
    ), ax=ax2, extend="min", pad=0.025)
    cb.set_label(r"R$^2$ score")
    cb.set_ticks([-5, -4, -3, -2, -1, 0, 1])

    ax2.set_xticks(xticks)
    ax2.set_xticklabels([
        format_perturbation(p) for _, p in df["perturbation"][0:16].items()
    ], rotation=45, ha="right", rotation_mode="anchor")
    ax2.xaxis.set_ticks_position("bottom")
    ax2.set_yticks(yticks)
    ax2.set_yticklabels(labels)

    ax2.set_ylabel("model training trajectory")

    ax2.grid()

    plt.tight_layout()
    
    plt.savefig(
        path.with_stem(f"{path.stem}-large").with_suffix(".pdf"),
        dpi=100, transparent=True, bbox_inches='tight',
    )
    # plt.show()
    plt.close(fig)

In [15]:
plot_perturbation_generalisation(Path("perturbation-generalisation-rf.csv"), "SOSAA-RF")
plot_perturbation_generalisation(Path("perturbation-generalisation-padre-rf.csv"), "SOSAA-PADRE-RF")
plot_perturbation_generalisation(Path("perturbation-generalisation-padre-rf-direct.csv"), "SOSAA-PADRE-RF-direct")
plot_perturbation_generalisation(Path("perturbation-generalisation-percentile-padre-rf.csv"), "SOSAA-PADRE-RF-percentile")
plot_perturbation_generalisation(Path("perturbation-generalisation-percentile-padre-rf-direct.csv"), "SOSAA-PADRE-RF-direct-percentile")