# Analysis of Adaptive Runs

In [None]:
import a3fe as a3 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams, rcParamsDefault
import pickle
rcParams.update(rcParamsDefault)
plt.style.use("seaborn-v0_8-colorblind")
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
from typing import List, Tuple, Dict, Callable, Union
%matplotlib inline
from scipy.stats import linregress, kruskal, t, sem, wilcoxon
import logging
from matplotlib import gridspec
import pymbar

# Dict mapping directory names to calculation names
ligs = {
    "T4L": "t4l",
    "MIF": "mif_180_anti",
    "MDM2-PIP2": "mdm2_pip2_short",
    "PDE2A": "pde2a_p10",
    "MDM2-Nutlin": "mdm2_nutlin_notprot",
}

REF_COST = 0.21 # GPU hours per ns

# Define how we want to rename some systems for display
rename = {"PDE2A": "PDE2a", "MDM2-PIP2": "MDM2-Pip2"}

## Load in the results of the slow analyses

In [None]:
files = {
    "equil_dgs": "final_analysis/equil_dgs.pkl",
    "costs": "final_analysis/costs.pkl",
    "equil_times": "final_analysis/equil_times.pkl",
    "final_dgs_30_nonadapt": "final_analysis/final_dGs_30.pkl",
    "final_dgs_all": "final_analysis/final_dGs_all.pkl",
    "restraint_corrections": "final_analysis/restraint_corrections.pkl",
    "comparitive_conv_data": "final_analysis/comparitive_conv_data.pkl",
    "lam_vals": "final_analysis/lam_vals.pkl",
    "sampling_times": "final_analysis/sampling_times.pkl",
    "dgs_conv_nonequil": "final_analysis/dgs_conv_nonequil.pkl",
    "dgs_conv_nonadapt_nonequil": "final_analysis/dgs_conv_nonadapt_nonequil_times.pkl",
    "gpu_times": "final_analysis/gpu_times.pkl",
    "comparitive_conv_data_calcs": "final_analysis/comparitive_conv_data_calcs.pkl",
}

for var, file in files.items():
    with open(file, "rb") as f:
        globals()[var] = pickle.load(f)

## Notebook Layout

- Equilibration analysis
- Overall results summary table
- Spacing of windows
- Allocation of sampling time
- Convergence of free energy estimates
- Comparison of uncertainties

## Equilibration Analysis

We want to collect all the equilibration plots onto a summary figure and we also want the data so that we can re-analyse it with Chodera's method and compare.

In [None]:
def get_95_ci(dgs: List[float]) -> Tuple[float, float]:
    """Get the 95% confidence interval for a list of data."""
    # Get the mean and standard deviation of the data
    mean_free_energy = np.mean(dgs, axis=0)
    conf_int = (
    t.interval(
        0.95,
        len(dgs) - 1,
        mean_free_energy,
        scale=sem(dgs),
    )[1]
    - mean_free_energy
    )  # 95 % C.I.
    return conf_int

for lig in equil_dgs:
    # Multiply the times by the costs
    for leg in equil_dgs[lig]:
        for stage in equil_dgs[lig][leg]:
            equil_dgs[lig][leg][stage]["times"] = np.array(equil_dgs[lig][leg][stage]["times"]) * costs[lig][leg] 

In [None]:
# Get the equilibration times with Chodera's method
equil_times_chodera = {}
for lig in ligs:
    equil_times_chodera[lig] = {}
    for leg in equil_dgs[lig]:
        equil_times_chodera[lig][leg] = {}
        for stage in equil_dgs[lig][leg]:
            # Get mean dgs
            mean_dgs = np.mean(equil_dgs[lig][leg][stage]["dgs"], axis = 0)
            idx0, _, _ = pymbar.timeseries.detectEquilibration(mean_dgs)
            equil_times_chodera[lig][leg][stage] = idx0

In [None]:
# Repeat above, but get Chodera equilibraion times for each individual run
equil_times_chodera_indiv = {}
for lig in ligs:
    equil_times_chodera_indiv[lig] = {}
    for leg in equil_dgs[lig]:
        equil_times_chodera_indiv[lig][leg] = {}
        for stage in equil_dgs[lig][leg]:
            equil_times_chodera_indiv[lig][leg][stage] = []
            for dg in equil_dgs[lig][leg][stage]["dgs"]:
                idx0, _, _ = pymbar.timeseries.detectEquilibration(dg)
                equil_times_chodera_indiv[lig][leg][stage].append(idx0)

In [None]:
# Plot the equilibration times for each stage for each leg for each system
fig = plt.figure(figsize = (19, 30), dpi = 600)
gs = gridspec.GridSpec(10, 5, figure = fig)
for i, lig in enumerate(ligs):
    for j, leg in enumerate(equil_times[lig]):
        for k, stage in enumerate(equil_times[lig][leg]):
            combined_row_ind = 3*j + k
            gs0 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = gs[combined_row_ind, i], hspace = 0)
            ax1 = fig.add_subplot(gs0[0])
            ax2 = fig.add_subplot(gs0[1], sharex = ax1)

            # Shared info
            stage_str = stage.split(".")[1].lower().capitalize()
            leg_str = leg.split(".")[1].lower().capitalize()
            times = equil_times[lig][leg][stage]["times"]
            p_vals = equil_times[lig][leg][stage]["p_vals"]
            equil_time = equil_times[lig][leg][stage]["equil_time"]

            # On the first axis, plot the blocked free energy changes
            av_dg = np.mean(equil_dgs[lig][leg][stage]["dgs"], axis = 0)
            ci = get_95_ci(equil_dgs[lig][leg][stage]["dgs"])
            av_dg_times = equil_dgs[lig][leg][stage]["times"][0]
            ax1.plot(av_dg_times, av_dg, label = "Free energy change")
            # Fill between the confidence intervals
            ax1.fill_between(
                av_dg_times,
                av_dg - ci,
                av_dg + ci,
                alpha=0.2,
            )
            # Vertical line at the equilibration time
            ax1.axvline(equil_time, linestyle = "--", label = "Equil. time (paired $t$)", color = "black")
            lig_title = rename.get(lig, lig)
            ax1.set_title(f"{lig_title} {leg_str} {stage_str}")
            ax1.set_ylabel("$\Delta G$ (kcal/mol)")
            # Hide all x axis labels to avoid overlap
            ax1.get_xaxis().set_visible(False)
            
            # Plot the 30 ns non-adaptive result
            non_adapt_dgs = final_dgs_30_nonadapt[lig][leg][stage]
            val = non_adapt_dgs[0]
            half_width = non_adapt_dgs[1] 
            # Plot a horizontal line at the value with error bars
            # Get next colour on the colour cycle
            color = next(iter(rcParams['axes.prop_cycle']))['color']
            start = av_dg_times[0]
            end = av_dg_times[-1]
            # Use plot instead of axhline so we can choose where to start and end
            ax1.plot([start, start, end, end], [val, val, val, val], linestyle = "-", label = r"Final Non-Adaptive $\Delta G$", alpha = 0.8, color = color)
            # Fill between for error bar. Make sure the start and end match the data
            ax1.fill_between(
                [start, end],
                val - half_width,
                val + half_width,
                alpha=0.2,
            )

            # On the second axis, plot the p values
            ax2.plot(times, p_vals, marker = "o", label = "Equil. time (paired $t$)")
            ax2.axvline(equil_time, linestyle = "--", label = "Equil. time (paired $t$)", color = "black")
            # Horizontal line at p = 0.05
            ax2.axhline(0.05, linestyle = "--", label = "p = 0.05", color="black")
            ax2.set_xlabel("GPU Hours")
            ax2.set_ylabel("$p$ value")
            # Set y limits at 0 and 1
            ax2.set_ylim(0, 1)
            fig.tight_layout()
            # Hide the label at 1 to avoid clash with upper x axis
            ax2.get_yticklabels()[-1].set_visible(False)

