# Create Synthetic Datasets for Testing

In order to test the performance of equilibration detection methods, we create sets of synthetic datasets modelled on long ABFE runs for six different systems (see the 30 ns non-adaptive runs [here](https://chemrxiv.org/engage/chemrxiv/article-details/6670b524c9c6a5c07aafa972) for details).

This notebook is organised as follows:

- [Load the data](#load)
- [Fit the data](#fit)
    - [Model trends with exponentials](#exp)
    - [Model stationary distributions with uncorrelated Gaussian noise](#noise)
    - [Reintroduce desired correlation structure with Cholesky decomposition](#cholesky)
- [Create synthetic datasets](#synth)


In [None]:
import red
import pickle as pkl
import matplotlib.pyplot as plt
from pymbar import timeseries
import numpy as np
import os
import scipy
import scipy.stats as st
import tqdm
import pandas as pd
from scipy.linalg import toeplitz
import seaborn as sns

# Set ggplot style
plt.style.use('ggplot')

In [None]:
# Utility functions

def get_subplots(systems: list[str]) -> tuple[plt.Figure, list[plt.Axes]]:
    # Plot two columns side-by-side
    n_cols = min(2, len(systems))
    n_rows = int(np.ceil(len(systems) / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3*n_rows))
    if len(systems) == 1:
        axs = [axs]
    else:
        axs = axs.flatten()

    # Set titles
    for i, system in enumerate(systems):
        axs[i].set_title(system)

    # Remove any unused axes
    for i in range(len(systems), len(axs)):
        fig.delaxes(axs[i])

    return fig, axs

def block_average(data: np.ndarray, n_blocks: int = 100, block_size: int | None = None) -> np.ndarray:
    """Block avarage the data using the requested number of blocks."""
    if n_blocks and block_size is None:
        n_samples = len(data)
        block_size = n_samples // n_blocks
    elif block_size and n_blocks is None:
        n_samples = len(data)
        n_blocks = n_samples // block_size
    else:
        raise ValueError("Either n_blocks or block_size should be provided.")
    blocks = np.array_split(data[:n_blocks * block_size], n_blocks)
    return np.array([np.mean(block) for block in blocks])

def block_average_times(times: np.ndarray, n_blocks: int = 100) -> np.ndarray:
    """Get the times corresponding to block averages."""
    n_samples = len(times)
    block_size = n_samples // n_blocks
    return np.array([times[block_size * i] for i in range(n_blocks)])

def get_time_idxs(times: np.ndarray, time: float) -> int:
    """Get the index of the first time greater than the requested time."""
    return int(np.argmax(times >= time))

## Load the data <a name="load"></a>

In [None]:
with open('gradient_arrays_30ns.pkl', 'rb') as f:
    gradient_arrays_30ns = pkl.load(f)

LAM = 0.45
# LAM = 0.95

systems = list(gradient_arrays_30ns.keys())
n_repeats = gradient_arrays_30ns[systems[0]]["bound"]["vanish"][0.0]["grads"].shape[0]

# Get timeseries of the overall free energy change for the bound vanish stages using thermodynamic integration with the trapzoidal rule.
# We use the bound vanish stage as it has the most pronounced equliibration behaviour.
timeseries_data = {}
for system in systems:
    all_data = gradient_arrays_30ns[system]["bound"]["vanish"]
    lam_values = list(all_data.keys())
    # Times are the same for all lambda windows, so just take the first set of times
    times = np.array(all_data[lam_values[0]]["times"])
    # Integrate over the lambda values to get the overall dgs at each time (kcal mol-1)
    dgs = np.zeros([len(times), n_repeats])
    for time_idx in range(len(times)):
        for repeat_idx in range(n_repeats):
            # Atrificially just use one lambda value
            dgs[time_idx, repeat_idx] = all_data[LAM]["grads"][repeat_idx,time_idx]
    timeseries_data[system] = {"times": times, "dgs": dgs}


In [None]:
# Plot the timeseries data for each system

fig, axs = get_subplots(systems)
for i, system in enumerate(systems):
    axs[i].plot(timeseries_data[system]["times"], timeseries_data[system]["dgs"], alpha=0.5, label=[f"Repeat {j + 1}" for j in range(timeseries_data[system]["dgs"].shape[1])])
    axs[i].plot(timeseries_data[system]["times"], np.mean(timeseries_data[system]["dgs"], axis=1), color='black', alpha=0.5, label="Mean")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.4, 0.8), loc='upper left')
fig.savefig("output_single/timeseries_dg.png", dpi=300, bbox_inches='tight')

In [None]:
# Plot block-averaged data to better visualise trends
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    # Apply block averaging to each replicate run individually
    block_averaged_dgs = np.array([block_average(timeseries_data[system]["dgs"][:,j]) for j in range(n_repeats)]).T
    times = block_average_times(timeseries_data[system]["times"])
    axs[i].plot(times, block_averaged_dgs, alpha=0.5, label=[f"Repeat {j + 1}" for j in range(n_repeats)])
    axs[i].plot(times, np.mean(block_averaged_dgs, axis=1), color='black', alpha=0.8, label="Mean")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.4, 0.8), loc='upper left')
fig.savefig("output_single/timeseries_block_dg.png", dpi=300, bbox_inches='tight')

## Fit the data <a name="fit"></a>

### Model trends with exponentials <a name="exp"></a>

In [None]:
# For each of the cases, take the last 20 ns to be equilibrated. Fit an exponential decay to the data after subtracting the mean value of the last 20 ns.

def exp_decay(x: float | np.ndarray, a: float, b: float) -> float:
    return a * np.exp(-b * x)

synthetic_data_params = {system:{} for system in systems}
idx_10ns = get_time_idxs(timeseries_data[systems[0]]["times"], 10)

for system in systems:
    mean = np.mean(timeseries_data[system]["dgs"][idx_10ns:,:])
    synthetic_data_params[system]["equil_region_mean"] = mean

    # Subtract the mean from the data before fitting the exponential decay
    shifted_data = timeseries_data[system]["dgs"] - mean

    # Fit the exponential decay to the data
    popt, pcov = scipy.optimize.curve_fit(exp_decay, timeseries_data[system]["times"], shifted_data.mean(axis=1), p0=[30, 1.5])
    synthetic_data_params[system]["exp_params"] = popt

    # Calculate the half-life of the exponential decay
    synthetic_data_params[system]["half_life"] = np.log(2) / popt[1] # in ns

    # Subtract the first exponential fit and add another to model the fast initial decay
    fit_exp_series = exp_decay(timeseries_data[system]["times"], *popt)
    # Change the shape of fit_exp_series to match the shape of shifted_data
    fit_exp_series = np.tile(fit_exp_series, (n_repeats, 1)).T
    shifted_data -= fit_exp_series
    popt_fast, pcov_fast = scipy.optimize.curve_fit(exp_decay, timeseries_data[system]["times"], shifted_data.mean(axis=1), p0=[30, 0.5])
    # If we get a negative a, set it to 0
    if popt_fast[0] < 0:
        popt_fast[0] = 0
        popt_fast[1] = 0
    fast_half_life = np.log(2) / popt_fast[1] if popt_fast[1] > 0 else np.inf

    # If the fast half-life is greater than the slow half-life, set the fast half-life to infinity and the fast exponential decay to 0
    if fast_half_life > synthetic_data_params[system]["half_life"]:
        popt_fast[0] = 0
        popt_fast[1] = 0
        fast_half_life = np.inf
        
    synthetic_data_params[system]["fast_exp_params"] = popt_fast
    synthetic_data_params[system]["fast_half_life"] = fast_half_life


    print(30*"#")
    print(f"System: {system}")
    print(f"Half-life: {synthetic_data_params[system]['half_life']:.2f} ns")
    print(f"Exponential decay parameters: a = {popt[0]:.2f}, b = {popt[1]:.2f}")
    print(f"Fast half-life: {fast_half_life:.2f} ns")
    print(f"Fast exponential decay parameters: a = {popt_fast[0]:.2f}, b = {popt_fast[1]:.2f}")


In [None]:
# Plot the block-averaged mean traces with the exponential decay fit
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    # Apply block averaging to each replicate run individually
    block_averaged_dgs = np.array([block_average(timeseries_data[system]["dgs"][:,j]) for j in range(n_repeats)]).T
    times = block_average_times(timeseries_data[system]["times"])
    axs[i].plot(times, block_averaged_dgs, alpha=0.5, label=[f"Repeat {j + 1}" for j in range(n_repeats)])
    axs[i].plot(times, np.mean(block_averaged_dgs, axis=1), color='black', alpha=0.8, label="Mean")
    # Plot slow and fast exponential fits, and the combined fit
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='green', label="Exponential decay fit", alpha=0.5)
    # Plot the fast exponential fit
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='blue', label="Fast exponential decay fit", alpha=0.5)
    # Plot the combined fit
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='red', label="Combined fit")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")
    # Label the plot with the exponential decay parameters and half-life
    axs[i].text(0.5, 0.9, f"Half-life: {synthetic_data_params[system]['half_life']:.2f} ns", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)
    axs[i].text(0.5, 0.8, f"a = {synthetic_data_params[system]['exp_params'][0]:.2f} kcal mol$^{{-1}}$", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)


fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.3, 0.8), loc='upper left')
fig.savefig("output_single/timeseries_block_dg_exp_fit.png", dpi=300, bbox_inches='tight')

