# Detailed Analysis of Results for MIF

This notebook is organised as follows:

- [Load and process data](#load)
- [Plot bias, sem and rmse](#plot)
- [Plot specific examples](#examples)
- [Compare results to max achievable performance](#max)


In [None]:
import red
import numpy as np
import pickle as pkl
from tqdm import tqdm
import pandas as pd
import seaborn as sns

from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib import gridspec
import matplotlib.pyplot as plt
plt.style.use('ggplot')

colors = sns.color_palette('colorblind')

def get_subplots(systems: list[str], scale_x: float = 1, scale_y: float =1) -> tuple[plt.Figure, list[plt.Axes]]:
    # Plot two columns side-by-side
    n_cols = min(2, len(systems))
    n_rows = int(np.ceil(len(systems) / n_cols))
    # fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 5*n_rows))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(scale_x*3 * n_cols, scale_y*3*n_rows))
    if len(systems) == 1:
        axs = [axs]
    else:
        axs = axs.flatten()

    # Set titles
    for i, system in enumerate(systems):
        axs[i].set_title(system)

    # Remove any unused axes
    for i in range(len(systems), len(axs)):
        fig.delaxes(axs[i])

    return fig, axs

def sanitise_name(name: str) -> str:
    # Replace underscores with spaces
    name = name.replace('_', ' ')
    
    # Split the string by spaces and newlines
    words = name.split()
    
    # Capitalize the first letter of each word
    capitalized_words = [word[0].upper() + word[1:] if word else '' for word in words]
    
    # Join the words back together with spaces
    sanitized_name = ' '.join(capitalized_words)
    
    # Restore newlines
    sanitized_name = sanitized_name.replace(' \n ', '\n')
    
    return sanitized_name


In [None]:
with open("../compute_equil_times/output/synthetic_data_bound_vanish_with_equil_times_saved.pkl", "rb") as f:
    data = pkl.load(f)

In [None]:
# Replace any keys in data called "Window Size $\\sqrt(n)$" with Window Size $\\sqrt{n}$. Do this recursively.
replaced_data = {}

for dataset in data:
    replaced_data[dataset] = {}
    for system in [sys for sys in data[dataset] if sys != "times"]:
        replaced_data[dataset][system] = {}
        for repeat in data[dataset][system]:
            replaced_data[dataset][system][repeat] = {}
            for key in data[dataset][system][repeat]:
                new_key = key.replace("Window Size $\\sqrt(n)$", "Window Size $\\sqrt{N_{n_0}}$")
                replaced_data[dataset][system][repeat][new_key] = data[dataset][system][repeat][key]


In [None]:
with open("../compute_equil_times/output/synthetic_data_bound_vanish_with_equil_times.pkl", "wb") as f:
    pkl.dump(replaced_data, f)

## Load and process data <a id="load"></a>

In [None]:
with open("../compute_equil_times/output/synthetic_data_bound_vanish_with_equil_times.pkl", "rb") as f:
    data = pkl.load(f)

datasets = list(data.keys())
# Sort datasets to put standard first
datasets = sorted(datasets, key=lambda x: x != "standard")
systems = [system for system in data[datasets[0]].keys() if system != "times"]
methods = [method for method in data[datasets[0]][systems[0]][0].keys() if method != "data"]

# Create version of the data with MIF only
mif_data = {dataset: data[dataset]["MIF"] for dataset in datasets}

# Seperate out automated and fixed methods
automated_methods = [method for method in methods if "Discard" not in method]
fixed_methods = [method for method in methods if "Discard" in method]

In [None]:
def bootstrap(data: np.ndarray, fn: callable = lambda x: x, n_bootstraps: int = 10_000) -> tuple[np.ndarray, np.ndarray]:
    """Get the 95 % confidence interval for the mean"""
    means = np.zeros(n_bootstraps)
    for i in range(n_bootstraps):
        # Get a random list of indices
        values = np.random.choice(data, size=len(data), replace=True)
        # Apply passed function to the data
        values = fn(values)
        # Take mean. If fn has already been applied to give a single value,
        # e.g. the variance, then this will not change the value.
        means[i] = np.mean(values)
    return np.percentile(means, [2.5, 97.5]), means

In [None]:
data["standard"]["T4L"][0].keys()

In [None]:
# Compute statistics for all datasets.
overall_results = {}
distributions = {}

# Can't pickle lambda functions for multiprocessing, so define the functions here

for dataset in tqdm(datasets):
    overall_results[dataset] = {}
    distributions[dataset] = {}
    for system in systems:
        overall_results[dataset][system] = {}
        distributions[dataset][system] = {}

        for method in methods:
            overall_results[dataset][system][method] = {}
            distributions[dataset][system][method] = {}
            local_data = {i: data[dataset][system][i][method] for i in data[dataset][system]}

            # Basic stats
            means = np.array([local_data[i]["mean"] for i in local_data])
            bias = means.mean()
            variance = means.var()

            # MUEs and RMSEs are easy to calculate as the true answer is 0
            mues = abs(means)
            mue = mues.mean()
            mses = means ** 2
            rmse = np.sqrt(mses.mean())

            # Check time taken for the calculations
            times = np.array([local_data[i]["time"] for i in local_data])
            time = times.mean()

            # Fraction of data discarded
            fracs = np.array([local_data[i]["frac_discarded"] for i in local_data])
            frac_discarded = fracs.mean()

            # Get the 95 % confidence intervals
            mean_ci, mean_distr = bootstrap(means)
            frac_ci, frac_distr = bootstrap(fracs)
            time_ci, time_distr = bootstrap(times)
            var_ci, var_distr = bootstrap(means, np.var)
            mue_ci, mue_distr = bootstrap(means, abs)
            # CIs for the RMSE - need to square root to get the RMSE
            rmse_ci, rmse_distr = bootstrap(means, np.square)
            rmse_ci = np.sqrt(rmse_ci)
            rmse_distr = np.sqrt(rmse_distr)

            # Save the results
            overall_results[dataset][system][method]["bias"] = bias
            overall_results[dataset][system][method]["variance"] = variance
            overall_results[dataset][system][method]["mue"] = mue
            overall_results[dataset][system][method]["rmse"] = rmse
            overall_results[dataset][system][method]["time"] = time
            overall_results[dataset][system][method]["frac_discarded"] = frac_discarded
            overall_results[dataset][system][method]["mean_ci"] = mean_ci
            overall_results[dataset][system][method]["frac_ci"] = frac_ci
            overall_results[dataset][system][method]["time_ci"] = time_ci
            overall_results[dataset][system][method]["var_ci"] = var_ci
            overall_results[dataset][system][method]["mue_ci"] = mue_ci
            overall_results[dataset][system][method]["rmse_ci"] = rmse_ci

            # Save the distributions
            distributions[dataset][system][method]["means"] = means
            distributions[dataset][system][method]["fracs"] = fracs
            distributions[dataset][system][method]["times"] = times
            distributions[dataset][system][method]["mues"] = mues
            distributions[dataset][system][method]["mses"] = mses

            # Save the bootstrap distributions
            distributions[dataset][system][method]["mean_boot"] = mean_distr
            distributions[dataset][system][method]["frac_boot"] = frac_distr
            distributions[dataset][system][method]["time_boot"] = time_distr
            distributions[dataset][system][method]["var_boot"] = var_distr
            distributions[dataset][system][method]["mue_boot"] = mue_distr
            distributions[dataset][system][method]["rmse_boot"] = rmse_distr


In [None]:
with open("output/overall_results.pkl", "rb") as f:
    overall_results = pkl.load(f)

with open("output/distributions.pkl", "rb") as f:
    distributions = pkl.load(f)

In [None]:
# Save the results
with open("output/overall_results.pkl", "wb") as f:
    pkl.dump(overall_results, f)

with open("output/distributions.pkl", "wb") as f:
    pkl.dump(distributions, f)

## Summary Tables

In [None]:
# Collect the RMSEs into dataframes. We'll create one table for each dataset.

for dataset in datasets:
    # Save a latex table of results
    rows = []
    for method in automated_methods:
        system_rmses = {}
        for system in systems:
            rmse = overall_results[dataset][system][method]["rmse"]
            upper_ci = overall_results[dataset][system][method]["rmse_ci"][1]
            lower_ci = overall_results[dataset][system][method]["rmse_ci"][0]
            system_rmses[system] = f"${rmse:.3f}_{{{lower_ci:.3f}}}^{{{upper_ci:.3f}}}$"
        row = {"Method": method, **system_rmses}
        rows.append(row)

    df = pd.DataFrame(rows)
    df.to_latex(f"output/rmse_table_{dataset}.tex", index=False, escape=False, column_format= "l" + "c" * len(systems))

## Visualisation of A Few Examples

In [None]:
data["standard"]["MIF"][0].keys()

In [None]:
# Plot a few examples

idx, g, ess = red.detect_equilibration_window(data["standard"]["MIF"][9]["data"], method="min_sse", plot=True)

In [None]:
idx, g, ess = red.detect_equilibration_init_seq(data["standard"]["MIF"][9]["data"], sequence_estimator="positive", plot = True)

In [None]:
idx, g, ess = red.detect_equilibration_init_seq(data["standard"]["MIF"][9]["data"], sequence_estimator="initial_convex", plot = True)

In [None]:
# Get a set of 3 x 2 subplots
from matplotlib import gridspec

example_data_full = data["standard"]["MIF"][9]["data"]
example_data = example_data_full[:]
example_times = np.linspace(0, 8, len(example_data_full)+1)[1:][:]
fig = plt.figure(figsize=(13, 5))
gridspec_obj = gridspec.GridSpec(1, 4, figure=fig)

# gridspec_obj[0].set_title("Min SSE")

red.detect_equilibration_window(example_data, example_times, window_size_fn=None, window_size=1, plot=True, figure=fig, grid_spec_obj=gridspec_obj[0])
red.detect_equilibration_window(example_data, example_times, plot=True, figure=fig, grid_spec_obj=gridspec_obj[1])
red.detect_equilibration_init_seq(example_data, example_times, sequence_estimator="positive", plot = True, figure=fig, grid_spec_obj=gridspec_obj[2])
red.detect_equilibration_init_seq(example_data, example_times, sequence_estimator="initial_convex", plot = True, figure=fig, grid_spec_obj=gridspec_obj[3])

# Get all axes in the figure
axs = fig.get_axes()
# Remove all legends
for i, ax in enumerate(axs):
    ax.legend().remove()

    # Keep only the leftmost y labels
    if i > 1:
        ax.set_ylabel("")
    
    # Set 1/ estimated variance as the y label for bottom left axes
    if i == 1:
        ax.set_ylabel("$(\\widehat{\\mathrm{Var}}_{\\mathrm{Trajs}}(\\langle A \\rangle_{[n_{0},N]}))^{-1}$ \n/ kcal$^{-2}$ mol$^2$")

    # Set title
    if i == 0:
        ax.set_title("Uncorrelated Estimate")
    if i == 3:
        ax.set_title("Window Size $\\sqrt{N_{n_0}}$")
    elif i == 6:
        ax.set_title("Initial Sequence: Chodera")
    elif i == 9:
        ax.set_title("Initial Sequence: Convex")

    # Remove grids from bottom right y axes
    if i in [2, 5, 8, 11]:
        ax.yaxis.grid(False)
        # Use the same colour as the line plotted
        lag_color = ax.get_lines()[0].get_color()
        ax.tick_params(axis='y', labelcolor=lag_color)

    # Set the y axis label for the lag index
    if i == 11:
        ax.set_ylabel("Window Size or \n Max Lag Index", color=lag_color)

fig.tight_layout()

# Combine the axis labels from the tow last axes into a single legend
handles_see, labels_see = axs[-2].get_legend_handles_labels()
handles_lag, labels_lag = axs[-1].get_legend_handles_labels()
fig.legend(handles_see + handles_lag, ["$(\\widehat{\\mathrm{Var}}_{\\mathrm{Trajs}}(\\langle A \\rangle_{[n_{0},N]}))^{-1}$", "Equilibration Time"] + ["Window Size or Max Lag Index"], loc='upper right', bbox_to_anchor=(0.62, -0.03))


fig.savefig("output/single_example_detection_uncorr.png", dpi=300, bbox_inches='tight')

In [None]:
# Get a set of 3 x 2 subplots
from matplotlib import gridspec

example_data_full = data["standard"]["MIF"][2]["data"]
example_data = example_data_full[:]
example_times = np.linspace(0, 8, len(example_data_full)+1)[1:][:]
fig = plt.figure(figsize=(13, 5))
gridspec_obj = gridspec.GridSpec(1, 4, figure=fig)

# gridspec_obj[0].set_title("Min SSE")

red.detect_equilibration_init_seq(example_data, example_times, sequence_estimator="positive", plot = True, figure=fig, grid_spec_obj=gridspec_obj[0])
red.detect_equilibration_init_seq(example_data, example_times, sequence_estimator="initial_positive", plot = True, figure=fig, grid_spec_obj=gridspec_obj[1])
red.detect_equilibration_init_seq(example_data, example_times, sequence_estimator="initial_monotone", plot = True, figure=fig, grid_spec_obj=gridspec_obj[2])
red.detect_equilibration_init_seq(example_data, example_times, sequence_estimator="initial_convex", plot = True, figure=fig, grid_spec_obj=gridspec_obj[3])

# Get all axes in the figure
axs = fig.get_axes()
# Remove all legends
for i, ax in enumerate(axs):
    ax.legend().remove()

    # Keep only the leftmost y labels
    if i > 1:
        ax.set_ylabel("")
    
    # Set 1/ estimated variance as the y label for bottom left axes
    if i == 1:
        ax.set_ylabel("$(\\widehat{\\mathrm{Var}}_{\\mathrm{Trajs}}(\\langle A \\rangle_{[n_{0},N]}))^{-1}$ \n/ kcal$^{-2}$ mol$^2$")

    # Set title
    if i == 0:
        ax.set_title("Chodera")
    if i == 3:
        ax.set_title("Initial Positive")
    elif i == 6:
        ax.set_title("Initial Monotone")
    elif i == 9:
        ax.set_title("Initial Convex")

    # Remove grids from bottom right y axes
    if i in [2, 5, 8, 11]:
        ax.yaxis.grid(False)
        # Use the same colour as the line plotted
        lag_color = ax.get_lines()[0].get_color()
        ax.tick_params(axis='y', labelcolor=lag_color)

    # Set the y axis label for the lag index
    if i == 11:
        ax.set_ylabel("Window Size or \n Max Lag Index", color=lag_color)

fig.tight_layout()

# Combine the axis labels from the tow last axes into a single legend
handles_see, labels_see = axs[-2].get_legend_handles_labels()
handles_lag, labels_lag = axs[-1].get_legend_handles_labels()
fig.legend(handles_see + handles_lag, ["$(\\widehat{\\mathrm{Var}}_{\\mathrm{Trajs}}(\\langle A \\rangle_{[n_{0},N]}))^{-1}$", "Equilibration Time"] + ["Window Size or Max Lag Index"], loc='upper right', bbox_to_anchor=(0.62, -0.03))

In [None]:
labels_lag

## Load Data Creation Parameters and Plot Bias, SEM and RMSE <a id="plot"></a>

In [None]:
with open("../synthetic_data_creation/output/synthetic_data_params.pkl", "rb") as f:
    params = pkl.load(f)

In [None]:
def exp_decay(x: np.ndarray, a: float, b: float) -> np.ndarray:
    return a * np.exp(-b * x)

def compute_bias(times: np.ndarray,
                 exp_params: tuple[float, float],
                 fast_exp_params: tuple[float, float]) -> np.ndarray:
    """Compute the bias for the given times"""
    # First, compute the bias at each point in the series
    bias = exp_decay(times, *exp_params) + exp_decay(times, *fast_exp_params)
    # Then, for each data point, average over all subsequent data points to get the mean biases
    mean_bias = np.zeros(len(bias))
    for i, _ in enumerate(bias):
        n_points = len(bias) - i
        mean_bias[i] = np.sum(bias[i:]) / n_points    

    return mean_bias

def compute_mean_variance(times: np.ndarray,
                          autocov_series: np.ndarray) -> np.ndarray:
    """Compute the mean variance over the times passed."""
    # Precompute cumulative sums
    # return autocov_series[0] + 2*np.sum(autocov_series[1:])
    cumsum_autocov = np.cumsum(autocov_series[1:])
    uncor_variance = autocov_series[0]

    forward_cor_variances = np.zeros(len(times))
    for i in range(len(times)):
        remaining_points = len(times) - i
        forward_cor_variances[i] = cumsum_autocov[remaining_points - 1] if remaining_points - 1 < len(cumsum_autocov) else cumsum_autocov[-1]

    # Backward correlations are just the same as forward correlations, but in reverse,
    # so simply double and add uncorrelated variance
    return np.mean(2*forward_cor_variances + uncor_variance)

def compute_sem(times: np.ndarray,
                autocov_series: np.ndarray) -> np.ndarray:
    """
    Compute the standard error of the mean for the given times. It is assumed
    that the times passed represent the entire series, with same frequency as the 
    original data.
    """
    sems = np.zeros(len(times))
    for i, _ in tqdm(enumerate(times), total=len(times), desc="Processing Times"):
        variance = compute_mean_variance(times[i:], autocov_series)
        n_points = len(times) - i
        sems[i] = np.sqrt(variance / n_points)
    return sems

def compute_rmse(times: np.ndarray,
                 exp_params: tuple[float, float],
                 fast_exp_params: tuple[float, float],
                 autocov_series: np.ndarray) -> np.ndarray:
    """
    Compute the RMSE at each time given the exponential parameters. It is
    assumed that the times passed represent the entire series, and the number of
    data points are the same as those sampled originally.
    """
    bias = compute_bias(times, exp_params, fast_exp_params)
    sem = compute_sem(times, autocov_series)
    return np.sqrt(sem ** 2 + bias ** 2)

In [None]:
datasets

In [None]:
# Calculate the bias, SEM, and RMSE for each dataset
fixed_trunc_error_series = {}

# Currently, only compute for the standard results
for dataset in datasets:
    fixed_trunc_error_series[dataset] = {}
    for system in systems:
        fixed_trunc_error_series[dataset][system] = {}

        # Get the parameters used to generate the series
        exp_params = params[system]["exp_params"]
        fast_exp_params = params[system]["fast_exp_params"]
        variance_fac = 1 if dataset != "noisy" else 5
        autocov_series = params[system]["autocov_convex"] * variance_fac

        # If this is the subsampled dataset, them subsample the autocovariance series to account for this
        autocov_series = autocov_series[::100] if dataset == "subsampled" else autocov_series

        # If the dataset is block averaged, the stats are the same as the standard dataset
        dataset_lookup = dataset if dataset != "block_averaged" else "standard"
        test_data = data[dataset_lookup][system][0]["data"]
        tot_time = 8 if dataset != "short" else 0.2
        times = np.linspace(0, tot_time, len(test_data) + 1)[1:] # The times at which the data were sampled

        # Compute the bias, sem, and rmse
        bias = compute_bias(times, exp_params, fast_exp_params)
        sem = compute_sem(times, autocov_series)
        rmse = np.sqrt(sem ** 2 + bias ** 2)

        # Get the optimal truncation point
        trunc_point = np.argmin(rmse)
        trunc_time = times[trunc_point]

        # If this is block averaged data, then we need to downsample the series
        if dataset == "block_averaged":
            # Blocks of size 100, ignoring the time closest to 0
            bias = bias[::100][1:]
            sem = sem[::100][1:]
            rmse = rmse[::100][1:]
        
        assert len(sem) == len(data[dataset][system][0]["data"]), f"Length of SEM series {len(sem)} does not match length of data {len(data[dataset][system][0]['data'])}"

        # Save the results
        fixed_trunc_error_series[dataset][system]["bias_series"] = bias
        fixed_trunc_error_series[dataset][system]["sem_series"] = sem
        fixed_trunc_error_series[dataset][system]["rmse_series"] = rmse
        fixed_trunc_error_series[dataset][system]["optimal_trunc_point"] = trunc_point
        fixed_trunc_error_series[dataset][system]["optimal_trunc_time"] = trunc_time

In [None]:
with open("output/fixed_trunc_error_series_stats.pkl", "wb") as f:
    pkl.dump(fixed_trunc_error_series, f)

In [None]:
with open("output/fixed_trunc_error_series_stats.pkl", "rb") as f:
    fixed_trunc_error_series = pkl.load(f)

In [None]:
def plot_theoretical_rmse_on_axis(ax: Axes, dataset: str, system: str) -> None:
    test_data = data[dataset][system][0]["data"]
    tot_time = 8 if dataset != "short" else 0.2
    test_times = np.linspace(0, tot_time, len(test_data) + 1)[1:] # The times at which the data was sampled
    rmse_series = fixed_trunc_error_series[dataset][system]["rmse_series"]

    # Get the real RMSE from the discarded fractions
    real_rmse = [overall_results[dataset][system][method]["rmse"] for method in fixed_methods]
    real_cis = [overall_results[dataset][system][method]["rmse_ci"] for method in fixed_methods]
    cis_lower = abs(np.array(real_cis)[:,0] - np.array(real_rmse))
    cis_upper = abs(np.array(real_cis)[:,1] - np.array(real_rmse))
    fracs = [float(method.split(" ")[-1]) for method in fixed_methods]
    tot_time = 8 if dataset != "short" else 0.2
    times = [frac * tot_time for frac in fracs]

    # Plot the RMSE
    ax.plot(test_times[:], rmse_series[:], label="Theoretical", zorder=2, alpha=0.7)
    ax.errorbar(times, real_rmse, yerr=[cis_lower, cis_upper], fmt='-', label="Empirical", zorder=1, ecolor='black')
    ax.set_xlabel("Truncation Time / ns")
    ax.set_ylabel("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle_{[n_{0},N]})$ / kcal mol$^{-1}$")
    ax.set_title(system)
    
    # Set max y limit to be 10 % above the highest RMSE, and 10 % below the lowest RMSE
    threshold =np.min(real_rmse[:-1]) * 0.9, np.max(real_rmse[:-1]) * 1.1
    ax.set_ylim(*threshold)

fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    plot_theoretical_rmse_on_axis(axs[i], "standard", system)

fig.tight_layout()

# Only put the legend to the left of the last plot
axs[-2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

fig.savefig("output/theoretical_vs_empirical_rmse.png", dpi=300, bbox_inches='tight')

In [None]:
def plot_error_components_on_ax(ax: Axes, dataset: str, system: str, show_min: bool = True) -> None:
    test_data = data[dataset][system][0]["data"]
    tot_time = 8 if dataset != "short" else 0.2
    test_times = np.linspace(0, tot_time, len(test_data) + 1)[1:] # The times at which the data was sampled
    bias_series = fixed_trunc_error_series[dataset][system]["bias_series"]
    sem_series = fixed_trunc_error_series[dataset][system]["sem_series"]
    rmse_series = fixed_trunc_error_series[dataset][system]["rmse_series"]

    # Plot the error components, truncating the last 1 % of the data
    n_truncate = round(len(test_times) * 0.01) # Truncate the last 1 % of the data to avoid large RMSE scaling the y-axis
    ax.plot(test_times[:-(n_truncate + 1)], rmse_series[:-(n_truncate+1)], label="$\\mathrm{RMSE}(\\langle \\Delta G \\rangle_{[n_{0},N]})$", zorder=2, alpha=0.7)
    ax.plot(test_times[:-(n_truncate + 1)], bias_series[:-(n_truncate+1)], label="$\\mathrm{Bias}(\\langle \\Delta G \\rangle_{[n_{0},N]})$", zorder=1, alpha=0.7)
    ax.plot(test_times[:-(n_truncate + 1)], sem_series[:-(n_truncate +1)], label="$\\mathrm{SD}(\\langle \\Delta G \\rangle_{[n_{0},N]})$", zorder=1, alpha=0.7)
    
    if show_min:
        # Plot dashed vertical line at the optimal truncation point
        trunc_time = fixed_trunc_error_series[dataset][system]["optimal_trunc_time"]
        ax.axvline(trunc_time, color='black', linestyle='--', label="Optimal Truncation Time")

        # Plot a dashed horizontal line at the RMSE at the optimal truncation point
        trunc_rmse = rmse_series[np.argmin(rmse_series)]
        ax.axhline(trunc_rmse, color='black', linestyle='--')

        # Add a small red dot at the minimum RMSE
        ax.plot(trunc_time, trunc_rmse, 'ro', alpha=0.7)

    ax.set_xlabel("Truncation Time / ns")
    ax.set_ylabel("Error / kcal mol$^{-1}$")
    ax.set_title(system)

# Plot the components of the RMSE
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    plot_error_components_on_ax(ax, "standard", system, show_min=False)
    ax.set_xlabel("Truncation Time / ns")
    ax.set_ylabel("Error / kcal mol$^{-1}$")
    ax.set_title(system)

fig.tight_layout()

# Only put the legend to the left of the last plot
axs[-2].legend(bbox_to_anchor=(1.2, 0.7), loc='upper left')

fig.savefig("output/error_components.png", dpi=300, bbox_inches='tight')

## Plot Discard Times



In [None]:
distributions["standard"]["T4L"].keys()

In [None]:
def plot_discard_times_on_ax(ax: Axes, dataset: str, system: str, n_truncate: int = 100) -> None:
    # Get a dataframe of the times discarded
    tot_time = 8 if dataset != "short" else 0.2
    df_times = pd.DataFrame({method: distributions[dataset][system][method]["fracs"]*tot_time for method in automated_methods})
    sns.violinplot(data=df_times, ax=ax, orient="h",palette=colors, alpha=1)
    ax.set_title(system)
    ax.set_xlabel("Truncation Time / ns")

    # Plot dashed vertical line at the optimal truncation point
    trunc_time = fixed_trunc_error_series[dataset][system]["optimal_trunc_time"]
    ax.axvline(trunc_time, color='black', linestyle='--', label="Optimal Truncation Time")

    ax.set_title(system)

# Plot the discard times
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    plot_discard_times_on_ax(ax, "standard", system, 100)
    # Remove y tick labels from right hand column
    if i % 2 == 1:
        ax.set_yticklabels([])
    ax.set_xlabel("Truncation Time / ns")
    ax.set_title(system)

# Tight layout, but only in the y direction
fig.subplots_adjust(hspace=0.5)

# Only put the legend to the left of the last plot
axs[-2].legend(bbox_to_anchor=(1.2, 0.7), loc='upper left')

fig.savefig("output/discard_times_standard.png", dpi=300, bbox_inches='tight')
    

## Plot RMSES

In [None]:
def plot_rmses_on_ax(ax: Axes, dataset: str, system: str, n_truncate: int = 100) -> None:
    # Get RMSEs and confidence intervals
    rmse = [overall_results[dataset][system][method]["rmse"] for method in automated_methods]
    cis = [overall_results[dataset][system][method]["rmse_ci"] for method in automated_methods]

    # Convert CIs from absolute values to relative values
    cis_lower = np.array(rmse) - np.array(cis)[:,0]
    cis_upper = np.array(cis)[:,1] - np.array(rmse)
    error_bar_settings = {"capsize": 0, "alpha": 1, "elinewidth": 1}
    ax.bar(automated_methods, rmse, yerr=[cis_lower, cis_upper], capsize=5, 
           alpha=1, error_kw=error_bar_settings, color=colors, edgecolor='black', linewidth=1)
    
    # Get the minimum possible fixed-time RMSE and plot a horizontal to show it
    min_rmse = np.min(fixed_trunc_error_series[dataset][system]["rmse_series"])
    ax.axhline(min_rmse, color='black', linestyle='--', label="Minimum Fixed-Time RMSE")

    # Plot the 0.2 % discard RMSE
    # discard_rmse = overall_results[dataset][system]["Discard Fraction 0.2"]["rmse"]
    # ax.axhline(discard_rmse, color='red', linestyle='--', label="Discard Fraction 0.2 RMSE")

    ax.set_xticklabels(methods, rotation=90)
    ax.set_title(system)

    # Remove x grid ylines
    ax.xaxis.grid(False)

# Plot the RMSEs
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    plot_rmses_on_ax(ax, "standard", system, 100)
    ax.set_ylabel("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle)$ \n/ kcal mol$^{-1}$")
    ax.set_title(system)

    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

    # Add the legend to the last plot
    if i == len(systems) - 1:
        ax.legend(bbox_to_anchor=(1.05, -0.3), loc='upper left')

fig.tight_layout()

fig.savefig("output/rmses_standard.png", dpi=300, bbox_inches='tight')

## Distributions of Unsigned Errors

In [None]:
def plot_unsigned_error_distribution_on_ax(ax: Axes, dataset: str, system: str) -> None:
    # Get the squared errors
    df_ses = pd.DataFrame({method: distributions[dataset][system][method]["mues"] for method in automated_methods})
    sns.violinplot(data=df_ses, ax=ax, palette=colors)
    ax.set_title(system)
    ax.set_ylabel("Unsigned Error / kcal mol$^{-1}$")
    ax.set_xticklabels(methods, rotation=90)
    
    # Get the minimum possible fixed-time RMSE and plot a horizontal to show it
    min_rmse = np.min(fixed_trunc_error_series[dataset][system]["rmse_series"])
    ax.axhline(min_rmse, color='black', linestyle='--', label="Minimum Fixed-Time RMSE")

# Plot the squared errors
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    plot_unsigned_error_distribution_on_ax(ax, "standard", system)
    ax.set_title(system)

    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

fig.tight_layout()

fig.savefig("output/unsigned_errors_standard.png", dpi=300, bbox_inches='tight')

## Contour Plots of Components of Error

In [None]:
# Based on the code above
def plot_contour_plot_on_axis(ax: Axes, dataset: str, system: str) -> None:
    # Get the fixed-time bias and sem
    bias = fixed_trunc_error_series[dataset][system]["bias_series"]
    sem = fixed_trunc_error_series[dataset][system]["sem_series"]
    ax.plot(bias, sem, label="Fixed Truncation\n Time Limit", zorder=2)

    # Now, get the biases, variances, and associated CIs for all of the methods.
    # We need these to decide how big the grid needs to be
    biases = [overall_results[dataset][system][method]["bias"] for method in automated_methods]
    sems = [overall_results[dataset][system][method]["variance"]**0.5 for method in automated_methods]
    bias_cis_upper = [overall_results[dataset][system][method]["mean_ci"][1] - overall_results[dataset][system][method]["bias"] for method in automated_methods]
    bias_cis_lower = [overall_results[dataset][system][method]["bias"] - overall_results[dataset][system][method]["mean_ci"][0] for method in automated_methods]
    sem_cis_upper = [overall_results[dataset][system][method]["var_ci"][1]**0.5 - overall_results[dataset][system][method]["variance"]**0.5 for method in automated_methods]
    sem_cis_lower = [overall_results[dataset][system][method]["variance"]**0.5 - overall_results[dataset][system][method]["var_ci"][0]**0.5 for method in automated_methods]

    # Create a grid of points
    max_bias_or_sem = max(max(biases), max(sems))
    limit = max_bias_or_sem + max_bias_or_sem * 0.1
    x = np.linspace(-limit, limit, 1000)
    y = np.linspace(-limit, limit, 1000)
    X, Y = np.meshgrid(x, y)

    # Calculate distance from origin
    Z = np.sqrt(X**2 + Y**2)

    # Create a contour plot so that there are 10 contours
    d_error = limit / 8
    # Round error steps up to nearest 0.05 kcal/mol
    # d_error = np.ceil(d_error / 0.05) * 0.05
    d_error = round(d_error, 2)
    contour_levels = np.arange(0, np.max(Z), d_error)
    contourf = ax.contourf(X, Y, Z, levels=contour_levels, cmap='viridis')

    # Add on the equilibration detection results
    for i, method in enumerate(automated_methods):
        ax.errorbar(biases[i], sems[i], xerr=[[bias_cis_lower[i]], [bias_cis_upper[i]]], yerr=[[sem_cis_lower[i]], [sem_cis_upper[i]]], 
                                              fmt='none', ecolor='black', capsize=5, markerfacecolor='white', zorder=1)

        ax.scatter(biases[i], sems[i], label=method, edgecolors='black', linewidth=1, s=50, zorder=2, color=colors[i])

    # Set x and y limits
    negative_limit = -limit * 0.02
    ax.set_xlim([negative_limit, limit])
    ax.set_ylim([negative_limit, limit])

    # Add a horizontal colour bar below the plot
    cbar = plt.colorbar(contourf, orientation='horizontal', location='top')
    cbar.set_label("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle )$ / kcal mol$^{-1}$")

    # Set x and y labels, and force aspect ratio to be equal
    ax.set_xlabel("$\\langle \\mathrm{Bias}(\\langle \\Delta G \\rangle) \\rangle$ / kcal mol$^{-1}$")
    ax.set_ylabel("$\\mathrm{SD}(\\langle \\Delta G \\rangle)$ / kcal mol$^{-1}$")
    ax.set_aspect('equal', adjustable='box')

# Plot the contour plots
fig, axs = get_subplots(systems, scale_y=1.2)

for i, system in enumerate(systems):
    ax = axs[i]
    plot_contour_plot_on_axis(ax, "standard", system)
    ax.set_title(system, pad=60)

fig.tight_layout()

# Set the legend on the last plot
axs[-2].legend(loc='center left', bbox_to_anchor=(1.1, 0.5))

fig.savefig("output/bias_vs_sem_contour_plots.png", dpi=300, bbox_inches='tight')

## Overall Plots

Combine the plots of error components, discard times, and RMSEs.

In [None]:
# Create a grid of 5 sections, one for each system. For each section, create a grid of 4 plots.

fig = plt.figure(figsize=(30, 6))

# Gridspec with 2 rows and 10 columns
# gs_outer = gridspec.GridSpec(1, 5, figure=fig)
subfigs = fig.subfigures(1, 5)

axs = []
for i, system in enumerate(systems):

    # Get subfigure, grid spec, and add title
    subfig = subfigs[i]
    gs = gridspec.GridSpec(2, 2, figure=subfig, hspace=0.05, wspace=0.05)
    subfig.suptitle(system, fontsize=16, fontweight='bold')

    # Create subplots with axes shared as required
    components_ax = subfig.add_subplot(gs[1,0])
    discard_ax = subfig.add_subplot(gs[0,0], sharex=components_ax)
    rmse_ax = subfig.add_subplot(gs[1,1], sharey=components_ax)
    unused_ax = subfig.add_subplot(gs[0,1])
    
    # Plot/ delete axes
    plot_error_components_on_ax(components_ax, "standard", system, 100)
    plot_discard_times_on_ax(discard_ax, "standard", system, 100)
    plot_rmses_on_ax(rmse_ax, "standard", system, 100)
    subfig.delaxes(unused_ax)

    # Remove unnecessary labels/ titles
    components_ax.set_title("")
    if i == 0:
        components_ax.legend(bbox_to_anchor=(-0.05, -0.3), loc='upper left')
    rmse_ax.set_title("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle)$")
    rmse_ax.set_ylabel("")
    discard_ax.set_title("Truncation Time")
    discard_ax.set_xlabel("")
    # Hide the numbers on the x axis for the discard times plot, but don't set labels to empty as this effects the RMSE plot
    discard_ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    rmse_ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
    
    # If this isn't the first system, remove the labels from the discard times plot
    if i != 0:
        discard_ax.set_yticklabels([])
        

fig.savefig("output/combined_plots.png", dpi=300, bbox_inches='tight')


In [None]:
# Create a grid of 5 sections, one for each system. For each section, create a grid of 4 plots.

fig = plt.figure(figsize=(21, 14))

# Gridspec with 2 rows and 10 columns
# gs_outer = gridspec.GridSpec(1, 5, figure=fig)
subfigs = fig.subfigures(2, 3)
subfigs = subfigs.flatten()

axs = []
for i, system in enumerate(systems):

    # Get subfigure, grid spec, and add title
    subfig = subfigs[i]
    gs = gridspec.GridSpec(2, 2, figure=subfig, hspace=0.05, wspace=0.05)
    subfig.suptitle(system, fontsize=16, fontweight='bold')

    # Create subplots with axes shared as required
    components_ax = subfig.add_subplot(gs[1,0])
    discard_ax = subfig.add_subplot(gs[0,0], sharex=components_ax)
    rmse_ax = subfig.add_subplot(gs[1,1], sharey=components_ax)
    unused_ax = subfig.add_subplot(gs[0,1])
    
    # Plot/ delete axes
    plot_error_components_on_ax(components_ax, "standard", system, 100)
    plot_discard_times_on_ax(discard_ax, "standard", system, 100)
    plot_rmses_on_ax(rmse_ax, "standard", system, 100)
    subfig.delaxes(unused_ax)

    # Remove unnecessary labels/ titles
    components_ax.set_title("")
    if i == 2:
        components_ax.legend(bbox_to_anchor=(0.05, -0.5), loc='upper left')

    # Remove x labels from first two plots
    if i < 2:
        rmse_ax.set_xticklabels([])

    rmse_ax.set_title("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle)$")
    rmse_ax.set_ylabel("")
    discard_ax.set_title("Truncation Time")
    discard_ax.set_xlabel("")
    # Hide the numbers on the x axis for the discard times plot, but don't set labels to empty as this effects the RMSE plot
    discard_ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    rmse_ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
    
    # If this isn't the first system in a row, remove the labels from the discard times plot
    if i not in [0, 3]:
        discard_ax.set_yticklabels([])
        

fig.savefig("output/combined_plots_reformatted.png", dpi=300, bbox_inches='tight')


In [None]:
# Repeat above plot, but show distributions of unsigned
# errors instead of RMSE of dataset

fig = plt.figure(figsize=(30, 6))

# Gridspec with 2 rows and 10 columns
# gs_outer = gridspec.GridSpec(1, 5, figure=fig)
subfigs = fig.subfigures(1, 5)

axs = []
for i, system in enumerate(systems):

    # Get subfigure, grid spec, and add title
    subfig = subfigs[i]
    gs = gridspec.GridSpec(2, 2, figure=subfig, hspace=0.05, wspace=0.05)
    subfig.suptitle(system, fontsize=16, fontweight='bold')

    # Create subplots with axes shared as required
    components_ax = subfig.add_subplot(gs[1,0])
    discard_ax = subfig.add_subplot(gs[0,0], sharex=components_ax)
    rmse_ax = subfig.add_subplot(gs[1,1], sharey=components_ax)
    unused_ax = subfig.add_subplot(gs[0,1])
    
    # Plot/ delete axes
    plot_error_components_on_ax(components_ax, "standard", system, 100)
    plot_discard_times_on_ax(discard_ax, "standard", system, 100)
    plot_unsigned_error_distribution_on_ax(rmse_ax, "standard", system)
    subfig.delaxes(unused_ax)

    # Remove unnecessary labels/ titles
    components_ax.set_title("")
    if i == 0:
        components_ax.legend(bbox_to_anchor=(-0.05, -0.3), loc='upper left')
    rmse_ax.set_title("Unsigned Errors")
    rmse_ax.set_ylabel("")
    discard_ax.set_title("Time Discarded")
    discard_ax.set_xlabel("")
    # Hide the numbers on the x axis for the discard times plot, but don't set labels to empty as this effects the RMSE plot
    discard_ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    rmse_ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
    
    # If this isn't the first system, remove the labels from the discard times plot
    if i != 0:
        discard_ax.set_yticklabels([])
        

fig.savefig("output/combined_plots_unsigned_errors.png", dpi=300, bbox_inches='tight')


## Summary Plots for All Datasets

Redo the above plots, but show results for all datasets.

In [None]:
# First, let's define a function to supply a nice set of axes
def get_subplots_all_datasets(systems: list[str], datasets: list[str], scale_x: float = 1, scale_y: float = 1) -> tuple[plt.Figure, list[plt.Axes]]:
    # We want the same number of columns as systems, and the same number of rows as datasets
    n_cols = len(systems)
    n_rows = len(datasets)
    return plt.subplots(n_rows, n_cols, figsize=(scale_x*3 * n_cols, scale_y*3*n_rows))


In [None]:
# Now, let's plot the RMSEs for all datasets
fig, axs = get_subplots_all_datasets(systems, datasets)

for i, dataset in enumerate(datasets):
    for j, system in enumerate(systems):
        ax = axs[i, j]
        plot_theoretical_rmse_on_axis(ax, dataset, system)
        ax.set_title(sanitise_name(f"{system}\n{dataset}"))

fig.tight_layout()

middle_bottom_ax = axs[4, 2]
middle_bottom_ax.legend(bbox_to_anchor=(0.1, -0.4), loc='upper left')

fig.savefig("output/theoretical_vs_empirical_rmse_all_datasets.png", dpi=300, bbox_inches='tight')

In [None]:
# Repeat, but this time for the error components

fig, axs = get_subplots_all_datasets(systems, datasets)

for i, dataset in enumerate(datasets):
    for j, system in enumerate(systems):
        ax = axs[i, j]
        plot_error_components_on_ax(ax, dataset, system, show_min=True)
        ax.set_title(sanitise_name(f"{system} \n {dataset}"))

fig.tight_layout()

middle_bottom_ax = axs[4, 2]
middle_bottom_ax.legend(bbox_to_anchor=(0.1, -0.4), loc='upper left')

fig.savefig("output/error_components_all_datasets.png", dpi=300, bbox_inches='tight')

In [None]:
# Repeat, but this time for the discard times

fig, axs = get_subplots_all_datasets(systems, datasets)

for i, dataset in enumerate(datasets):
    for j, system in enumerate(systems):
        ax = axs[i, j]
        plot_discard_times_on_ax(ax, dataset, system, 100)
        ax.set_title(sanitise_name(f"{system} \n {dataset}"))

        # Remove labels unless in left-most column
        if j != 0:
            ax.set_yticklabels([])

        # Rmove x axis label unless on bottom row
        if i != len(datasets) - 1:
            ax.set_xlabel("")

fig.tight_layout()

middle_bottom_ax = axs[4, 2]
middle_bottom_ax.legend(bbox_to_anchor=(0.1, -0.4), loc='upper left')

fig.savefig("output/discard_times_all_datasets.png", dpi=300, bbox_inches='tight')

In [None]:
# Plot the RMSEs for all datasets

fig, axs = get_subplots_all_datasets(systems, datasets)

for i, dataset in enumerate(datasets):
    for j, system in enumerate(systems):
        ax = axs[i, j]
        plot_rmses_on_ax(ax, dataset, system, 100)
        ax.set_title(sanitise_name(f"{system} \n {dataset}"))

        # Remove x tick labels unless in bottom row
        if i != len(datasets) - 1:
            ax.set_xticklabels([])

        # Only add y axis label to left-most column
        if j == 0:
            ax.set_ylabel("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle)$ / kcal mol$^{-1}$")

fig.tight_layout()

middle_bottom_ax = axs[4, 2]
middle_bottom_ax.legend(bbox_to_anchor=(0.1, -1.7), loc='upper left')

fig.savefig("output/rmses_all_datasets.png", dpi=300, bbox_inches='tight')

In [None]:
# Unsigned errors

fig, axs = get_subplots_all_datasets(systems, datasets)

for i, dataset in enumerate(datasets):
    for j, system in enumerate(systems):
        ax = axs[i, j]
        plot_unsigned_error_distribution_on_ax(ax, dataset, system)
        ax.set_title(sanitise_name(f"{system} \n {dataset}"))

        # Remove x tick labels unless in bottom row
        if i != len(datasets) - 1:
            ax.set_xticklabels([])

        # Only add y axis label to left-most column
        if j == 0:
            ax.set_ylabel("Unsigned Error / kcal mol$^{-1}$")

fig.tight_layout()

middle_bottom_ax = axs[4, 2]
middle_bottom_ax.legend(bbox_to_anchor=(0.1, -1.7), loc='upper left')

fig.savefig("output/unsigned_errors_all_datasets.png", dpi=300, bbox_inches='tight')

In [None]:
# Contour plots

fig, axs = get_subplots_all_datasets(systems, datasets, scale_y = 1.2)

for i, dataset in enumerate(datasets):
    for j, system in enumerate(systems):
        ax = axs[i, j]
        plot_contour_plot_on_axis(ax, dataset, system)
        ax.set_title(sanitise_name(f"{system} \n {dataset}"), pad=60)

fig.tight_layout()

middle_bottom_ax = axs[4, 2]
middle_bottom_ax.legend(bbox_to_anchor=(-0.2, -0.4), loc='upper left')

fig.savefig("output/bias_vs_sem_contour_plots_all_datasets.png", dpi=300, bbox_inches='tight')

In [None]:
# Overall plots for all datasets

fig = plt.figure(figsize=(30, 30))

subfigs = fig.subfigures(5, 5)

axs = []
for j, dataset in enumerate(datasets):
    for i, system in enumerate(systems):

        # Get subfigure, grid spec, and add title
        subfig = subfigs[j, i]
        gs = gridspec.GridSpec(2, 2, figure=subfig, hspace=0.05, wspace=0.05)
        subfig.suptitle(sanitise_name(f"{system} \n {dataset}"), fontsize=16, fontweight='bold')

        # Create subplots with axes shared as required
        components_ax = subfig.add_subplot(gs[1,0])
        discard_ax = subfig.add_subplot(gs[0,0], sharex=components_ax)
        rmse_ax = subfig.add_subplot(gs[1,1], sharey=components_ax)
        unused_ax = subfig.add_subplot(gs[0,1])
        
        # Plot/ delete axes
        plot_error_components_on_ax(components_ax, dataset, system)
        plot_discard_times_on_ax(discard_ax, dataset, system)
        plot_rmses_on_ax(rmse_ax, dataset, system)
        subfig.delaxes(unused_ax)

        # Remove unnecessary labels/ titles
        components_ax.set_title("")
        if i == 0 and j ==4:
            components_ax.legend(bbox_to_anchor=(-0.05, -0.3), loc='upper left')
        rmse_ax.set_title("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle)$")
        rmse_ax.set_ylabel("")
        discard_ax.set_title("Truncation Time")
        discard_ax.set_xlabel("")

        # Remove the RMSE bar plot tick labels unless this is the bottom row
        if j != 4:
            rmse_ax.set_xticklabels([])

        # Hide the numbers on the x axis for the discard times plot, but don't set labels to empty as this effects the RMSE plot
        discard_ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
        rmse_ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
        
        # If this isn't the first system, remove the labels from the discard times plot
        if i != 0:
            discard_ax.set_yticklabels([])
        

fig.savefig("output/combined_plots_all_datasets.png", dpi=300, bbox_inches='tight')


In [None]:
# Overall plots for all datasets

fig = plt.figure(figsize=(30, 30))

subfigs = fig.subfigures(5, 5)

axs = []
for j, dataset in enumerate(datasets):
    for i, system in enumerate(systems):

        # Get subfigure, grid spec, and add title
        subfig = subfigs[j, i]
        gs = gridspec.GridSpec(2, 2, figure=subfig, hspace=0.05, wspace=0.05)
        subfig.suptitle(sanitise_name(f"{system} \n {dataset}"), fontsize=16, fontweight='bold')

        # Create subplots with axes shared as required
        components_ax = subfig.add_subplot(gs[1,0])
        discard_ax = subfig.add_subplot(gs[0,0], sharex=components_ax)
        rmse_ax = subfig.add_subplot(gs[1,1], sharey=components_ax)
        unused_ax = subfig.add_subplot(gs[0,1])
        
        # Plot/ delete axes
        plot_error_components_on_ax(components_ax, dataset, system)
        plot_discard_times_on_ax(discard_ax, dataset, system)
        plot_unsigned_error_distribution_on_ax(rmse_ax, "standard", system)
        subfig.delaxes(unused_ax)

        # Remove unnecessary labels/ titles
        components_ax.set_title("")
        if i == 0 and j ==4:
            components_ax.legend(bbox_to_anchor=(-0.05, -0.3), loc='upper left')
        rmse_ax.set_title("$\\mathrm{RMSE}(\\langle \\Delta G \\rangle)$")
        rmse_ax.set_ylabel("")
        discard_ax.set_title("Truncation Time")
        discard_ax.set_xlabel("")

        # Remove the RMSE bar plot tick labels unless this is the bottom row
        if j != 4:
            rmse_ax.set_xticklabels([])

        # Hide the numbers on the x axis for the discard times plot, but don't set labels to empty as this effects the RMSE plot
        discard_ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
        rmse_ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
        
        # If this isn't the first system, remove the labels from the discard times plot
        if i != 0:
            discard_ax.set_yticklabels([])
        

fig.savefig("output/combined_plots_all_datasets_unsigned_errors.png", dpi=300, bbox_inches='tight')


In [None]:
# Repeat for the discard times

fig, axs = get_subplots_all_datasets(systems, datasets)



In [None]:
# Example systems list
systems = ['System 1', 'System 2', 'System 3', 'System 4', 'System 5']

# Create a figure
fig = plt.figure(figsize=(30, 12))

# Create a mosaic layout
mosaic = """
AABBCCDDEE
AABBCCDDff
"""

# Create subplots using subplot_mosaic
ax_dict = fig.subplot_mosaic(mosaic, gridspec_kw={'hspace': 0.4, 'wspace': 0.4})

# Create subplots for each section
axs = []
for i, system in enumerate(systems):
    # Calculate the starting column index for each section
    col_start = i * 2

    # Create subplots with axes shared as required
    components_ax = ax_dict[f'{chr(65 + i)}B']
    discard_ax = ax_dict[f'{chr(65 + i)}A']
    rmse_ax = ax_dict[f'{chr(65 + i)}D']
    unused_ax = ax_dict[f'{chr(65 + i)}C']
    
    # Add all the axes to the list
    axs.append([discard_ax, components_ax, rmse_ax])

    # Plot/delete axes
    plot_error_components_on_ax(components_ax, "standard", system, 100)
    plot_discard_times_on_ax(discard_ax, "standard", system, 100)
    plot_rmses_on_ax(rmse_ax, "standard", system, 100)
    fig.delaxes(unused_ax)

    # Remove unnecessary labels/titles
    components_ax.set_title("")
    rmse_ax.set_title("")
    rmse_ax.set_ylabel("")
    discard_ax.set_xlabel("")
    # Hide the numbers on the x axis for the discard times plot, but don't set labels to empty as this affects the RMSE plot
    discard_ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    rmse_ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)

    # If this isn't the first system, remove the labels from the discard times plot
    if i != 0:
        discard_ax.set_yticklabels([])

    # Add a title for each section
    fig.text(0.5, 0.5 - i * 0.2, f'System: {system}', ha='center', va='center', fontsize=16, transform=fig.transFigure)

# Adjust the layout
plt.tight_layout()
plt.show()

In [None]:
np.linspace(1,4,3)

In [None]:
overall_results["block_averaged"]["MIF"]["Uncorrelated Estimate"]

In [None]:
systems

In [None]:
# Plot bar plots of the RMSE for MIF only
# Make a plot for each dataset (5)

def get_subplots(systems: list[str]) -> tuple[plt.Figure, list[plt.Axes]]:
    # Plot two columns side-by-side
    n_cols = min(2, len(systems))
    n_rows = int(np.ceil(len(systems) / n_cols))
    # fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 5*n_rows))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3*n_rows))
    if len(systems) == 1:
        axs = [axs]
    else:
        axs = axs.flatten()

    # Set titles
    for i, system in enumerate(systems):
        axs[i].set_title(system)

    # Remove any unused axes
    for i in range(len(systems), len(axs)):
        fig.delaxes(axs[i])

    return fig, axs

automated_methods = [method for method in methods if "Discard" not in method]
fixed_methods = [method for method in methods if "Discard" in method]
fig, axs = get_subplots(datasets)

# For MIF only, plot the RMSE for each dataset with a bar plot on a new axis.
SYSTEM = "MIF"

for i, dataset in enumerate(datasets):
    ax = axs[i]
    rmse = [overall_results[dataset][SYSTEM][method]["rmse"] for method in automated_methods]
    cis = [overall_results[dataset][SYSTEM][method]["rmse_ci"] for method in automated_methods]
    # Convert CIs from absolute values to relative values
    cis_lower = abs(np.array(cis)[:,0] - np.array(rmse))
    cis_upper = abs(np.array(cis)[:,1] - np.array(rmse))
    ax.bar(automated_methods, rmse, yerr=[cis_lower, cis_upper], capsize=5)
    # Rotate labels 90 degrees
    ax.set_xticklabels(methods, rotation=90)
    ax.set_ylabel("RMSE / kcal mol$^{-1}$")
    ax.set_title(dataset)

    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

# Stop the long labels overlapping
fig.subplots_adjust(wspace=0.3)
fig.savefig(f"output/rmse_bar_plots_{SYSTEM.lower()}.png", dpi=300, bbox_inches="tight")

In [None]:
# As above, but plot all datasets on the same plot

DATASET = "standard"

fig, axs = get_subplots(systems)

# For MIF only, plot the RMSE for each dataset with a bar plot on a new axis.
for i, system in enumerate(systems):
    ax = axs[i]
    rmse = [overall_results[DATASET][system][method]["rmse"] for method in automated_methods]
    cis = [overall_results[DATASET][system][method]["rmse_ci"] for method in automated_methods]

    # Convert CIs from absolute values to relative values
    cis_lower = abs(np.array(cis)[:,0] - np.array(rmse))
    cis_upper = abs(np.array(cis)[:,1] - np.array(rmse))
    ax.bar(automated_methods, rmse, yerr=[cis_lower, cis_upper], capsize=5, alpha=0.7, label=system)

    # Rotate labels 90 degrees
    ax.set_xticklabels(methods, rotation=90)
    ax.set_ylabel("RMSE / kcal mol$^{-1}$")
    ax.set_title(system)

    # Get the lowest RMSE from any fixed-time method
    fixed_rmse = [overall_results[DATASET][system][method]["rmse"] for method in fixed_methods]
    # Get the name of the lowest RMSE method
    best_fixed_method = fixed_methods[np.argmin(fixed_rmse[:-1])]
    ax.axhline(min(fixed_rmse), color="red", linestyle="--")
    # Add text with the best fixed-time method
    # ax.text(0.5, min(fixed_rmse), f"Best fixed-time method: {best_fixed_method}", ha="center", va="bottom")

    # Plot a blue line at the RMSE of the 0.1 fraction discarded method and 0.4 fraction discarded method
    rmse_01 = overall_results[DATASET][system]["Discard Fraction 0.1"]["rmse"]
    rmse_04 = overall_results[DATASET][system]["Discard Fraction 0.4"]["rmse"]
    ax.axhline(rmse_01, color="blue", linestyle="--")
    ax.axhline(rmse_04, color="green", linestyle="--")

    

    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

# Stop the long labels overlapping
fig.subplots_adjust(wspace=0.3)
fig.savefig("output/rmse_bar_plots_all_systems.png", dpi=300, bbox_inches="tight")

In [None]:
fig, axs = get_subplots(systems)

DATASET = "standard"

# For MIF only, plot the RMSE for each dataset with a bar plot on a new axis.
for i, system in enumerate(systems):
    ax = axs[i]
    rmse = [overall_results[DATASET][system][method]["rmse"] for method in fixed_methods]
    cis = [overall_results[DATASET][system][method]["rmse_ci"] for method in fixed_methods]

    # Convert CIs from absolute values to relative values
    cis_lower = abs(np.array(cis)[:,0] - np.array(rmse))
    cis_upper = abs(np.array(cis)[:,1] - np.array(rmse))
    ax.bar(fixed_methods, rmse, yerr=[cis_lower, cis_upper], capsize=5, alpha=0.7, label=system)

    # Rotate labels 90 degrees
    ax.set_xticklabels(fixed_methods, rotation=90)
    ax.set_ylabel("RMSE / kcal mol$^{-1}$")
    ax.set_title(system)

    # Get the lowest RMSE from any fixed-time method
    fixed_rmse = [overall_results[DATASET][system][method]["rmse"] for method in fixed_methods]
    # Get the name of the lowest RMSE method
    best_fixed_method = fixed_methods[np.argmin(fixed_rmse)]
    ax.axhline(min(fixed_rmse), color="red", linestyle="--")
    # Add text with the best fixed-time method
    # ax.text(0.5, min(fixed_rmse), f"Best fixed-time method: {best_fixed_method}", ha="center", va="bottom")
    

    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

# Stop the long labels overlapping
fig.subplots_adjust(wspace=0.3)
fig.savefig("output/rmse_bar_plots_all_systems_fixed_methods.png", dpi=300, bbox_inches="tight")

In [None]:
# Plot the distributions of frac_discarded for each system for the "standard" dataset
# Use seaborn to plot the distributions

import seaborn as sns
import pandas as pd

DATASET  = "standard"

fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    df_means = pd.DataFrame({method: distributions[DATASET][system][method]["fracs"] for method in automated_methods})
    sns.violinplot(data=df_means, ax=ax)
    ax.set_title(system)
    ax.set_ylabel("Fraction of data discarded")
    ax.set_xticklabels(methods, rotation=90)

    # Figure out the optimal (lowest RMSE) fixed truncation time and draw a horizontal line
    fixed_rmse = [overall_results[DATASET][system][method]["rmse"] for method in fixed_methods]
    best_fixed_method = fixed_methods[np.argmin(fixed_rmse[:-1])]
    best_fixed_time = float(best_fixed_method.split(" ")[-1])
    ax.axhline(best_fixed_time, color="red", linestyle="--")
    
    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

# Stop the long labels overlapping
fig.subplots_adjust(wspace=0.3)



In [None]:
# Plot distributions of RMSEs
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    df_means = pd.DataFrame({method: distributions["standard"][system][method]["mses"] for method in automated_methods})
    sns.violinplot(data=df_means, ax=ax)
    ax.set_title(system)
    ax.set_ylabel("Mean / kcal mol$^{-1}$")
    ax.set_xticklabels(methods, rotation=90)

    # Figure out the optimal (lowest RMSE) fixed truncation time and draw a horizontal line
    fixed_rmse = [overall_results["standard"][system][method]["mue"] for method in fixed_methods]
    best_fixed_method = fixed_methods[np.argmin(fixed_rmse[:-1])]
    best_fixed_time = float(best_fixed_method.split(" ")[-1])
    ax.axhline(best_fixed_time, color="red", linestyle="--")
    
    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

# Stop the long labels overlapping
fig.subplots_adjust(wspace=0.3)


In [None]:
# Plot distributions of RMSEs
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    df_means = pd.DataFrame({method: distributions["standard"][system][method]["means"] for method in methods})
    sns.violinplot(data=df_means, ax=ax)
    ax.set_title(system)
    ax.set_ylabel("Mean / kcal mol$^{-1}$")
    ax.set_xticklabels(methods, rotation=90)

    # Figure out the optimal (lowest RMSE) fixed truncation time and draw a horizontal line
    fixed_rmse = [overall_results["standard"][system][method]["mue"] for method in fixed_methods]
    best_fixed_method = fixed_methods[np.argmin(fixed_rmse[:-1])]
    best_fixed_time = float(best_fixed_method.split(" ")[-1])
    ax.axhline(best_fixed_time, color="red", linestyle="--")
    
    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

# Stop the long labels overlapping
fig.subplots_adjust(wspace=0.3)


In [None]:
# Plot distributions of RMSEs
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    ax = axs[i]
    df_means = pd.DataFrame({method: distributions["standard"][system][method]["mues"] for method in automated_methods})
    sns.violinplot(data=df_means, ax=ax)
    ax.set_title(system)
    ax.set_ylabel("MUE / kcal mol$^{-1}$")
    ax.set_xticklabels(methods, rotation=90)

    # Figure out the optimal (lowest RMSE) fixed truncation time and draw a horizontal line
    fixed_rmse = [overall_results["standard"][system][method]["mue"] for method in fixed_methods]
    best_fixed_method = fixed_methods[np.argmin(fixed_rmse[:-1])]
    best_fixed_time = float(best_fixed_method.split(" ")[-1])
    ax.axhline(best_fixed_time, color="red", linestyle="--")
    
    # Remove x labels from first 3 plots
    if i < 3:
        ax.set_xticklabels([])

    # Remove y labels from right column
    if i % 2 == 1:
        ax.set_ylabel("")

# Stop the long labels overlapping
fig.subplots_adjust(wspace=0.3)


In [None]:
fixed_rmse = [overall_results["standard"][system][method]["rmse"] for method in fixed_methods]

In [None]:
fixed_rmse

In [None]:
for i, dataset in enumerate(datasets):
    ax = axs[i]
    rmse = [overall_results[dataset]["MDM2-PIP2"][method]["rmse"] for method in automated_methods]
    cis = [overall_results[dataset]["MDM2-PIP2"][method]["rmse_ci"] for method in automated_methods]
    # Convert CIs from absolute values to relative values
    cis_lower = abs(np.array(cis)[:,0] - np.array(rmse))
    cis_upper = abs(np.array(cis)[:,1] - np.array(rmse))
    ax.bar(automated_methods, rmse, yerr=[cis_lower, cis_upper], capsize=5)
    # Rotate labels 90 degrees
    ax.set_xticklabels(methods, rotation=90)
    ax.set_ylabel("RMSE")
    ax.set_title(dataset)

fig.tight_layout()
# Stop the long labels overlapping
# fig.subplots_adjust(hspace=1.5, wspace=0.3)
fig.savefig("output/rmse_bar_plots.png")

In [None]:
cis_upper

In [None]:
# Compute bias and variance
overall_results = {}
distributions = {}

for method in methods:
    means = [data[i][method]["mean"] for i in synthetic_data_8_ns]
    bias = np.mean(means) 
    variance = np.var(means)
    # Get the overall mean absolute error
    maes = np.abs(means)
    mae = np.mean(maes)
    mses = np.square(means)
    rmse = np.sqrt(np.mean(mses))
    # Time taken
    #times = [synthetic_data_8_ns[i][method]["time"] for i in synthetic_data_8_ns]
    #time_taken = np.mean(times)
    # Frac discarded
    fracs = [synthetic_data_8_ns[i][method]["frac_discarded"] for i in synthetic_data_8_ns]
    frac_discarded = np.mean(fracs)
    overall_results[method] = {"bias": bias, "variance": variance, "mae": mae, "rmse": rmse, "frac_discarded": frac_discarded}
    distributions[method] = {"means": means, "fracs": fracs, "maes": maes, "mses": mses}

# Get the 95 % confidence interval for the mean for each of the properties above
for method in tqdm.tqdm(methods):
    # Get the 95 % confidence interval for the mean
    mean_ci, mean_distr = bootstrap(distributions[method]["means"])
    frac_ci, frac_distr = bootstrap(distributions[method]["fracs"])
    # Get the 95 % confidence interval for the variance, mae, rmse, and bias
    var_ci, var_distr = bootstrap(distributions[method]["means"], np.var)
    mae_ci, mae_distr = bootstrap(distributions[method]["means"], lambda x: np.abs(x))
    rmse_ci, rmse_distr = bootstrap(distributions[method]["means"], lambda x: np.square(x))
    # Sqrt things to get proper rmse
    rmse_ci = np.sqrt(rmse_ci)
    rmse_distr = np.sqrt(rmse_distr)
    overall_results[method]["bias_ci"] = [abs(val - overall_results[method]["bias"]) for val in mean_ci]
    overall_results[method]["frac_ci"] = [abs(val - overall_results[method]["frac_discarded"]) for val in frac_ci]
    overall_results[method]["var_ci"] = [abs(val - overall_results[method]["variance"]) for val in var_ci]
    overall_results[method]["mae_ci"] = [abs(val - overall_results[method]["mae"]) for val in mae_ci]
    overall_results[method]["rmse_ci"] = [abs(val - overall_results[method]["rmse"]) for val in rmse_ci]
    # Add the distributions
    distributions[method]["bias_distr"] = mean_distr
    distributions[method]["frac_distr"] = frac_distr
    distributions[method]["var_distr"] = var_distr
    distributions[method]["mae_distr"] = mae_distr
    distributions[method]["rmse_distr"] = rmse_distr
