# Analysis of Non-Adaptive Runs

In [None]:
import a3fe as a3 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams, rcParamsDefault
import pickle
rcParams.update(rcParamsDefault)
plt.style.use("seaborn-v0_8-colorblind")
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
from typing import List, Tuple, Dict, Callable, Union
%matplotlib inline
from scipy.stats import linregress, kruskal, t, sem

# Define how we want to rename some systems
rename = {"PDE2A": "PDE2a", "MDM2-PIP2": "MDM2-Pip2"}

## Load Results of Slow Analyses

In [None]:
# First, get the required data from Zenodo

! wget --output-document final_analysis/grad_data.pkl.tar.gz "https://zenodo.org/records/11520013/files/grad_data.pkl.tar.gz?download=1"
! tar -xzf final_analysis/grad_data.pkl.tar.gz -C final_analysis

In [None]:
# This is required because a3fe was called EnsEquil when the data was generated
# and pickled, and the current hack in a3fe which points EnsEquil to a3fe results
# in a circular import issue with EnsEquil
from EnsEquil.run import *

file_names = {"final_dGs_all": "final_dGs_all.pkl",
              "restraint_corrections": "restraint_corrections.pkl",
              "grad_data": "grad_data.pkl",
              "restraint_dicts": "restraint_dicts.pkl",
              "dgs_conv_nonequil": "dgs_conv_nonequil_times.pkl",
              "final_dGs": "final_dGs_30.pkl",
            }

for key, value in file_names.items():
    file_path = f"final_analysis/{value}"
    with open(file_path, "rb") as f:
        globals()[key] = pickle.load(f)

## Notebook Layout

- Tables of overall results
- Convergence plots
- H-test to check for significantly different means of gradients between runs
- Plots of SDs of gradients vs lambda
- Plots of SEMs of gradients vs lambda

## Overall Results

In [None]:
def get_95_ci(data: np.ndarray) -> Tuple[float, float]:
    """Get the 95% confidence interval for a given array of data using scipy.stats.sem"""
    mean_free_energy = np.mean(data)
    conf_int = t.interval(
        0.95,
        len(data) - 1,
        mean_free_energy,
        scale=sem(data),
    )[1] - mean_free_energy # 95 % C.I.
    return conf_int
    
symmetry_corrections = {"T4L": -0.41, "MIF": -0.65, "MDM2-Nutlin": 0.00, "MDM2-PIP2": 0.00, "PDE2A": 0.00}

exp_dgs = {"T4L": {"Exp. dG": -5.19, "Exp. dG Err": 0.16, "source": "https://pubs.acs.org/doi/epdf/10.1021/ct060037v"},
           "MIF": {"Exp. dG": -8.98, "Exp. dG Err": 0.28, "source": "Flourescence polarisation assay, https://pubs.acs.org/doi/10.1021/jacs.6b04910"},
           "MDM2-Nutlin": {"Exp. dG": -11.14, "Exp. dG Err": 0.27, "source": "ITC,  https://pubs.rsc.org/en/content/articlelanding/2022/SC/D2SC00028H"},
           "PDE2A": {"Exp. dG": -14.35, "Exp. dG Err": 0.82, "source": "IC50,  https://pubs.acs.org/doi/pdf/10.1021/acs.jctc.1c01208 https://pubs.acs.org/doi/10.1021/acs.jmedchem.8b00116?ref=PDF"},
           "MDM2-PIP2": {"Exp. dG": -9.11, "Exp. dG Err": 0.01, "source": "ITC, https://pubs.acs.org/doi/full/10.1021/ja305839b"}}

results_summary = {}
for system in final_dGs_all:
    results_summary[system] = {}
    for time in final_dGs_all[system]:
        dgs = np.array(final_dGs_all[system][time]["dgs"])
        dg_tot = np.mean(dgs) + symmetry_corrections[system]
        dg_err = get_95_ci(dgs)
        result_str = f"{dg_tot:.2f}" + r" $\pm$ " + f"{dg_err:.2f}"
        results_summary[system][time] = result_str
    exp_res = exp_dgs[system]["Exp. dG"]
    exp_err = exp_dgs[system]["Exp. dG Err"]
    exp_str = f"{exp_res:.2f}" + r" $\pm$ " + f"{exp_err:.2f}"
    results_summary[system][r"Exp. $\Delta G^o_\textrm{Bind}$"] = exp_str
    results_summary[system]["Exp. Source"] = exp_dgs[system]["source"]