In [None]:
# As above, but zoom in on the first ns

fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    # Apply block averaging to each replicate run individually
    block_averaged_dgs = np.array([block_average(timeseries_data[system]["dgs"][:,j], 5000) for j in range(n_repeats)]).T
    times = block_average_times(timeseries_data[system]["times"], 5000)
    axs[i].plot(times, block_averaged_dgs, alpha=0.5, label=[f"Repeat {j + 1}" for j in range(n_repeats)])
    axs[i].plot(times, np.mean(block_averaged_dgs, axis=1), color='black', alpha=0.8, label="Mean")
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='green', label="Exponential decay fit", alpha=0.5)
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='blue', label="Fast exponential decay fit", alpha=0.5)
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='red', label="Combined fit")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")
    axs[i].set_xlim(0, 0.2)
    # Label the plot with the exponential decay parameters and half-life
    axs[i].text(0.5, 0.9, f"Half-life: {synthetic_data_params[system]['half_life']:.2f} ns", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)
    axs[i].text(0.5, 0.8, f"a = {synthetic_data_params[system]['exp_params'][0]:.2f} kcal mol$^{{-1}}$", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.3, 0.8), loc='upper left')
fig.savefig("output_single/timeseries_block_dg_exp_fit_zoom.png", dpi=300, bbox_inches='tight')