# Add a legend to the figure, off to the right hand side of all the plots
axs = fig.get_axes()
axs[-6].legend(bbox_to_anchor = (1.05, 0.5), loc = "center left", borderaxespad = 0)
fig.tight_layout()
fig.savefig("final_analysis/equil_times_no_chodera.png", dpi = 600, bbox_inches = "tight")

In [None]:
# Plot the equilibration times for each stage for each leg for each system
fig = plt.figure(figsize = (19, 30), dpi = 600)
gs = gridspec.GridSpec(10, 5, figure = fig)
for i, lig in enumerate(ligs):
    for j, leg in enumerate(equil_times[lig]):
        for k, stage in enumerate(equil_times[lig][leg]):
            combined_row_ind = 3*j + k
            gs0 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = gs[combined_row_ind, i], hspace = 0)
            ax1 = fig.add_subplot(gs0[0])
            ax2 = fig.add_subplot(gs0[1], sharex = ax1)

            # Shared info
            stage_str = stage.split(".")[1].lower().capitalize()
            leg_str = leg.split(".")[1].lower().capitalize()
            times = equil_times[lig][leg][stage]["times"]
            p_vals = equil_times[lig][leg][stage]["p_vals"]
            equil_time = equil_times[lig][leg][stage]["equil_time"]

            # On the first axis, plot the blocked free energy changes
            av_dg = np.mean(equil_dgs[lig][leg][stage]["dgs"], axis = 0)
            ci = get_95_ci(equil_dgs[lig][leg][stage]["dgs"])
            av_dg_times = equil_dgs[lig][leg][stage]["times"][0]
            ax1.plot(av_dg_times, av_dg, label = "Free energy change")
            # Fill between the confidence intervals
            ax1.fill_between(
                av_dg_times,
                av_dg - ci,
                av_dg + ci,
                alpha=0.2,
            )
            # Vertical line at the equilibration time
            ax1.axvline(equil_time, linestyle = "--", label = "Equil. time (paired $t$)", color = "black")
            lig_title = rename.get(lig, lig)
            ax1.set_title(f"{lig_title} {leg_str} {stage_str}")
            ax1.set_ylabel("$\Delta G$ (kcal/mol)")
            # Red vertical lines at the individual run Chodera equilibration times
            for z, equil_idx_chodera_indiv in enumerate(equil_times_chodera_indiv[lig][leg][stage]):
                equil_time_chodera_indiv = av_dg_times[equil_idx_chodera_indiv]
                if z !=0:
                    ax1.axvline(equil_time_chodera_indiv, linestyle = ":", color = "red")
                else:
                    ax1.axvline(equil_time_chodera_indiv, linestyle = ":", color = "red", label = "Equil. time (Max $N_\mathrm{Eff.}$, per run)")
            # Vertical line at Chodera's equilibration time
            equil_time_chodera = av_dg_times[equil_times_chodera[lig][leg][stage]]
            ax1.axvline(equil_time_chodera, linestyle = ":", label = "Equil. time (Max $N_\mathrm{Eff.}$, mean)", color = "black")
            # Hide all x axis labels to avoid overlap
            ax1.get_xaxis().set_visible(False)
            
            # Plot the 30 ns non-adaptive result
            non_adapt_dgs = final_dgs_30_nonadapt[lig][leg][stage]
            val = non_adapt_dgs[0]
            half_width = non_adapt_dgs[1] 
            # Plot a horizontal line at the value with error bars
            # Get next colour on the colour cycle
            color = next(iter(rcParams['axes.prop_cycle']))['color']
            start = av_dg_times[0]
            end = av_dg_times[-1]
            # Use plot instead of axhline so we can choose where to start and end
            ax1.plot([start, start, end, end], [val, val, val, val], linestyle = "-", label = r"Final Non-Adaptive $\Delta G$", alpha = 0.8, color = color)
            # Fill between for error bar. Make sure the start and end match the data
            ax1.fill_between(
                [start, end],
                val - half_width,
                val + half_width,
                alpha=0.2,
            )

            # On the second axis, plot the p values
            ax2.plot(times, p_vals, marker = "o", label = "Equil. time (paired $t$)")
            ax2.axvline(equil_time, linestyle = "--", label = "Equil. time (paired $t$)", color = "black")
            # Horizontal line at p = 0.05
            ax2.axhline(0.05, linestyle = "--", label = "p = 0.05", color="black")
            ax2.set_xlabel("GPU Hours")
            ax2.set_ylabel("$p$ value")
            # Set y limits at 0 and 1
            ax2.set_ylim(0, 1)
            fig.tight_layout()
            # Hide the label at 1 to avoid clash with upper x axis
            ax2.get_yticklabels()[-1].set_visible(False)

# Add a legend to the figure, off to the right hand side of all the plots
axs = fig.get_axes()
axs[-6].legend(bbox_to_anchor = (1.05, 0.5), loc = "center left", borderaxespad = 0)
fig.tight_layout()
fig.savefig("final_analysis/equil_times_all.png", dpi = 600, bbox_inches = "tight")
            

In [None]:
# Repeat above, but just plot the bound vanish stages
# Plot the equilibration times for each stage for each leg for each system
fig = plt.figure(figsize = (8, 17), dpi = 600)
gs = gridspec.GridSpec(6, 3, figure = fig)
for i, lig in enumerate(ligs):
    for j, leg in enumerate(["LegType.BOUND"]):
        for stage in ["StageType.VANISH"]: 
            col = i % 2
            row = i // 2
            gs0 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = gs[row, col], hspace = 0)
            ax1 = fig.add_subplot(gs0[0])
            ax2 = fig.add_subplot(gs0[1], sharex = ax1)

            # Shared info
            stage_str = stage.split(".")[1].lower().capitalize()
            leg_str = leg.split(".")[1].lower().capitalize()
            times = equil_times[lig][leg][stage]["times"]
            p_vals = equil_times[lig][leg][stage]["p_vals"]
            equil_time = equil_times[lig][leg][stage]["equil_time"]

            # On the first axis, plot the blocked free energy changes
            av_dg = np.mean(equil_dgs[lig][leg][stage]["dgs"], axis = 0)
            ci = get_95_ci(equil_dgs[lig][leg][stage]["dgs"])
            av_dg_times = equil_dgs[lig][leg][stage]["times"][0]
            ax1.plot(av_dg_times, av_dg, label = "Free energy change")
            # Fill between the confidence intervals
            ax1.fill_between(
                av_dg_times,
                av_dg - ci,
                av_dg + ci,
                alpha=0.2,
            )
            # Vertical line at the equilibration time
            ax1.axvline(equil_time, linestyle = "--", label = "Equil. time (paired $t$)", color = "black")
            ax1.set_title(f"{lig} {leg_str} {stage_str}")
            ax1.set_ylabel("$\Delta G$ (kcal/mol)")
            # Red vertical lines at the individual run Chodera equilibration times
            for z, equil_idx_chodera_indiv in enumerate(equil_times_chodera_indiv[lig][leg][stage]):
                equil_time_chodera_indiv = av_dg_times[equil_idx_chodera_indiv]
                if z !=0:
                    ax1.axvline(equil_time_chodera_indiv, linestyle = ":", color = "red")
                else:
                    ax1.axvline(equil_time_chodera_indiv, linestyle = ":", color = "red", label = "Equil. time (Max $N_\mathrm{Eff.}$, per run)")
            # Vertical line at Chodera's equilibration time
            equil_time_chodera = av_dg_times[equil_times_chodera[lig][leg][stage]]
            ax1.axvline(equil_time_chodera, linestyle = ":", label = "Equil. time (Max $N_\mathrm{Eff.}$, mean)", color = "black")
            # Hide all x axis labels to avoid overlap
            ax1.get_xaxis().set_visible(False)

            # Plot the 30 ns non-adaptive result
            non_adapt_dgs = final_dgs_30_nonadapt[lig][leg][stage]
            val = non_adapt_dgs[0]
            half_width = non_adapt_dgs[1] 
            # Plot a horizontal line at the value with error bars
            # Get next colour on the colour cycle
            color = next(iter(rcParams['axes.prop_cycle']))['color']
            start = av_dg_times[0]
            end = av_dg_times[-1]
            # Use plot instead of axhline so we can choose where to start and end
            ax1.plot([start, start, end, end], [val, val, val, val], linestyle = "-", label = r"Final Non-Adaptive $\Delta G$", alpha = 0.8, color = color)
            # Fill between for error bar. Make sure the start and end match the data
            ax1.fill_between(
                [start, end],
                val - half_width,
                val + half_width,
                alpha=0.2,
            )


            # On the second axis, plot the p values
            ax2.plot(times, p_vals, marker = "o", label = "Equil. time (paired $t$)")
            ax2.axvline(equil_time, linestyle = "--", label = "Equil. time (paired $t$)", color = "black")
            # Horizontal line at p = 0.05
            ax2.axhline(0.05, linestyle = "--", label = "p = 0.05", color="black")
            ax2.set_xlabel("GPU Hours")
            ax2.set_ylabel("$p$ value")
            # Set y limits at 0 and 1
            ax2.set_ylim(0, 1)
            fig.tight_layout()
            # Hide the label at 1 to avoid clash with upper x axis
            ax2.get_yticklabels()[-1].set_visible(False)