# Turn results summary into a dataframe
results_summary_df = pd.DataFrame(results_summary).T
# Replace PDE2A with PDE2a
results_summary_df = results_summary_df.rename(index={"PDE2A": "PDE2a", "MDM2-PIP2": "MDM2-Pip2"})
results_summary_df.to_latex("final_analysis/results_summary.tex", escape=False)

In [None]:
# Repeat as above, but make a much more explicit table including all contributions to the final free energy change and the
# symmetry corrections. Make the systems the columns and the rows the different contributions to the free energy change
results_summary_detailed = {}
for system in final_dGs_all:
    for time in final_dGs_all[system]:
        title = f"{system} {time}"
        results_summary_detailed[title] = {}
        for leg_type in final_dGs_all[system][time]:
            if leg_type == "dgs":
                continue
            leg_name = leg_type.split(".")[1].capitalize()
            for stage_type in final_dGs_all[system][time][leg_type]:
                if stage_type == "dg":
                    continue
                stage_name = stage_type.split(".")[1].capitalize()
                dgs = np.array(final_dGs_all[system][time][leg_type][stage_type]["dg"])
                dg_tot = np.mean(dgs)
                dg_err = get_95_ci(dgs)
                result_str = f"{dg_tot:.2f}" + r" $\pm$ " + f"{dg_err:.2f}"
                results_summary_detailed[title][f"{leg_name} {stage_name}"] = result_str
        # Add in symmetry correction and restraint correction, along with experimental results
        restraint_corr = restraint_corrections[system]
        restraint_corr_str = f"{restraint_corr:.2f}"
        results_summary_detailed[title]["Restraint Correction"] = restraint_corr_str
        symmetry_corr = symmetry_corrections[system]
        symmetry_corr_str = f"{symmetry_corr:.2f}"
        results_summary_detailed[title]["Symmetry Correction"] = symmetry_corr_str
        exp_res = f"{exp_dgs[system]['Exp. dG']:.2f}" + r" $\pm$ " + f"{exp_dgs[system]['Exp. dG Err']:.2f}"
        results_summary_detailed[title][r"Exp. $\Delta G^o_\textrm{Bind}$"] = exp_res

results_summary_detailed_df = pd.DataFrame(results_summary_detailed).T
results_summary_detailed_df = results_summary_detailed_df.rename(index={"PDE2A": "PDE2a", "MDM2-PIP2": "MDM2-Pip2"})
results_summary_detailed_df.to_latex("final_analysis/results_summary_detailed.tex", escape=False)

In [None]:
# Double all the force constants, because of the definition in SOMD (kx rather than 0.5kx^2)
for system in restraint_dicts:
    force_constants = restraint_dicts[system]["force_constants"]
    for key in force_constants:
        force_constants[key] *= 2

# Now turn the restraints dict into a nice dataframe
restraint_df_dict = {}
for system in restraint_dicts:
    restraint_df_dict[system] = {}
    for index in restraint_dicts[system]["anchor_points"]:
        restraint_df_dict[system][index] = str(round(restraint_dicts[system]["anchor_points"][index]))
    for equil_val in restraint_dicts[system]["equilibrium_values"]:
        restraint_df_dict[system][equil_val] = f'{restraint_dicts[system]["equilibrium_values"][equil_val]:.2f}'
    for force_const in restraint_dicts[system]["force_constants"]:
        restraint_df_dict[system][force_const] = f'{restraint_dicts[system]["force_constants"][force_const]:.2f}'