### Model stationary distributions with uncorrelated Gaussian noise <a name="noise"></a>

In [None]:
def plot_normality(data: np.ndarray, axs: list[plt.Axes]) -> None:
    """
    Plot the histogram and QQ plot for a given set of data.

    Parameters
    ----------
    data : np.ndarray
        The data to plot.

    Returns
    -------
    None
    """
    # Plot the histogram, kernel density estimate, and QQ plot
    axs[0].hist(data, edgecolor="black")
    sns.kdeplot(data, ax=axs[1], color="black", linewidth=2)
    st.probplot(data, plot=axs[2])

    # Set the axis labels
    axs[0].set_xlabel("Value")
    axs[0].set_ylabel("Frequency")
    axs[0].set_title("Histogram")
    axs[1].set_xlabel("Value")
    axs[1].set_ylabel("Frequency")
    axs[1].set_title("Kernel Density Estimate")
    axs[2].set_xlabel("Theoretical Normal Quantiles")
    axs[2].set_ylabel("Ordered Values")
    axs[2].set_title("QQ Plot")

    # Compute the Shapiro-Wilk test and print the p value
    _, p_value = st.shapiro(data)
    axs[2].text(
        0.5,
        0.95,
        f"Shapiro-Wilk p-value: {p_value:.2f}",
        transform=axs[2].transAxes,
        horizontalalignment="center",
        verticalalignment="top",
    )

In [None]:
# Check how Gaussian the stationary distributions are for each system

fig, axs = plt.subplots(len(systems), 3, figsize=(12, 4 * len(systems)))
for i, system in enumerate(systems):
    # Get the stationary distribution for the mean trace
    stationary_dgs = timeseries_data[system]["dgs"][idx_10ns:,:].mean(axis=1)
    # Plot the normality of the stationary distribution
    plot_normality(stationary_dgs, axs[i])
    axs[i, 0].set_title(f"{system} Stationary Distribution")

fig.tight_layout()
fig.savefig("output_single/stationary_dg_normality.png", dpi=300, bbox_inches='tight')

### Calculate the autocorrelation of the stationary distributions <a name="autocorr"></a>