# Change the title of PDE2A to PDE2a
axs = fig.get_axes()
axs[-4].set_title("PDE2a Bound Vanish")
axs[-6].set_title("MDM2-Pip2 Bound Vanish")

# Add a legend to the figure, off to the bottom of the last plot
axs = fig.get_axes()
fig.tight_layout()
axs[8].legend(bbox_to_anchor = (1.15, -0.35), loc = "lower left", borderaxespad = 0)
fig.savefig("final_analysis/equil_times.png", dpi = 600, bbox_inches = "tight")
            

In [None]:
# For the bound vanish leg, let's get the absolute differences between the final result 
# and the mean of the equilibrated data

results_overall = {"Chodera": [],
                   "Paired_t": [],
                    }
for lig in ligs:
    for stage in ["StageType.VANISH"]:
        equil_time_chodera = equil_times_chodera[lig]["LegType.BOUND"][stage]
        equil_time_paired_t = round(equil_times[lig]["LegType.BOUND"][stage]["equil_time"])
        dgs = equil_dgs[lig]["LegType.BOUND"][stage]["dgs"]
        final_nonequil_dg = final_dgs_30_nonadapt[lig]["LegType.BOUND"][stage][0]
        equil_times_local = {"Chodera": equil_time_chodera,
                       "Paired_t": equil_time_paired_t,
        }
        # Get the estimates for each individual run 
        for method, equil_time in equil_times_local.items():
            # Truncate the data as dictated by the equilibration time
            dgs_trunc = dgs[:, equil_time:]
            # Get the mean result
            mean_dgs = np.mean(dgs_trunc)
            # Get the absolute difference between the mean and the final result
            abs_diff = abs(mean_dgs - final_nonequil_dg)
            results_overall[method].append(abs_diff)

In [None]:
# Bar plot of the absolute differences for each method for each system
fig, ax = plt.subplots(figsize = (5, 4))
x = np.arange(len(ligs))
width = 0.3
for i, method in enumerate(results_overall):
    label = method if method == "Chodera" else "Paired $t$"
    ax.bar(x + i*width, results_overall[method], width, label = label)
ax.set_xticks(x + width/2)
ax.set_xticklabels([rename.get(lig, lig) for lig in ligs.keys()])
ax.set_ylabel("Absolute difference in $\Delta G$ Compared \nto 30 ns Result / kcal mol$^{-1}$")
ax.legend()
fig.tight_layout()
fig.savefig("final_analysis/abs_diffs_chod_paired_t_initial_sys.png", dpi = 600, bbox_inches = "tight")

In [None]:
# Get the RMSE for Chodera and paired-t
rmse_chod = np.sqrt(np.mean(np.array(results_overall["Chodera"])**2))
rmse_paired_t = np.sqrt(np.mean(np.array(results_overall["Paired_t"])**2))
print(f"RMSE Chodera: {rmse_chod:.2f} / kcal mol$^{-1}$")
print(f"RMSE Paired t: {rmse_paired_t:.2f} / kcal mol$^{-1}$")

In [None]:
# Perform a Wilcoxon signed rank test to see if the differences are significant
# (can't be for two-sided as we only have 5 samples, but get the p value anyway)
stat, p = wilcoxon(results_overall["Chodera"], results_overall["Paired_t"])
print(f"Wilcoxon signed rank test p value: {p:.2f}")

## Overall Results Summary Table

In [None]:
def get_95_ci(data: np.ndarray) -> Tuple[float, float]:
    """Get the 95% confidence interval for a given array of data using scipy.stats.sem"""
    mean_free_energy = np.mean(data)
    conf_int = t.interval(
        0.95,
        len(data) - 1,
        mean_free_energy,
        scale=sem(data),
    )[1] - mean_free_energy # 95 % C.I.
    return conf_int
    
symmetry_corrections = {"T4L": -0.41, "MIF": -0.65, "MDM2-Nutlin": 0.00, "MDM2-PIP2": 0.00, "PDE2A": 0.00}

exp_dgs = {"T4L": {"Exp. dG": -5.19, "Exp. dG Err": 0.16, "source": "https://pubs.acs.org/doi/epdf/10.1021/ct060037v"},
           "MIF": {"Exp. dG": -8.98, "Exp. dG Err": 0.28, "source": "Flourescence polarisation assay, https://pubs.acs.org/doi/10.1021/jacs.6b04910"},
           "MDM2-Nutlin": {"Exp. dG": -11.14, "Exp. dG Err": 0.27, "source": "ITC,  https://pubs.rsc.org/en/content/articlelanding/2022/SC/D2SC00028H"},
           "PDE2A": {"Exp. dG": -14.35, "Exp. dG Err": 0.82, "source": "IC50,  https://pubs.acs.org/doi/pdf/10.1021/acs.jctc.1c01208 https://pubs.acs.org/doi/10.1021/acs.jmedchem.8b00116?ref=PDF"},
           "MDM2-PIP2": {"Exp. dG": np.nan, "Exp. dG Err": np.nan, "source": "N/A"}}


results_summary = {}
for system in final_dgs_all:
    results_summary[system] = {}
    for method in final_dgs_all[system]:
        dgs = np.array(final_dgs_all[system][method]["dgs"])
        dg_tot = np.mean(dgs) + symmetry_corrections[system]
        dg_err = get_95_ci(dgs)
        result_str = f"{dg_tot:.2f}" + r" $\pm$ " + f"{dg_err:.2f}"
        results_summary[system][method] = result_str