replace_dict = {"r0": r"$r_0$ / $\mathrm{\AA}$", 
                "thetaA0": r"$\theta_{\mathrm{A}0}$ / $\mathrm{\AA}$", 
                "thetaB0": r"$\theta_{\mathrm{B}0}$ / Rad", 
                "phiA0": r"$\phi_{\mathrm{A}0}$ / Rad", 
                "phiB0": r"$\phi_{\mathrm{B}0}$ / Rad", 
                "phiC0": r"$\phi_{\mathrm{C}0}$ / Rad",
                "kr": r"$k_r$ / kcal mol$^{-1}$ $\mathrm{\AA}^{-2}$", 
                "kthetaA": r"$k_{\theta \mathrm{A}}$ / kcal mol$^{-1}$ $\mathrm{Rad}^{-2}$", 
                "kthetaB": r"$k_{\theta \mathrm{B}}$ / kcal mol$^{-1}$ $\mathrm{Rad}^{-2}$", 
                "kphiA": r"$k_{\phi \mathrm{A}}$ / kcal mol$^{-1}$ $\mathrm{Rad}^{-2}$", 
                "kphiB": r"$k_{\phi \mathrm{B}}$ / kcal mol$^{-1}$ $\mathrm{Rad}^{-2}$",
                "kphiC": r"$k_{\phi \mathrm{C}}$ / kcal mol$^{-1}$ $\mathrm{Rad}^{-2}$"}

# Create the dataframe
restraint_df = pd.DataFrame(restraint_df_dict)
# Replace names using the replace_dict
restraint_df = restraint_df.rename(columns=replace_dict, index=replace_dict)
# Save,making sure not to truncate the label units in the column names
with pd.option_context("max_colwidth", 1000):
    restraint_df.to_latex("final_analysis/restraint_params.tex", escape=False)

## Convergence Analysis

In [None]:
dgs_conv_nonequil.keys()

In [None]:
# List the systems in the order we want them plotted
systems = ["T4L", "MIF", "MDM2-PIP2", "PDE2A", "MDM2-Nutlin"]

# Plot the convergence of the free energy changes for each stage for each leg for each time for each system
def plot_dgs_conv(ax: plt.axes, dg_dict: Dict, leg: str, stage: str, time: str, system: str, equil: bool, show_final_dg: bool = True) -> None:
    """Plot the convergence of the free energy changes for a given leg, stage, time and system"""
    # Get the data for the given leg, stage, time and system
    data = dg_dict[system][time][leg][stage]
    # Get the simulation times
    n_runs = len(data["dgs"])
    times = np.array(data["gpu_times"])
    # Get the free energy changes
    dgs = np.array(data["dgs"])
    # Calculate the 95 % CI with scipy
    mean_free_energy = np.mean(dgs, axis=0)
    conf_int = (
    t.interval(
        0.95,
        len(dgs) - 1,
        mean_free_energy,
        scale=sem(dgs),
    )[1]
    - mean_free_energy
    )  # 95 % C.I.
    # Add a horizontal line at the final free energy change and dotted lines at the 95 % CI
    if show_final_dg:
        final_dg, final_dg_err = final_dGs[system][leg][stage]
        # Make the lines lighter
        ax.axhline(final_dg, color="black", linestyle="--", alpha=0.5)
        ax.axhline(final_dg + final_dg_err, color="black", linestyle=":", alpha=0.5)
        ax.axhline(final_dg - final_dg_err, color="black", linestyle=":", alpha=0.5)
    # Plot the free energy changes against the simulation times
    ax.plot(times, mean_free_energy, label=f"{time}")
    # Fill between to show the 95 % CI
    ax.fill_between(times, mean_free_energy - conf_int, mean_free_energy + conf_int, alpha=0.5)
    # Label the plot
    ax.set_xlabel("GPU Hours")
    ax.set_ylabel(r"$\Delta G$ / kcal mol$^{-1}$")
    stage_str = stage.split(".")[1].capitalize()
    leg_str = leg.split(".")[1].capitalize()
    sys_title = [rename.get(system, system)]
    ax.set_title(f"{sys_title} {leg_str} {stage_str}")
    ax.legend()
    plt.tight_layout()

    