In [None]:
from red.variance import _get_autocovariance, _get_gamma_cap, _get_initial_convex_sequence, _get_initial_positive_sequence, _get_initial_monotone_sequence

Fitting the covariance - we take the autocovariance at lag times of 0 and 1 directly from the data, and calculate later lags based on Geyer's initial convex sequence method. This is intended to fit initial lags well (which make large contributions to the autocovariance), while removing effects from later lag times caused by non-stationarity.

In [None]:
for system in systems:
    # Get the stationary distribution for the mean trace
    stationary_dgs = timeseries_data[system]["dgs"][idx_10ns:,:].mean(axis=1)
    # Reshape into a 2D array, as expected by red
    stationary_dgs.reshape(1, -1)
    # Get the autocovariance series
    autocov_series = _get_autocovariance(stationary_dgs)
    # Get the gamma series using Geyer's initial convex sequence method
    gamma_series = _get_gamma_cap(autocov_series)[1:]
    # Get the initial convex sequence
    initial_convex_sequence = _get_initial_convex_sequence(gamma_series)
    # Convert the initial convex gamma sequence back into an autocovariance series
    # Do this by interpolating, doubling the amount of data
    x_interpolate = np.arange(2*len(initial_convex_sequence))
    # x_gamma is at x = 0.5, 2.5, 4.5, etc.
    x_gamma = np.arange(len(initial_convex_sequence)) * 2 + 0.5
    autocov_convex = np.interp(x_interpolate, x_gamma, initial_convex_sequence) / 2
    autocov_convex = np.concatenate([autocov_series[:2], autocov_convex])
    # Save parameters
    synthetic_data_params[system]["autocov_convex"] = autocov_convex
    synthetic_data_params[system]["autocov"] = autocov_series
    # Get the total variance (accounting for autocorrelation) from the autocov_convex data
    synthetic_data_params[system]["total_variance"] = 2 * np.sum(autocov_convex) - autocov_convex[0]
    synthetic_data_params[system]["max_lag_idx"] = len(autocov_convex) - 1
    

In [None]:
# Plot all of the autocovariance series against the initial convex sequences
fig, axs = get_subplots(systems)

first_idx = 0
last_idx = None

for i, system in enumerate(systems):
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov"][first_idx:last_idx])), synthetic_data_params[system]["autocov"][first_idx:last_idx], label="Autocovariance")
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov_convex"][first_idx:last_idx])), synthetic_data_params[system]["autocov_convex"][first_idx:last_idx], label="Initial Convex Sequence")
    axs[i].set_xlabel("Lag")
    axs[i].set_ylabel("Autocovariance / kcal$^2$ mol$^{-2}$")
    axs[i].set_title(system)
    # Add the total variance to the plot
    axs[i].text(0.5, 0.7, f"Total Variance:\n {synthetic_data_params[system]['total_variance']:.2f} kcal$^2$ mol$^{{-2}}$", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)
    # Add max lag index to the plot
    axs[i].text(0.5, 0.5, f"Max Lag Index:\n {synthetic_data_params[system]['max_lag_idx']}", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.25, 0.7), loc='upper left')
fig.savefig("output_single/autocovariance_initial_convex_long.png", dpi=300, bbox_inches='tight')


In [None]:
# Plot all of the autocovariance series against the initial convex sequences
fig, axs = get_subplots(systems)

first_idx = 0
last_idx = 10

for i, system in enumerate(systems):
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov"][first_idx:last_idx])), synthetic_data_params[system]["autocov"][first_idx:last_idx], label="Autocovariance", marker='o')
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov_convex"][first_idx:last_idx])), synthetic_data_params[system]["autocov_convex"][first_idx:last_idx], label="Initial Convex Sequence", marker='o')
    axs[i].set_xlabel("Lag")
    axs[i].set_ylabel("Autocovariance / kcal$^2$ mol$^{-2}$")
    axs[i].set_title(system)

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.25, 0.7), loc='upper left')
fig.savefig("output_single/autocovariance_initial_convex_zoom.png", dpi=300, bbox_inches='tight')


In [None]:
def format_significant_figures(x, sig_figs=2):
    return f"{x:.{sig_figs}g}"