# Turn results summary into a dataframe
results_summary_df = pd.DataFrame(results_summary).T
# Replace PDE2A with PDE2a
results_summary_df = results_summary_df.rename(index={"PDE2A": "PDE2a", "MDM2-PIP2": "MDM2-Pip2"})
results_summary_df.to_latex("final_analysis/results_summary.tex", escape=False)

In [None]:
# Repeat as above, but make a much more explicit table including all contributions to the final free energy change and the
# symmetry corrections. Make the systems the columns and the rows the different contributions to the free energy change
results_summary_detailed = {}
for system in final_dgs_all:
    for method in final_dgs_all[system]:
        title = f"{system} {method}"
        results_summary_detailed[title] = {}
        for leg_type in final_dgs_all[system][method]:
            if leg_type == "dgs":
                continue
            leg_name = leg_type.split(".")[1].capitalize()
            for stage_type in final_dgs_all[system][method][leg_type]:
                if stage_type == "dg":
                    continue
                stage_name = stage_type.split(".")[1].capitalize()
                dgs = np.array(final_dgs_all[system][method][leg_type][stage_type]["dg"])
                dg_tot = np.mean(dgs)
                dg_err = get_95_ci(dgs)
                result_str = f"{dg_tot:.2f}" + r" $\pm$ " + f"{dg_err:.2f}"
                results_summary_detailed[title][f"{leg_name} {stage_name}"] = result_str
        # Add in symmetry correction and restraint correction, along with experimental results
        restraint_corr = restraint_corrections[system]
        restraint_corr_str = f"{restraint_corr:.2f}"
        results_summary_detailed[title]["Restraint Correction"] = restraint_corr_str
        symmetry_corr = symmetry_corrections[system]
        symmetry_corr_str = f"{symmetry_corr:.2f}"
        results_summary_detailed[title]["Symmetry Correction"] = symmetry_corr_str
        exp_res = f"{exp_dgs[system]['Exp. dG']:.2f}" + r" $\pm$ " + f"{exp_dgs[system]['Exp. dG Err']:.2f}"
        results_summary_detailed[title][r"Exp. $\Delta G^o_\textrm{Bind}$"] = exp_res

results_summary_detailed_df = pd.DataFrame(results_summary_detailed).T
results_summary_detailed_df = results_summary_detailed_df.rename(index={"PDE2A": "PDE2a", "MDM2-PIP2": "MDM2-Pip2"})
results_summary_detailed_df.to_latex("final_analysis/results_summary_detailed.tex", escape=False)

## Spacing of Intermediate States

In [None]:
# For every stage for every system, plot the lambda values vs lambda index and show a dotted line on x = y
fig, axs = plt.subplots(3, 5, figsize = (17, 9))
for i, lig in enumerate(ligs):
    for j, leg in enumerate(lam_vals[lig]):
        for k, stage in enumerate(lam_vals[lig]["LegType.BOUND"]):
            if k == 0 and leg == "LegType.FREE":
                continue
            stage_str = stage.split(".")[1].lower().capitalize()
            leg_str = leg.split(".")[1].lower().capitalize()
            lambda_idx = list(range(len(lam_vals[lig][leg][stage])))
            linear = np.linspace(0, 1, len(lambda_idx))
            lam_vals_local = lam_vals[lig][leg][stage]
            axs[k, i].plot(lambda_idx, lam_vals_local, label = leg_str, marker = "o")
            axs[k, i].plot(lambda_idx, linear, linestyle = "--", color = "black")
            axs[k, i].set_title(f"{rename.get(lig, lig)} {stage_str}")
            axs[k, i].set_xlabel("Window Index")
            axs[k, i].set_ylabel("$\lambda$")
            # Add text stating the number of windows
            offset = 0.05 if leg == "LegType.FREE" else 0.1
            axs[k, i].text(0.65, offset, f"No. windows {leg_str}: {len(lam_vals_local)}", transform = axs[k, i].transAxes, horizontalalignment = "center", verticalalignment = "center")

            # Ensure that only integers are shown on the x axis
            from matplotlib.ticker import MaxNLocator
            axs[k, i].xaxis.set_major_locator(MaxNLocator(integer=True))

# Add a legend to the figure, off to the right hand side of all the plots
axs[1,4].legend(bbox_to_anchor = (1.05, 0.5), loc = "center left", borderaxespad = 0)
fig.tight_layout()
fig.savefig("final_analysis/lambda_values.png", dpi = 600, bbox_inches = "tight")


## Relative Costs

In [None]:
# Plot bar plot of costs
costs_df = pd.DataFrame(costs)
# Rename columns as required
costs_df = costs_df.rename(columns = {lig: rename.get(lig, lig) for lig in ligs})
# Change row names to Bound Leg and Free Leg
costs_df.index = ["Bound Leg", "Free Leg"]
# Nice black outline round bars
costs_df.plot.bar(figsize = (10, 5), rot = 0, xlabel = "Leg", ylabel = "Relative cost", edgecolor = "black")
# Save figure
plt.savefig("final_analysis/costs.png", dpi = 600, bbox_inches = "tight")

## Allocation of Sampling Time

In [None]:

fig, axs = plt.subplots(5, 5, figsize = (19, 17), dpi = 600)
for i, lig in enumerate(ligs):
    for j, leg in enumerate(sampling_times[lig]):
        for k, stage in enumerate(sampling_times[lig][leg]):
            stage_str = stage.split(".")[1].lower().capitalize()
            leg_str = leg.split(".")[1].lower().capitalize()
            times = sampling_times[lig][leg][stage]["times"]
            times = np.array(times) * costs[lig][leg] * REF_COST
            equil_times = sampling_times[lig][leg][stage]["equil_times"]
            equil_times = np.array(equil_times) * costs[lig][leg] * REF_COST
            lam_vals_local = sampling_times[lig][leg][stage]["lam_vals"]
            # Bar plots for sampling times
            ax = axs[k + 3*j, i]
            # Get reasonable width for bars
            width = 0.6 / len(lam_vals_local)
            ax.bar(lam_vals_local, times, label = "Total sampling time", edgecolor = "black", width = width)
            # Plot the equilibration times
            ax.bar(lam_vals_local, equil_times, label = "Equilibration time", edgecolor = "black", width = width, hatch = "///////")
            ax.set_title(f"{rename.get(lig, lig)} {leg_str} {stage_str}")
            ax.set_xlabel("$\lambda$")
            ax.set_ylabel("GPU Hours")

# Create a legend for the figure to the right of all the plots
axs[2,4].legend(bbox_to_anchor = (1.05, 0.5), loc = "center left", borderaxespad = 0)
fig.tight_layout()
fig.savefig("final_analysis/sampling_times.png", dpi = 600, bbox_inches = "tight")

In [None]:
# Create a coarser summary of sampling time allocation - just show the total sampling time for each stage for each sytem.
# Put everything on a single bar plot. Use the code above as a starting point.

fig, ax = plt.subplots(figsize=(8, 4), dpi=600)
x = np.arange(len(sampling_times))
width = 0.2
# Plot bound and free next to each other
for i, stage in enumerate(sampling_times["MDM2-Nutlin"]["LegType.BOUND"]):
    # Get single colour
    color = ax._get_lines.get_next_color()
    tot_times_bound = [np.sum(sampling_times[system]["LegType.BOUND"][stage]["times"])*costs[system]["LegType.BOUND"]*REF_COST for system in sampling_times]
    ax.bar(x + (i * width), tot_times_bound, width, label=f"Bound {stage.split('.')[1].capitalize()}", edgecolor="k", alpha=1, color=color)
    if stage != "StageType.RESTRAIN":
        tot_times_free = [np.sum(sampling_times[system]["LegType.FREE"][stage]["times"])*costs[system]["LegType.FREE"]*REF_COST for system in sampling_times]
        ax.bar(x + (i * width), tot_times_free, width, label=f"Free {stage.split('.')[1].capitalize()}", edgecolor="k", alpha=1, color=color, hatch="///////")