def plot_dgs_conv_all(dg_dict: Dict, times = ["0.2 ns", "6 ns", "30 ns"], scale=False, show_final_dg=True) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the convergence of the free energy changes for each stage for each leg for each time for each system"""
    # Need 5 axes for the for each system
    n_systems = len(dg_dict)
    # Share y-axis scales along rows
    fig, axs = plt.subplots(5, n_systems, figsize=(4*5, 4*n_systems), dpi=300)
    for i, system in enumerate(systems):
        for time in times:
            ax_ind = 0
            for j, leg in enumerate(dg_dict[system][time]):
                for k, stage in enumerate(dg_dict[system][time][leg]):
                    plot_dgs_conv(axs[ax_ind, i], dg_dict, leg, stage, time, system, equil=True, show_final_dg=show_final_dg)
                    ax_ind += 1

    if scale:
        # Set the y-axis scales to be the same within rows (same stages, different systems)
        for i in range(5):
            # Get the largest difference between the upper and lower bounds of the y-axis
            y_diff = max([ax.get_ylim()[1] - ax.get_ylim()[0] for ax in axs[i, :]])
            y_diff /=2
            for j in range(n_systems):
                # Set the y-axis limits to be the mean +/- half the difference
                ax_mean = np.mean(axs[i, j].get_ylim())
                axs[i, j].set_ylim(ax_mean - y_diff, ax_mean + y_diff)
        
    return fig, axs

In [None]:
fig, ax = plot_dgs_conv_all(dgs_conv_nonequil, times = ["0.2 ns", "6 ns"], scale=False, show_final_dg = True)
fig.savefig("final_analysis/dgs_conv_nonequil.png", bbox_inches="tight", dpi=600)

## Check for Significant Inter-Run Differences Between Gradient Distributions

In [None]:
def get_sig_diff_grads(data: a3.analyse.GradientData) -> Tuple[float, float]:
    """
    Calculate the percentage of lambda windows where the gradient distributions
    are significantly different, using the Kruskal-Wallis test
    """
    n_lam = len(data.lam_vals)
    n_sig_diff = 0
    for i in range(n_lam):
        _ , p = kruskal(*data.subsampled_gradients[i])
        if p < 0.05:
            n_sig_diff += 1
    return n_lam, n_sig_diff

In [None]:
# Get a dict with the percentage of lambda windows where the gradients are significantly different
sig_diff = {}
for system in systems:
    sig_diff[system] = {}
    for time in grad_data[system]:
        sig_diff[system][time] = {}
        for leg in grad_data[system][time]:
            sig_diff[system][time][leg] = {}
            for stage in grad_data[system][time][leg]:
                n_lam, n_diff = get_sig_diff_grads(grad_data[system][time][leg][stage])
                sig_diff[system][time][leg][stage] = (n_diff / n_lam) * 100

In [None]:
# Plot the percentage of lambda windows where the gradients are significantly different all on one plot as a bar plot]
fig, ax = plt.subplots(figsize=(8, 4), dpi=600)
x = np.arange(len(systems))
width = 0.2
# Plot bound and free next to each other
for i, stage in enumerate(grad_data[system]["30 ns"]["bound"]):
    # Get single colour
    color = ax._get_lines.get_next_color()
    ax.bar(x + (i * width), [sig_diff[system]["30 ns"]["bound"][stage] for system in systems], width, label=f"Bound {stage.capitalize()}", edgecolor="k", alpha=1, color=color)
    if stage != "restrain":
        ax.bar(x + (i * width), [sig_diff[system]["30 ns"]["free"][stage] for system in systems], width, label=f"Free {stage.capitalize()}", edgecolor="k", alpha=1, color=color, hatch="///////")

# Set the names to display
system_names = [rename.get(system, system) for system in systems]
ax.set_xticks(x + width)
ax.set_xticklabels(system_names)
ax.set_ylabel("% Windows with Significant Inter-run \n Differences Between Gradient Distributions")
# Put label off to right of plot
ax.legend(bbox_to_anchor=(1.03, 0.7))
plt.tight_layout()
fig.savefig("final_analysis/sig_diff_grads.png", bbox_inches="tight", dpi=600)


## Plot Thermodynamic Length Metrics

In [None]:
# Sanity check that the weights are correct and we get the same result from trapezoidal integration
res_trap = np.trapz(np.array(grad_data["T4L"]["0.2 ns"]["bound"]["restrain"].gradients).mean(axis=1).mean(axis=1), grad_data["T4L"]["0.2 ns"]["bound"]["restrain"].lam_vals)
res_weights = np.sum(np.array(grad_data["T4L"]["0.2 ns"]["bound"]["restrain"].gradients).mean(axis=1).mean(axis=1)* grad_data["T4L"]["0.2 ns"]["bound"]["restrain"].lam_val_weights)
assert res_trap == res_weights

In [None]:
def get_imporovement_factor(grad_data: a3.analyse.GradientData, er_type: str) -> np.ndarray:
    """
    Get the improvement factor for a given gradient data object, error type and initial lambda values.
    """
    if er_type == "SEM":
        metric = grad_data.get_time_normalised_sems()
    elif er_type == "SD":
        metric = np.sqrt(grad_data.vars_intra)
    else:
        raise ValueError(f"Error type {er_type} not recognised.")
    opt = np.trapz(metric, grad_data.lam_vals)**2
    equal = np.trapz(metric**2, grad_data.lam_vals)
    # Get the predicted improvement factor
    pred = opt/equal
    # Return the improvement factor
    return pred

In [None]:
def plot_length_on_ax(ax: plt.Axes, grad_data: Dict[str, a3.analyse.GradientData], system: str, stage:str, len_type:str, free_only:bool=False)-> None:
    """Plot either 'root-var' or 'sem'-based thermodynamic length metrics on an axis"""
    line_style = {"bound": "-", "free": "--"}
    improvement_factors = {"bound": [], "free": []}
    for time in grad_data[system]:
        color = ax._get_lines.get_next_color()
        for leg in grad_data[system][time]:
            if leg == "bound" and free_only:
                continue
            # Ensure the data is present in the dictionary
            if stage not in grad_data[system][time][leg]:
                continue
            #if leg =="bound":
                #continue
            data = grad_data[system][time][leg][stage]
            quantity = data.get_time_normalised_sems(smoothen=False) if len_type == "SEM" else np.sqrt(data.vars_intra)
            ax.plot(data.lam_vals, quantity, color=color, linestyle=line_style[leg], label=f"{time} {leg}")
            improvement_factors[leg].append(get_imporovement_factor(data, len_type))

    legs = ["bound"] if stage == "restrain" else ["bound", "free"]
    for leg in legs:
        if leg == "bound" and free_only:
            continue
        av_improvement_factor = np.mean(improvement_factors[leg])
        sd_improvement_factor = np.std(improvement_factors[leg])
        offset = 0.04 if leg == "bound" else -0.04
        if free_only:
            offset = 0.04
        ax.text(0.965, 0.935 + offset, f"IF {leg.capitalize()}: {av_improvement_factor:.2f} $\pm$ {sd_improvement_factor:.2f}", transform=ax.transAxes, horizontalalignment="right", verticalalignment="top")
    sys_title = rename.get(system, system)
    ax.set_title(f"{sys_title} {stage.capitalize()}")
    ax.set_xlabel(r"$\lambda$")
    ylabel = r"$\sqrt{t_\lambda}\sigma\left(\left\langle\frac{\partial H}{\partial \lambda}\right\rangle_\lambda\right)$ / kcal mol$^{-1}$ ns$^{\frac{1}{2}}$" if len_type == "SEM" else r"$\sigma\left(\frac{\partial H}{\partial \lambda}\right)$ / kcal mol$^{-1}$"

    ax.set_ylabel(ylabel)


Plot the SDs

In [None]:
# Plot the SDs for all systems. Need 3 rows of 5 columns.
fig, axs = plt.subplots(3, 5, figsize=(17, 9), sharex=True)
for i, system in enumerate(systems):
    for j, stage in enumerate(grad_data[system]["0.2 ns"]["bound"]):
        plot_length_on_ax(axs[j, i], grad_data, system, stage, "SD")
# Show overall legend off to the side of all plots
axs[1, 4].legend(bbox_to_anchor=(1.05, 0.85))
plt.tight_layout()
fig.savefig("final_analysis/SD_non_adapt.png", bbox_inches="tight", dpi=600)

Plot the SEMs

In [None]:

# Plot the SEMs for all systems. Need 3 rows of 5 columns.
fig, axs = plt.subplots(3, 5, figsize=(17, 9), sharex=True)
for i, system in enumerate(systems):
    for j, stage in enumerate(grad_data[system]["0.2 ns"]["bound"]):
        plot_length_on_ax(axs[j, i], grad_data, system, stage, "SEM")
# Show overall legend off to the side of all plots
axs[1, 4].legend(bbox_to_anchor=(1.05, 0.85))
plt.tight_layout()
fig.savefig("final_analysis/SEM_non_adapt.png", bbox_inches="tight", dpi=600)

In [None]:
# Replot with just the free leg 
fig, axs = plt.subplots(2, 5, figsize=(17, 6), sharex=True)
for i, system in enumerate(systems):
    for j, stage in enumerate(grad_data[system]["0.2 ns"]["free"]):
        plot_length_on_ax(axs[j, i], grad_data, system, stage, "SEM", free_only=True)
# Show overall legend off to the side of all plots
axs[1, 4].legend(bbox_to_anchor=(1.05, 0.85))
plt.tight_layout()
fig.savefig("final_analysis/SEM_non_adapt_free_only.png", bbox_inches="tight", dpi=600)

In [None]:
# Taken from https://stackoverflow.com/questions/25812255/row-and-column-headers-in-matplotlibs-subplots
def add_headers(
    fig,
    *,
    row_headers=None,
    col_headers=None,
    row_pad=1,
    col_pad=5,
    rotate_row_headers=True,
    **text_kwargs
):
    # Based on https://stackoverflow.com/a/25814386

    axes = fig.get_axes()

    for ax in axes:
        sbs = ax.get_subplotspec()

        # Putting headers on cols
        if (col_headers is not None) and sbs.is_first_row():
            ax.annotate(
                col_headers[sbs.colspan.start],
                xy=(0.5, 1),
                xytext=(0, col_pad),
                xycoords="axes fraction",
                textcoords="offset points",
                ha="center",
                va="baseline",
                weight="bold",
                **text_kwargs,
            )

        # Putting headers on rows
        if (row_headers is not None) and sbs.is_first_col():
            ax.annotate(
                row_headers[sbs.rowspan.start],
                xy=(0, 0.5),
                xytext=(-ax.yaxis.labelpad - row_pad, 0),
                xycoords=ax.yaxis.label,
                textcoords="offset points",
                ha="right",
                va="center",
                rotation=rotate_row_headers * 90,
                # Make slightly bolder
                weight="bold",
                **text_kwargs,
            )

In [None]:
# Make example plots for main text with PDE2a as a fairly representative example.
fig, axs = plt.subplots(3, 2, figsize=(7, 11), sharex=True)
# Plot the SDs on the left column
for i, stage in enumerate(grad_data["PDE2A"]["0.2 ns"]["bound"]):
    plot_length_on_ax(axs[i, 0], grad_data, "PDE2A", stage, "SD")
# Plot the SEMs on the right column
for i, stage in enumerate(grad_data["PDE2A"]["0.2 ns"]["bound"]):
    plot_length_on_ax(axs[i, 1], grad_data, "PDE2A", stage, "SEM")

# Include legend off the bottom of the plot
axs[2, 0].legend(bbox_to_anchor=(0.8, -0.25))

# Remove all axis titles
for ax in axs.flatten():
    ax.set_title("")

# Create column and row titles
bold_headers=[
    r"$\mathbf{\sigma\left(\frac{\partial H}{\partial \lambda}\right)}$",
    r"$\mathbf{\sqrt{t_\lambda}\sigma\left(\left\langle\frac{\partial H}{\partial \lambda}\right\rangle_\lambda\right)}$",
]
add_headers(fig, row_headers=["Restrain", "Discharge", "Vanish"], col_headers=bold_headers, row_pad=5, col_pad=15, rotate_row_headers=True, fontsize=12)



# Save the figure
fig.tight_layout()
fig.savefig("final_analysis/PDE2a_SD_SEM.png", bbox_inches="tight", dpi=600)