# Create the DataFrame
overview_df = pd.DataFrame(columns=["Half-life (ns)", "a (kcal mol$^{-1}$)", "Fast Half-life (ns)", "Fast a (kcal mol$^{-1}$)", "Total Variance (kcal$^2$ mol$^{-2}$)", "Max Lag Index"])
for system in systems:
    overview_df.loc[system] = [synthetic_data_params[system]["half_life"], synthetic_data_params[system]["exp_params"][0], synthetic_data_params[system]["fast_half_life"], synthetic_data_params[system]["fast_exp_params"][0], synthetic_data_params[system]["total_variance"], synthetic_data_params[system]["max_lag_idx"]]

# Apply the rounding function to all elements in the DataFrame
overview_df = overview_df.applymap(lambda x: format_significant_figures(x, 2))

# Save csv and latex
overview_df.to_csv("output_single/overview.csv")
latex_str = overview_df.to_latex("output_single/overview.tex",index=True, escape=False)

# Display the DataFrame
overview_df

In [None]:
with open('output_single/synthetic_data_params.pkl', 'wb') as f:
     pkl.dump(synthetic_data_params, f)

## Create synthetic datasets <a name="synth"></a>

Generate the following synthetic datasets:

- "Standard" - Standard synthetic data as described above, up to 8 ns as this is a typical length for an ABFE calculation window
- "Short" - Use only the first 0.2 ns of data
- "Subsampled" - Keep only 1 out of every 100 points
- "Noisy" - Increase the variance by a factor of 5 
- "Block Averaged" - Decrease the variance by a factor of 10

In [None]:
if not os.environ.get("FIGURES_ONLY", False):

    def get_cholesky(autocov_fn: np.ndarray, n_data: int) -> np.ndarray:
        """Get the Cholesky decomposition of the autocovariance function"""
        # Figure out if we need to pad or truncate the autocovariance function
        if len(autocov_fn) > n_data:
            autocov_fn = autocov_fn[:n_data]
        else:
            autocov_fn = np.pad(autocov_fn, (0, n_data - len(autocov_fn)), 'constant')

        # Generate the autocovariance matrix
        autocov_matrix = toeplitz(autocov_fn)

        # Get Cholesky decomposition.
        cholesky = scipy.linalg.cholesky(autocov_matrix, lower=True)
        return cholesky

    def generate_data(cholesky_mat: np.ndarray, scale:float = 1) -> np.ndarray:
        """Generate new data from an autocovariance function"""
        white_noise = np.random.normal(size=cholesky_mat.shape[0], scale=scale)
        return np.dot(cholesky_mat, white_noise)


    def generate_data_with_transient(cholesky_mat: np.ndarray,
                                    times: np.ndarray,
                                    slow_exponential_params: np.ndarray,
                                    fast_exponential_params: np.ndarray,
                                    scale: float = 1) -> np.ndarray:
        """Generate synthetic data for equilibration detection with an exponential transient"""
        # Get the exponential transient
        n_data= cholesky_mat.shape[0]
        slow_transient_data = exp_decay(times, *slow_exponential_params)
        fast_transient_data = exp_decay(times, *fast_exponential_params)

        # Get the stationary data
        stationary_data = generate_data(cholesky_mat, scale=scale)

        # Add the two together
        return stationary_data + slow_transient_data + fast_transient_data
    # Generate data up to 8 ns

    SET_NAME = "standard"
    N_REPEATS = 1000
    IDX = get_time_idxs(timeseries_data[systems[0]]["times"], 8)

    synthetic_data_bound_vanish = {}
    synthetic_data_bound_vanish[SET_NAME] = {}

    for system in tqdm.tqdm(systems, desc=system):
        synthetic_data_bound_vanish[SET_NAME][system] = {}
        autocov_fn = synthetic_data_params[system]["autocov_convex"]
        cholesky_mat = get_cholesky(autocov_fn, IDX)
        times = timeseries_data[system]["times"][:IDX]
        synthetic_data_bound_vanish[SET_NAME]["times"] = times
        for i in range(N_REPEATS):
            data = generate_data_with_transient(cholesky_mat,
                                                times,
                                                synthetic_data_params[system]["exp_params"],
                                                synthetic_data_params[system]["fast_exp_params"])
            synthetic_data_bound_vanish[SET_NAME][system][i] = {}
            synthetic_data_bound_vanish[SET_NAME][system][i]["data"] = data

    with open("output_single/synthetic_data_bound_vanish.pkl", "wb") as f:
        pkl.dump(synthetic_data_bound_vanish, f)