ax.set_xticks(x + width)
x_labels = [rename.get(system, system) for system in sampling_times]
ax.set_xticklabels(x_labels)
ax.set_ylabel("Total Sampling Time / GPU Hours")
# Put label off to right of plot
ax.legend(bbox_to_anchor=(1.03, 0.7))
plt.tight_layout()
fig.savefig("final_analysis/sampling_times_summary.png", bbox_inches="tight", dpi=600)


In [None]:
# Create a coarser summary of sampling time allocation - just show the total sampling time for each stage for each sytem.
# Put everything on a single bar plot. Use the code above as a starting point.

fig, ax = plt.subplots(figsize=(8, 4), dpi=600)
x = np.arange(len(sampling_times))
width = 0.2
# Plot bound and free next to each other
for i, stage in enumerate(sampling_times["MDM2-Nutlin"]["LegType.BOUND"]):
    # Get single colour
    color = ax._get_lines.get_next_color()
    if stage != "StageType.RESTRAIN":
        tot_times_free = [np.sum(sampling_times[system]["LegType.FREE"][stage]["times"])*costs[system]["LegType.FREE"]*REF_COST for system in sampling_times]
        ax.bar(x + (i * width), tot_times_free, width, label=f"Free {stage.split('.')[1].capitalize()}", edgecolor="k", alpha=1, color=color, hatch="///////")

ax.set_xticks(x + 1.5*width)
x_labels = [rename.get(system, system) for system in sampling_times]
ax.set_xticklabels(x_labels)
ax.set_ylabel("Total Sampling Time / GPU Hours")
# Put label off to right of plot
ax.legend(bbox_to_anchor=(1.03, 0.7))
plt.tight_layout()
fig.savefig("final_analysis/sampling_times_summary_free_only.png", bbox_inches="tight", dpi=600)


In [None]:
# Repeat above plots, but show total sampling time per window

fig, ax = plt.subplots(figsize=(8, 4), dpi=600)
x = np.arange(len(sampling_times))
width = 0.2
# Plot bound and free next to each other
for i, stage in enumerate(sampling_times["MDM2-Nutlin"]["LegType.BOUND"]):
    # Get single colour
    n_windows = np.array([len(sampling_times[system]["LegType.BOUND"][stage]["times"]) for system in sampling_times])
    color = ax._get_lines.get_next_color()
    tot_times_bound = np.array([np.sum(sampling_times[system]["LegType.BOUND"][stage]["times"])*costs[system]["LegType.BOUND"]*REF_COST for system in sampling_times])
    tot_times_bound_per_window = tot_times_bound / n_windows
    ax.bar(x + (i * width), tot_times_bound_per_window, width, label=f"Bound {stage.split('.')[1].capitalize()}", edgecolor="k", alpha=1, color=color)
    if stage != "StageType.RESTRAIN":
        tot_times_free = np.array([np.sum(sampling_times[system]["LegType.FREE"][stage]["times"])*costs[system]["LegType.FREE"]*REF_COST for system in sampling_times])
        tot_times_free_per_window = tot_times_free / n_windows
        ax.bar(x + (i * width), tot_times_free_per_window, width, label=f"Free {stage.split('.')[1].capitalize()}", edgecolor="k", alpha=1, color=color, hatch="///////")

ax.set_xticks(x + width)
x_labels = [rename.get(system, system) for system in sampling_times]
ax.set_xticklabels(x_labels)
ax.set_ylabel("Total Sampling Time per Window / GPU Hours")
# Put label off to right of plot
ax.legend(bbox_to_anchor=(1.03, 0.7))
plt.tight_layout()
fig.savefig("final_analysis/sampling_times_summary_per_window.png", bbox_inches="tight", dpi=600)


In [None]:
# Repeat above plots, but show total sampling time per window

fig, ax = plt.subplots(figsize=(8, 4), dpi=600)
x = np.arange(len(sampling_times))
width = 0.2
# Plot bound and free next to each other
for i, stage in enumerate(sampling_times["MDM2-Nutlin"]["LegType.BOUND"]):
    # Get single colour
    n_windows = np.array([len(sampling_times[system]["LegType.BOUND"][stage]["times"]) for system in sampling_times])
    color = ax._get_lines.get_next_color()
    if stage != "StageType.RESTRAIN":
        tot_times_free = np.array([np.sum(sampling_times[system]["LegType.FREE"][stage]["times"])*costs[system]["LegType.FREE"]*REF_COST for system in sampling_times])
        tot_times_free_per_window = tot_times_free / n_windows
        ax.bar(x + (i * width), tot_times_free_per_window, width, label=f"Free {stage.split('.')[1].capitalize()}", edgecolor="k", alpha=1, color=color, hatch="///////")

ax.set_xticks(x + 1.5*width)
x_labels = [rename.get(system, system) for system in sampling_times]
ax.set_xticklabels(x_labels)
ax.set_ylabel("Total Sampling Time per Window / GPU Hours")
# Put label off to right of plot
ax.legend(bbox_to_anchor=(1.03, 0.7))
plt.tight_layout()
fig.savefig("final_analysis/sampling_times_summary_per_window_free_only.png", bbox_inches="tight", dpi=600)


## Table of GPU Times

In [None]:
gpu_times_df = pd.DataFrame(gpu_times)
# Add in a final column for the total
gpu_times_df["Total"] = gpu_times_df.sum(axis = 1)
# Round all numbers to nearest integer
gpu_times_df = gpu_times_df.round(0)
# Only display numbers as integers
gpu_times_df = gpu_times_df.astype(int)
# Write to a LaTeX table
gpu_times_df.to_latex("final_analysis/gpu_times.tex", escape = False)


In [None]:
gpu_times_df

## Convergence of Free Energy Estimates

In [None]:
# List the systems in the order we want them plotted
systems = ["T4L", "MIF", "MDM2-PIP2", "PDE2A", "MDM2-Nutlin"]
stage_types = {"restrain": "StageType.RESTRAIN", "discharge": "StageType.DISCHARGE", "vanish": "StageType.VANISH"}
leg_types = {"bound": "LegType.BOUND", "free": "LegType.FREE"}

def plot_dgs_conv(ax: plt.axes, dgs: np.ndarray,times: np.ndarray, leg: str, stage: str, system: str, label: str, show_final_dg: bool = True, scale_er:float=1) -> None:
    """Plot the convergence of the free energy changes for a given leg, stage, time and system. Note
    that the label should be the time, e.g. "0.2 ns", or "Adaptive"."""
    # Get the simulation times
    n_runs = len(dgs)
    # Calculate the 95 % CI with scipy
    mean_free_energy = np.mean(dgs, axis=0)
    conf_int = (
    t.interval(
        0.95,
        len(dgs) - 1,
        mean_free_energy,
        scale=sem(dgs),
    )[1]
    - mean_free_energy
    )  # 95 % C.I.
    conf_int *= scale_er
    # Add a horizontal line at the final free energy change and dotted lines at the 95 % CI
    if show_final_dg:
        final_dg, final_dg_err = final_dgs_30_nonadapt[system][leg][stage]
        # Make the lines lighter
        ax.axhline(final_dg, color="black", linestyle="--", alpha=0.5)
        ax.axhline(final_dg + final_dg_err, color="black", linestyle=":", alpha=0.5)
        ax.axhline(final_dg - final_dg_err, color="black", linestyle=":", alpha=0.5)
    # Plot the free energy changes against the simulation times
    ax.plot(times, mean_free_energy, label=label)
    # Fill between to show the 95 % CI
    ax.fill_between(times, mean_free_energy - conf_int, mean_free_energy + conf_int, alpha=0.5)
    # Label the plot
    ax.set_xlabel("GPU Hours")
    ax.set_ylabel(r"$\Delta G$ / kcal mol$^{-1}$")
    stage_str = stage.split(".")[1].capitalize()
    leg_str = leg.split(".")[1].capitalize()
    if system == "PDE2A":
        system_title = "PDE2a"
    elif system == "MDM2-PIP2":
        system_title = "MDM2-Pip2"
    else:
        system_title = system
    ax.set_title(f"{system_title} {leg_str} {stage_str}")
    ax.legend()
    plt.tight_layout()

def plot_dgs_conv_all(comparitive_conv_data:Dict, times = ["Adaptive", "6 ns", "30 ns"], scale=False, show_final_dg=True) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the free energy changes for each stage for each leg for each time for each system"""
    # Need 5 axes for the for each system
    n_systems = len(comparitive_conv_data)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(5, n_systems, figsize=(4*5, 4*n_systems), dpi=300)
    for i, system in enumerate(systems):
        for time in times:
            ax_ind = 0
            for j, leg in enumerate(comparitive_conv_data[system]):
                for k, stage in enumerate(comparitive_conv_data[system][leg]):
                    data_nonadapt = comparitive_conv_data[system][leg][stage][time]
                    dgs = np.array(data_nonadapt["dgs"])
                    times_local = np.array(data_nonadapt["gpu_times"])
                    stage_label = stage_types[stage]
                    leg_label = leg_types[leg]
                    if not dgs.shape == (5,20):
                        print(f"Incorrect shape for {system} {leg} {stage} {time}")
                    elif not times_local.shape == (20,):
                        print(f"Incorrect shape for {system} {leg} {stage} {time}")
                    plot_dgs_conv(axs[ax_ind, i], dgs, times_local, leg_label, stage_label, system, time)
                    ax_ind += 1

    if scale:
        # Set the y-axis scales to be the same within rows (same stages, different systems)
        for i in range(5):
            # Get the largest difference between the upper and lower bounds of the y-axis
            y_diff = max([ax.get_ylim()[1] - ax.get_ylim()[0] for ax in axs[i, :]])
            y_diff /=2
            for j in range(n_systems):
                # Set the y-axis limits to be the mean +/- half the difference
                ax_mean = np.mean(axs[i, j].get_ylim())
                axs[i, j].set_ylim(ax_mean - y_diff, ax_mean + y_diff)
        
    return fig, axs

In [None]:
fig, axs = plot_dgs_conv_all(comparitive_conv_data, times = ["Adaptive", "6 ns", "30 ns"])
fig.savefig("final_analysis/dgs_conv_all.png", dpi=600, bbox_inches="tight")

In [None]:
# Plot as above, but using the overall results for each system.

def plot_dgs_conv_all_calcs(comparitive_conv_data:Dict, times = ["Adaptive", "6 ns", "30 ns"], scale=False, show_final_dg=False) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the free energy changes for each stage for each leg for each time for each system for each calculation"""
    # Need 5 axes for the for each system
    n_systems = len(comparitive_conv_data)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(1, n_systems, figsize=(4*n_systems, 4), dpi=300)
    for i, system in enumerate(comparitive_conv_data):
        for time in times:
            dgs = np.array(comparitive_conv_data[system][time]["dgs"])
            times_local = np.array(comparitive_conv_data[system][time]["times"])
            plot_dgs_conv(axs[i], dgs, times_local, "All.", "All.", system, time, show_final_dg=show_final_dg)
    
    return fig, axs

In [None]:
fig, axs = plot_dgs_conv_all_calcs(comparitive_conv_data_calcs, times = ["Adaptive", "6 ns", "30 ns"])
#fig.savefig("final_analysis/dgs_conv_all_calcs_non.png", dpi=600, bbox_inches="tight")

In [None]:
# Combine the two non-adaptive results to make any equilibration trends clearer

def plot_dgs_conv_all_combined(comparitive_conv_data:Dict, scale=False, show_final_dg=True) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the free energy changes for each stage for each leg for each time for each system"""
    # Need 5 axes for the for each system
    n_systems = len(comparitive_conv_data)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(5, n_systems, figsize=(4*5, 4*n_systems), dpi=300)
    for i, system in enumerate(systems):
        for plt_type in ["Adaptive", "Non-Adaptive"]:
            ax_ind = 0
            for j, leg in enumerate(comparitive_conv_data[system]):
                for k, stage in enumerate(comparitive_conv_data[system][leg]):
                    if plt_type == "Adaptive":
                        data = comparitive_conv_data[system][leg][stage]["Adaptive"]
                        dgs = np.array(data["dgs"])
                        times_local = np.array(data["gpu_times"])
                    elif plt_type == "Non-Adaptive":
                        # Combine the 6 ns and 30 ns data
                        data_6 = comparitive_conv_data[system][leg][stage]["6 ns"]
                        data_30 = comparitive_conv_data[system][leg][stage]["30 ns"]
                        # Add along second axis to combine the data
                        dgs = np.concatenate((np.array(data_6["dgs"]), np.array(data_30["dgs"])), axis=0)
                        times_local = data_6["gpu_times"] # No need to double the times
                    stage_label = stage_types[stage]
                    leg_label = leg_types[leg]
                    scale_er = 1/np.sqrt(2) if plt_type == "Adaptive" else 1
                    #plot_dgs_conv(axs[ax_ind, i], dgs, times_local, leg_label, stage_label, system, plt_type, scale_er=scale_er)
                    plot_dgs_conv(axs[ax_ind, i], dgs, times_local, leg_label, stage_label, system, label=plt_type, scale_er=scale_er)
                    ax_ind += 1

    if scale:
        # Set the y-axis scales to be the same within rows (same stages, different systems)
        for i in range(5):
            # Get the largest difference between the upper and lower bounds of the y-axis
            y_diff = max([ax.get_ylim()[1] - ax.get_ylim()[0] for ax in axs[i, :]])
            y_diff /=2
            for j in range(n_systems):
                # Set the y-axis limits to be the mean +/- half the difference
                ax_mean = np.mean(axs[i, j].get_ylim())
                axs[i, j].set_ylim(ax_mean - y_diff, ax_mean + y_diff)
        
    return fig, axs

In [None]:
fig, ax = plot_dgs_conv_all_combined(comparitive_conv_data)
fig.savefig("final_analysis/dgs_conv_all_combined.png", dpi=600, bbox_inches="tight")

In [None]:
# Repeat above plot but for only the bound vanish stage, to show differences in convergence.

def plot_dgs_conv_all_combined_bound_vanish(comparitive_conv_data:Dict, scale=False, show_final_dg=True) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the free energy changes for each stage for each leg for each time for each system"""
    # Need 5 axes for the for each system
    n_systems = len(comparitive_conv_data)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(3, 2, figsize=(5.5, 8), dpi=300)
    axs = axs.flatten()
    for i, system in enumerate(systems):
        for plt_type in ["Adaptive", "Non-Adaptive"]:
            ax_ind = 0
            for j, leg in enumerate(comparitive_conv_data[system]):
                if j > 0:
                    continue
                for k, stage in enumerate(comparitive_conv_data[system][leg]):
                    if k != 2:
                        continue
                    if plt_type == "Adaptive":
                        data = comparitive_conv_data[system][leg][stage]["Adaptive"]
                        dgs = np.array(data["dgs"])
                        times_local = np.array(data["gpu_times"])
                    elif plt_type == "Non-Adaptive":
                        # Combine the 6 ns and 30 ns data
                        data_6 = comparitive_conv_data[system][leg][stage]["6 ns"]
                        data_30 = comparitive_conv_data[system][leg][stage]["30 ns"]
                        # Add along second axis to combine the data
                        dgs = np.concatenate((np.array(data_6["dgs"]), np.array(data_30["dgs"])), axis=0)
                        times_local = data_6["gpu_times"] # No need to double the times
                    stage_label = stage_types[stage]
                    leg_label = leg_types[leg]
                    scale_er = 1/np.sqrt(2) if plt_type == "Adaptive" else 1
                    #plot_dgs_conv(axs[ax_ind, i], dgs, times_local, leg_label, stage_label, system, plt_type, scale_er=scale_er)
                    plot_dgs_conv(axs[i], dgs, times_local, leg_label, stage_label, system, label=plt_type, scale_er=scale_er)
                    ax_ind += 1

    # Hide the legend on all but the final plot, and move this to below the plot
    for ax in axs[:-1]:
        ax.get_legend().remove()

    # Hide the last axis
    axs[-1].axis("off")
    
    # Stick a legend to the side of the last active axis
    axs[-2].legend(bbox_to_anchor=(1.45, 0.5), loc="center left", borderaxespad=0)

    if scale:
        # Set the y-axis scales to be the same within rows (same stages, different systems)
        for i in range(1):
            # Get the largest difference between the upper and lower bounds of the y-axis
            y_diff = max([ax.get_ylim()[1] - ax.get_ylim()[0] for ax in axs[:]])
            y_diff /=2
            for j in range(n_systems):
                # Set the y-axis limits to be the mean +/- half the difference
                ax_mean = np.mean(axs[j].get_ylim())
                axs[j].set_ylim(ax_mean - y_diff, ax_mean + y_diff)
        
    return fig, axs

fig, ax = plot_dgs_conv_all_combined_bound_vanish(comparitive_conv_data, scale=True)
fig.savefig("final_analysis/dgs_conv_all_combined_bound_vanish.png", dpi=600, bbox_inches="tight")

In [None]:
# Repeat above plots but for errors only

def plot_cis_conv(ax: plt.axes, dgs: np.ndarray,times: np.ndarray, leg: str, stage: str, system: str, label: str) -> None:
    """Plot the 95 % C.I. changes for a given leg, stage, time and system. Note
    that the label should be the time, e.g. "0.2 ns", or "Adaptive"."""
    # Get the simulation times
    n_runs = len(dgs)
    # Turn the dgs into a cumulative average
    dgs = np.cumsum(dgs, axis=1) / np.arange(1, len(dgs[0]) + 1)
    # Calculate the 95 % CI with scipy
    mean_free_energy = np.mean(dgs, axis=0)
    conf_int = (
    t.interval(
        0.95,
        len(dgs) - 1,
        mean_free_energy,
        scale=sem(dgs),
    )[1]
    - mean_free_energy
    )  # 95 % C.I.
    # Multiply CIs by the cost of the averaging block to make results comparable 
    # with difference averaging block sizes
    # Add a horizontal line at the final free energy change and dotted lines at the 95 % CI
    ax.plot(times, conf_int, label=label)
    # Label the plot
    ax.set_xlabel("GPU Hours")
    ax.set_ylabel(r"95 % CI of $\Delta G$ / kcal mol$^{-1}$")
    stage_str = stage.split(".")[1].capitalize()
    leg_str = leg.split(".")[1].capitalize()
    ax.set_title(f"{system} {leg_str} {stage_str}")
    ax.legend()
    # Ensure that the bottom of the scale is 0
    plt.tight_layout()

def plot_cis_conv_all(comparitive_conv_data:Dict, times = ["Adaptive", "6 ns", "30 ns"], scale=False, show_final_dg=True) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the 95 % CIs for each stage for each leg for each time for each system"""
    # Need 5 axes for the for each system
    n_systems = len(comparitive_conv_data)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(5, n_systems, figsize=(4*5, 4*n_systems), dpi=300)
    for i, system in enumerate(systems):
        for time in times:
            ax_ind = 0
            for j, leg in enumerate(comparitive_conv_data[system]):
                for k, stage in enumerate(comparitive_conv_data[system][leg]):
                    data_nonadapt = comparitive_conv_data[system][leg][stage][time]
                    dgs = np.array(data_nonadapt["dgs"])
                    times_local = np.array(data_nonadapt["gpu_times"])
                    stage_label = stage_types[stage]
                    leg_label = leg_types[leg]
                    plot_cis_conv(axs[ax_ind, i], dgs, times_local, leg_label, stage_label, system, time)
                    ax_ind += 1
        

    if scale:
        # Set the y-axis scales to be the same within rows (same stages, different systems)
        for i in range(5):
            # Get the largest difference between the upper and lower bounds of the y-axis
            y_diff = max([ax.get_ylim()[1] - ax.get_ylim()[0] for ax in axs[i, :]])
            y_diff /=2
            for j in range(n_systems):
                # Set the y-axis limits to be the mean +/- half the difference
                ax_mean = np.mean(axs[i, j].get_ylim())
                axs[i, j].set_ylim(ax_mean - y_diff, ax_mean + y_diff)
        
    # Set y limit bottom to 0
    for ax in axs.flatten():
        ax.set_ylim(bottom=0)
    return fig, axs

In [None]:
fig, axs = plot_cis_conv_all(comparitive_conv_data)
fig.savefig("final_analysis/cis_conv_all.png", dpi=600, bbox_inches="tight")

In [None]:
# Combined analysis of entire dataset
def plot_dgs_conv_entire(ax: plt.axes, dgs: np.array, times: np.array, name: str, time:str) -> None:
    """Plot the convergence of the free energy changes for the entire set on the given axis"""
    # Calculate the 95 % CI with scipy
    mean_free_energy = np.mean(dgs, axis=0)
    conf_int = (
    t.interval(
        0.95,
        len(dgs) - 1,
        mean_free_energy,
        scale=sem(dgs),
    )[1]
    - mean_free_energy
    )  # 95 % C.I.
    # Take half delta t off the times to get the average times
    times = times - (times[1] - times[0]) / 2
    # Plot lines showing the final free energy change and dotted lines at the 95 % CI
    # Plot the free energy changes against the simulation times
    ax.plot(times, mean_free_energy, label=f"{time}")
    # Fill between to show the 95 % CI
    ax.fill_between(times, mean_free_energy - conf_int, mean_free_energy + conf_int, alpha=0.5)
    # Label the plot
    ax.set_xlabel("GPU Hours")
    ax.set_ylabel(r"$\Delta G$ / kcal mol$^{-1}$")
    ax.set_title(name)
    ax.legend(loc="lower right")
    plt.tight_layout()

def plot_dgs_conv_entire_dataset(dg_dict_nonadapt: Dict, dg_dict_adapt: Dict, time = "6 ns", scale=False) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the free energy changes over the entire calculation for each system"""
    # Need 5 axes for the for each system
    n_systems = len(dg_dict_nonadapt)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(1, 1, figsize=(4, 4), dpi=600)
    n_times = len(dg_dict_nonadapt["T4L"][time]["LegType.BOUND"]["StageType.DISCHARGE"]["dgs"][0])
    n_replicates = len(dg_dict_nonadapt["T4L"][time]["LegType.BOUND"]["StageType.DISCHARGE"]["dgs"])
    dgs_nonadapt = np.zeros((n_replicates, n_times))
    dgs_adapt = np.zeros((n_replicates, n_times))
    times_data_nonadapt = np.zeros((n_times))
    times_data_adapt = np.zeros((n_times))
    for i, system in enumerate(dg_dict_nonadapt):
        for leg in dg_dict_nonadapt[system][time]:
            dg_multiplier = -1 if leg == "LegType.BOUND" else 1
            for stage in dg_dict_nonadapt[system][time][leg]:
                dgs_nonadapt += np.array(dg_dict_nonadapt[system][time][leg][stage]["dgs"]) * dg_multiplier
                dgs_adapt += np.array(dg_dict_adapt[system][leg][stage]["dgs"]) * dg_multiplier
                times_data_nonadapt += np.array(dg_dict_nonadapt[system][time][leg][stage]["gpu_times"])
                times_data_adapt += np.array(dg_dict_adapt[system][leg][stage]["gpu_times"])
        dgs_nonadapt -= restraint_corrections[system]
        dgs_adapt -= restraint_corrections[system]
    plot_dgs_conv_entire(axs, dgs_nonadapt, times_data_nonadapt, "Entire Dataset", "Non-Adaptive 6 ns")
    plot_dgs_conv_entire(axs, dgs_adapt, times_data_adapt, "Entire Dataset", "Adaptive")

    return fig, axs


In [None]:
fig, ax = plot_dgs_conv_entire_dataset(dgs_conv_nonadapt_nonequil, dgs_conv_nonequil, time="6 ns")
fig.savefig("final_analysis/dgs_conv_entire_dataset.png", dpi=600, bbox_inches="tight")

In [None]:
def plot_dgs_conv_system(ax: plt.axes, dgs: np.ndarray,times: np.ndarray, system: str, label: str, show_final_dg: bool = True, scale_er:float=1) -> None:
    """Plot the convergence of the free energy changes for a system. Note
    that the label should be the time, e.g. "0.2 ns", or "Adaptive"."""
    # Get the simulation times
    n_runs = len(dgs)
    # Calculate the 95 % CI with scipy
    mean_free_energy = np.mean(dgs, axis=0)
    conf_int = (
    t.interval(
        0.95,
        len(dgs) - 1,
        mean_free_energy,
        scale=sem(dgs),
    )[1]
    - mean_free_energy
    )  # 95 % C.I.
    conf_int *= scale_er
    # Add a horizontal line at the final free energy change and dotted lines at the 95 % CI
    if show_final_dg:
        final_dg, final_dg_err = final_dgs_30_nonadapt[system][leg][stage]
        # Make the lines lighter
        ax.axhline(final_dg, color="black", linestyle="--", alpha=0.5)
        ax.axhline(final_dg + final_dg_err, color="black", linestyle=":", alpha=0.5)
        ax.axhline(final_dg - final_dg_err, color="black", linestyle=":", alpha=0.5)
    # Take half delta t off the times to get the average times
    times = times - (times[1] - times[0]) / 2
    # Plot the free energy changes against the simulation times
    ax.plot(times, mean_free_energy, label=label)
    # Fill between to show the 95 % CI
    ax.fill_between(times, mean_free_energy - conf_int, mean_free_energy + conf_int, alpha=0.5)
    # Label the plot
    ax.set_xlabel("GPU Hours")
    ax.set_ylabel(r"$\Delta G$ / kcal mol$^{-1}$")
    system_title = system if system != "PDE2A" else "PDE2a"
    ax.set_title(f"{system_title}")
    ax.legend()
    plt.tight_layout()

def plot_cis_conv_system(ax: plt.axes, dgs: np.ndarray,times: np.ndarray, system: str, label: str) -> None:
    """Plot the 95 % C.I. changes for a given system. Note
    that the label should be the time, e.g. "0.2 ns", or "Adaptive"."""
    # Get the simulation times
    n_runs = len(dgs)
    # Turn the dgs into a cumulative average
    dgs = np.cumsum(dgs, axis=1) / np.arange(1, len(dgs[0]) + 1)
    # Calculate the 95 % CI with scipy
    mean_free_energy = np.mean(dgs, axis=0)
    conf_int = (
    t.interval(
        0.95,
        len(dgs) - 1,
        mean_free_energy,
        scale=sem(dgs),
    )[1]
    - mean_free_energy
    )  # 95 % C.I.
    # Multiply CIs by the cost of the averaging block to make results comparable 
    # with difference averaging block sizes
    # Add a horizontal line at the final free energy change and dotted lines at the 95 % CI
    ax.plot(times, conf_int, label=label)
    # Label the plot
    ax.set_xlabel("GPU Hours")
    ax.set_ylabel(r"95 % CI of $\Delta G$ / kcal mol$^{-1}$")
    sys_title = rename.get(system, system)
    ax.set_title(f"{sys_title}")
    ax.legend()
    # Ensure that the bottom of the scale is 0
    plt.tight_layout()

def plot_dgs_conv_per_system(dg_dict_nonadapt: Dict, scale=False, errors:bool=False) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the free energy changes over each stage for each system"""
    # Need 5 axes for the for each system
    n_systems = len(dg_dict_nonadapt)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(1, 5, figsize=(4*5, 4))
    n_times = len(dg_dict_nonadapt["T4L"]["6 ns"]["LegType.BOUND"]["StageType.DISCHARGE"]["dgs"][0])
    n_replicates = len(dg_dict_nonadapt["T4L"]["6 ns"]["LegType.BOUND"]["StageType.DISCHARGE"]["dgs"])
    for i, system in enumerate(dg_dict_nonadapt):
        max_time = 0
        for time in dg_dict_nonadapt[system]:
            dgs = np.zeros((n_replicates, n_times))
            times_data = np.zeros((n_times))
            for leg in dg_dict_nonadapt[system][time]:
                dg_multiplier = -1 if leg == "LegType.BOUND" else 1
                for stage in dg_dict_nonadapt[system][time][leg]:
                    dgs += np.array(dg_dict_nonadapt[system][time][leg][stage]["dgs"]) * dg_multiplier
                    times_data += np.array(dg_dict_nonadapt[system][time][leg][stage]["gpu_times"])
            dgs -= restraint_corrections[system]
            if time == "6 ns":
                max_time = max(max_time, np.max(times_data))
            if errors:
                plot_cis_conv_system(axs[i], dgs, times_data, system, time)
            else:
                plot_dgs_conv_system(axs[i], dgs, times_data, system, time, show_final_dg=False)
        # Set the x limit to the 6 ns final time
        axs[i].set_xlim(-10, max_time)
    
    return fig, ax

In [None]:
# Combine dgs_conv_nonequil and dgs_conv_nonadapt_nonequil to make plotting easier
dgs_conv_nonequil_combined = {}
for system in dgs_conv_nonequil:
    dgs_conv_nonequil_combined[system] = {}
    # First, add adaptive data
    dgs_conv_nonequil_combined[system]["Adaptive"] = dgs_conv_nonequil[system]
    # Then, add non-adaptive data
    for label in dgs_conv_nonadapt_nonequil[system]:
        dgs_conv_nonequil_combined[system][label] = dgs_conv_nonadapt_nonequil[system][label]

In [None]:
fig, ax = plot_dgs_conv_per_system(dgs_conv_nonequil_combined)
fig.savefig("final_analysis/dgs_conv_per_system.png", dpi=600, bbox_inches="tight")

In [None]:
fig, ax = plot_dgs_conv_per_system(dgs_conv_nonequil_combined, errors=True)
fig.savefig("final_analysis/cis_conv_per_system.png", dpi=600, bbox_inches="tight")