# Create Synthetic Datasets for Testing Using the Free Vanish Stage

In order to test the performance of equilibration detection methods, we create sets of synthetic datasets modelled on long ABFE runs for six different systems (see the 30 ns non-adaptive runs [here](https://chemrxiv.org/engage/chemrxiv/article-details/6670b524c9c6a5c07aafa972) for details). Here, we create synthetic data based on the free vanish stage of an ABFE calculation.

This notebook is organised as follows:

- [Load the data](#load)
- [Fit the data](#fit)
    - [Model trends with exponentials](#exp)
    - [Model stationary distributions with uncorrelated Gaussian noise](#noise)
    - [Reintroduce desired correlation structure with Cholesky decomposition](#cholesky)
- [Create synthetic datasets](#synth)


In [None]:
import red
import pickle as pkl
import matplotlib.pyplot as plt
from pymbar import timeseries
import numpy as np
import scipy
import scipy.stats as st
import tqdm
import pandas as pd
from scipy.linalg import toeplitz
import seaborn as sns

# Set ggplot style
plt.style.use('ggplot')

In [None]:
# Utility functions

def get_subplots(systems: list[str]) -> tuple[plt.Figure, list[plt.Axes]]:
    # Plot two columns side-by-side
    n_cols = min(2, len(systems))
    n_rows = int(np.ceil(len(systems) / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3*n_rows))
    if len(systems) == 1:
        axs = [axs]
    else:
        axs = axs.flatten()

    # Set titles
    for i, system in enumerate(systems):
        axs[i].set_title(system)

    # Remove any unused axes
    for i in range(len(systems), len(axs)):
        fig.delaxes(axs[i])

    return fig, axs

def block_average(data: np.ndarray, n_blocks: int = 100) -> np.ndarray:
    """Block avarage the data using the requested number of blocks."""
    n_samples = len(data)
    block_size = n_samples // n_blocks
    blocks = np.array_split(data[:n_blocks * block_size], n_blocks)
    return np.array([np.mean(block) for block in blocks])

def block_average_times(times: np.ndarray, n_blocks: int = 100) -> np.ndarray:
    """Get the times corresponding to block averages."""
    n_samples = len(times)
    block_size = n_samples // n_blocks
    return np.array([times[block_size * i] for i in range(n_blocks)])

def get_time_idxs(times: np.ndarray, time: float) -> int:
    """Get the index of the first time greater than the requested time."""
    return int(np.argmax(times >= time))

## Load the data <a name="load"></a>

In [None]:
with open('gradient_arrays_30ns.pkl', 'rb') as f:
    gradient_arrays_30ns = pkl.load(f)

systems = list(gradient_arrays_30ns.keys())
n_repeats = gradient_arrays_30ns[systems[0]]["free"]["vanish"][0.0]["grads"].shape[0]

# Get timeseries of the overall free energy change for the bound vanish stages using thermodynamic integration with the trapzoidal rule.
# We use the bound vanish stage as it has the most pronounced equliibration behaviour.
timeseries_data = {}
for system in systems:
    all_data = gradient_arrays_30ns[system]["free"]["vanish"]
    lam_values = list(all_data.keys())
    # Times are the same for all lambda windows, so just take the first set of times
    times = np.array(all_data[lam_values[0]]["times"])
    # Integrate over the lambda values to get the overall dgs at each time (kcal mol-1)
    dgs = np.zeros([len(times), n_repeats])
    for lam_idx in range(len(times)):
        for repeat_idx in range(n_repeats):
            dgs[lam_idx, repeat_idx] = np.trapz([all_data[lam]["grads"][repeat_idx,lam_idx] for lam in lam_values], x=lam_values)
    timeseries_data[system] = {"times": times, "dgs": dgs}


In [None]:
# Plot the timeseries data for each system

fig, axs = get_subplots(systems)
for i, system in enumerate(systems):
    axs[i].plot(timeseries_data[system]["times"], timeseries_data[system]["dgs"], alpha=0.5, label=[f"Repeat {j + 1}" for j in range(timeseries_data[system]["dgs"].shape[1])])
    axs[i].plot(timeseries_data[system]["times"], np.mean(timeseries_data[system]["dgs"], axis=1), color='black', alpha=0.5, label="Mean")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.4, 0.8), loc='upper left')
fig.savefig("output_free/timeseries_dg.png", dpi=300, bbox_inches='tight')

In [None]:
# Plot block-averaged data to better visualise trends
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    # Apply block averaging to each replicate run individually
    block_averaged_dgs = np.array([block_average(timeseries_data[system]["dgs"][:,j]) for j in range(n_repeats)]).T
    times = block_average_times(timeseries_data[system]["times"])
    axs[i].plot(times, block_averaged_dgs, alpha=0.5, label=[f"Repeat {j + 1}" for j in range(n_repeats)])
    axs[i].plot(times, np.mean(block_averaged_dgs, axis=1), color='black', alpha=0.8, label="Mean")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.4, 0.8), loc='upper left')
fig.savefig("output_free/timeseries_block_dg.png", dpi=300, bbox_inches='tight')

## Fit the data <a name="fit"></a>

### Model trends with exponentials <a name="exp"></a>

In [None]:
# For each of the cases, take the last 20 ns to be equilibrated. Fit an exponential decay to the data after subtracting the mean value of the last 20 ns.

def exp_decay(x: float, a: float, b: float) -> float:
    return a * np.exp(-b * x)

synthetic_data_params = {system:{} for system in systems}
idx_10ns = get_time_idxs(timeseries_data[systems[0]]["times"], 10)

for system in systems:
    mean = np.mean(timeseries_data[system]["dgs"][idx_10ns:,:])
    synthetic_data_params[system]["equil_region_mean"] = mean

    # Subtract the mean from the data before fitting the exponential decay
    shifted_data = timeseries_data[system]["dgs"] - mean

    # Fit the exponential decay to the data
    popt, pcov = scipy.optimize.curve_fit(exp_decay, timeseries_data[system]["times"], shifted_data.mean(axis=1), p0=[30, 1.5])
    synthetic_data_params[system]["exp_params"] = popt
    
    # Calculate the half-life of the exponential decay
    synthetic_data_params[system]["half_life"] = np.log(2) / popt[1] # in ns

    # If the half life is ridiculously fast, assume no initial transient and avoid fitting more exponentials
    if synthetic_data_params[system]["half_life"] < 0.05:
        # Zero current exponential parameters
        synthetic_data_params[system]["exp_params"] = [0, 0]
        synthetic_data_params[system]["half_life"] = np.inf
        # Zero fast exponential parameters
        fast_half_life = np.inf
        popt_fast = [0, 0]
        synthetic_data_params[system]["fast_exp_params"] = popt_fast
        synthetic_data_params[system]["fast_half_life"] = fast_half_life

    else:
        # Subtract the first exponential fit and add another to model the fast initial decay
        fit_exp_series = exp_decay(timeseries_data[system]["times"], *popt)
        # Change the shape of fit_exp_series to match the shape of shifted_data
        fit_exp_series = np.tile(fit_exp_series, (n_repeats, 1)).T
        shifted_data -= fit_exp_series
        popt_fast, pcov_fast = scipy.optimize.curve_fit(exp_decay, timeseries_data[system]["times"], shifted_data.mean(axis=1), p0=[30, 0.5])
        # If we get a negative a, set it to 0
        if popt_fast[0] < 0:
            popt_fast[0] = 0
            popt_fast[1] = 0
        fast_half_life = np.log(2) / popt_fast[1] if popt_fast[1] > 0 else np.inf
        synthetic_data_params[system]["fast_exp_params"] = popt_fast
        synthetic_data_params[system]["fast_half_life"] = fast_half_life


    print(30*"#")
    print(f"System: {system}")
    print(f"Half-life: {synthetic_data_params[system]['half_life']:.2f} ns")
    print(f"Exponential decay parameters: a = {popt[0]:.2f}, b = {popt[1]:.2f}")
    print(f"Fast half-life: {fast_half_life:.2f} ns")
    print(f"Fast exponential decay parameters: a = {popt_fast[0]:.2f}, b = {popt_fast[1]:.2f}")


In [None]:
# Plot the block-averaged mean traces with the exponential decay fit
fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    # Apply block averaging to each replicate run individually
    block_averaged_dgs = np.array([block_average(timeseries_data[system]["dgs"][:,j]) for j in range(n_repeats)]).T
    times = block_average_times(timeseries_data[system]["times"])
    axs[i].plot(times, block_averaged_dgs, alpha=0.5, label=[f"Repeat {j + 1}" for j in range(n_repeats)])
    axs[i].plot(times, np.mean(block_averaged_dgs, axis=1), color='black', alpha=0.8, label="Mean")
    # Plot slow and fast exponential fits, and the combined fit
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='green', label="Exponential decay fit", alpha=0.5)
    # Plot the fast exponential fit
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='blue', label="Fast exponential decay fit", alpha=0.5)
    # Plot the combined fit
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='red', label="Combined fit")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")
    # Label the plot with the exponential decay parameters and half-life
    axs[i].text(0.5, 0.9, f"Half-life: {synthetic_data_params[system]['half_life']:.2f} ns", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)
    axs[i].text(0.5, 0.8, f"a = {synthetic_data_params[system]['exp_params'][0]:.2f} kcal mol$^{{-1}}$", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)


fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.3, 0.8), loc='upper left')
fig.savefig("output_free/timeseries_block_dg_exp_fit.png", dpi=300, bbox_inches='tight')

In [None]:
# As above, but zoom in on the first ns

fig, axs = get_subplots(systems)

for i, system in enumerate(systems):
    # Apply block averaging to each replicate run individually
    block_averaged_dgs = np.array([block_average(timeseries_data[system]["dgs"][:,j], 5000) for j in range(n_repeats)]).T
    times = block_average_times(timeseries_data[system]["times"], 5000)
    axs[i].plot(times, block_averaged_dgs, alpha=0.5, label=[f"Repeat {j + 1}" for j in range(n_repeats)])
    axs[i].plot(times, np.mean(block_averaged_dgs, axis=1), color='black', alpha=0.8, label="Mean")
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='green', label="Exponential decay fit", alpha=0.5)
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='blue', label="Fast exponential decay fit", alpha=0.5)
    axs[i].plot(times, exp_decay(times, *synthetic_data_params[system]["exp_params"]) + exp_decay(times, *synthetic_data_params[system]["fast_exp_params"]) + synthetic_data_params[system]["equil_region_mean"], color='red', label="Combined fit")
    axs[i].set_xlabel("Time / ns")
    axs[i].set_ylabel("$\Delta G$ / kcal mol$^{-1}$")
    axs[i].set_xlim(0, 0.2)
    # Label the plot with the exponential decay parameters and half-life
    axs[i].text(0.5, 0.9, f"Half-life: {synthetic_data_params[system]['half_life']:.2f} ns", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)
    axs[i].text(0.5, 0.8, f"a = {synthetic_data_params[system]['exp_params'][0]:.2f} kcal mol$^{{-1}}$", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.3, 0.8), loc='upper left')
fig.savefig("output_free/timeseries_block_dg_exp_fit_zoom.png", dpi=300, bbox_inches='tight')

### Model stationary distributions with uncorrelated Gaussian noise <a name="noise"></a>

In [None]:
def plot_normality(data: np.ndarray, axs: list[plt.Axes]) -> None:
    """
    Plot the histogram and QQ plot for a given set of data.

    Parameters
    ----------
    data : np.ndarray
        The data to plot.

    Returns
    -------
    None
    """
    # Plot the histogram, kernel density estimate, and QQ plot
    axs[0].hist(data, edgecolor="black")
    sns.kdeplot(data, ax=axs[1], color="black", linewidth=2)
    st.probplot(data, plot=axs[2])

    # Set the axis labels
    axs[0].set_xlabel("Value")
    axs[0].set_ylabel("Frequency")
    axs[0].set_title("Histogram")
    axs[1].set_xlabel("Value")
    axs[1].set_ylabel("Frequency")
    axs[1].set_title("Kernel Density Estimate")
    axs[2].set_xlabel("Theoretical Normal Quantiles")
    axs[2].set_ylabel("Ordered Values")
    axs[2].set_title("QQ Plot")

    # Compute the Shapiro-Wilk test and print the p value
    _, p_value = st.shapiro(data)
    axs[2].text(
        0.5,
        0.95,
        f"Shapiro-Wilk p-value: {p_value:.2f}",
        transform=axs[2].transAxes,
        horizontalalignment="center",
        verticalalignment="top",
    )

In [None]:
# Check how Gaussian the stationary distributions are for each system

fig, axs = plt.subplots(len(systems), 3, figsize=(12, 4 * len(systems)))
for i, system in enumerate(systems):
    # Get the stationary distribution for the mean trace
    stationary_dgs = timeseries_data[system]["dgs"][idx_10ns:,:].mean(axis=1)
    # Plot the normality of the stationary distribution
    plot_normality(stationary_dgs, axs[i])
    axs[i, 0].set_title(f"{system} Stationary Distribution")

fig.tight_layout()
fig.savefig("output_free/stationary_dg_normality.png", dpi=300, bbox_inches='tight')

### Calculate the autocorrelation of the stationary distributions <a name="autocorr"></a>

In [None]:
from red.variance import _get_autocovariance, _get_gamma_cap, _get_initial_convex_sequence, _get_initial_positive_sequence, _get_initial_monotone_sequence

Fitting the covariance - we take the autocovariance at lag times of 0 and 1 directly from the data, and calculate later lags based on Geyer's initial convex sequence method. This is intended to fit initial lags well (which make large contributions to the autocovariance), while removing effects from later lag times caused by non-stationarity.

In [None]:
for system in systems:
    # Get the stationary distribution for the mean trace
    stationary_dgs = timeseries_data[system]["dgs"][idx_10ns:,:].mean(axis=1)
    # Reshape into a 2D array, as expected by red
    stationary_dgs.reshape(1, -1)
    # Get the autocovariance series
    autocov_series = _get_autocovariance(stationary_dgs)
    # Get the gamma series using Geyer's initial convex sequence method
    gamma_series = _get_gamma_cap(autocov_series)[1:]
    # Get the initial convex sequence
    initial_convex_sequence = _get_initial_convex_sequence(gamma_series)
    # Convert the initial convex gamma sequence back into an autocovariance series
    # Do this by interpolating, doubling the amount of data
    x_interpolate = np.arange(2*len(initial_convex_sequence))
    # x_gamma is at x = 0.5, 2.5, 4.5, etc.
    x_gamma = np.arange(len(initial_convex_sequence)) * 2 + 0.5
    autocov_convex = np.interp(x_interpolate, x_gamma, initial_convex_sequence) / 2
    autocov_convex = np.concatenate([autocov_series[:2], autocov_convex])
    # Save parameters
    synthetic_data_params[system]["autocov_convex"] = autocov_convex
    synthetic_data_params[system]["autocov"] = autocov_series
    # Get the total variance (accounting for autocorrelation) from the autocov_convex data
    synthetic_data_params[system]["total_variance"] = 2 * np.sum(autocov_convex) - autocov_convex[0]
    synthetic_data_params[system]["max_lag_idx"] = len(autocov_convex) - 1
    

In [None]:
# Plot all of the autocovariance series against the initial convex sequences
fig, axs = get_subplots(systems)

first_idx = 0
last_idx = None

for i, system in enumerate(systems):
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov"][first_idx:last_idx])), synthetic_data_params[system]["autocov"][first_idx:last_idx], label="Autocovariance")
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov_convex"][first_idx:last_idx])), synthetic_data_params[system]["autocov_convex"][first_idx:last_idx], label="Initial Convex Sequence")
    axs[i].set_xlabel("Lag")
    axs[i].set_ylabel("Autocovariance / kcal$^2$ mol$^{-2}$")
    axs[i].set_title(system)
    # Add the total variance to the plot
    axs[i].text(0.5, 0.7, f"Total Variance:\n {synthetic_data_params[system]['total_variance']:.2f} kcal$^2$ mol$^{{-2}}$", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)
    # Add max lag index to the plot
    axs[i].text(0.5, 0.5, f"Max Lag Index:\n {synthetic_data_params[system]['max_lag_idx']}", horizontalalignment='center', verticalalignment='center', transform=axs[i].transAxes)

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.25, 0.7), loc='upper left')
fig.savefig("output_free/autocovariance_initial_convex_long.png", dpi=300, bbox_inches='tight')


In [None]:
# Plot all of the autocovariance series against the initial convex sequences
fig, axs = get_subplots(systems)

first_idx = 0
last_idx = 10

for i, system in enumerate(systems):
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov"][first_idx:last_idx])), synthetic_data_params[system]["autocov"][first_idx:last_idx], label="Autocovariance", marker='o')
    axs[i].plot(np.arange(len(synthetic_data_params[system]["autocov_convex"][first_idx:last_idx])), synthetic_data_params[system]["autocov_convex"][first_idx:last_idx], label="Initial Convex Sequence", marker='o')
    axs[i].set_xlabel("Lag")
    axs[i].set_ylabel("Autocovariance / kcal$^2$ mol$^{-2}$")
    axs[i].set_title(system)

fig.tight_layout()
axs[-2].legend(bbox_to_anchor=(1.25, 0.7), loc='upper left')
fig.savefig("output_free/autocovariance_initial_convex_zoom.png", dpi=300, bbox_inches='tight')


In [None]:
# Create overview of all systems - include exponential parameters (half-life and a), total variance, and max lag index

overview_df = pd.DataFrame(columns=["Half-life (ns)", "a (kcal mol$^{-1}$)", "Fast Half-life (ns)", "Fast a (kcal mol$^{-1}$)", "Total Variance (kcal$^2$ mol$^{-2}$)", "Max Lag Index"])
for system in systems:
    overview_df.loc[system] = [synthetic_data_params[system]["half_life"], synthetic_data_params[system]["exp_params"][0], synthetic_data_params[system]["fast_half_life"], synthetic_data_params[system]["fast_exp_params"][0], synthetic_data_params[system]["total_variance"], synthetic_data_params[system]["max_lag_idx"]]
overview_df.to_csv("output_free/overview.csv")
overview_df

In [None]:
def format_significant_figures(x, sig_figs=2):
    return f"{x:.{sig_figs}f}"

# Create the DataFrame
overview_df = pd.DataFrame(columns=["Half-life (ns)", "a (kcal mol$^{-1}$)", "Fast Half-life (ns)", "Fast a (kcal mol$^{-1}$)", "Total Variance (kcal$^2$ mol$^{-2}$)", "Max Lag Index"])
for system in systems:
    overview_df.loc[system] = [synthetic_data_params[system]["half_life"], synthetic_data_params[system]["exp_params"][0], synthetic_data_params[system]["fast_half_life"], synthetic_data_params[system]["fast_exp_params"][0], synthetic_data_params[system]["total_variance"], synthetic_data_params[system]["max_lag_idx"]]

# Apply the rounding function to all elements in the DataFrame
overview_df = overview_df.applymap(lambda x: format_significant_figures(x, 2))

# Save csv and latex
overview_df.to_csv("output_free/overview.csv")
latex_str = overview_df.to_latex("output_free/overview.tex",index=True, escape=False)

overview_df

In [None]:
with open('output_free/synthetic_data_params.pkl', 'wb') as f:
    pkl.dump(synthetic_data_params, f)

## Create synthetic datasets <a name="synth"></a>



Generate the following synthetic datasets:

- "Standard" - Standard synthetic data as described above, up to 8 ns as this is a typical length for an ABFE calculation window

Here, we want an example of both a fairly uncorrelated dataset and a fairly correlated dataset. For the first, we'll use benzene, and for the last we'll use the PDE2A ligand.

In [None]:
def get_cholesky(autocov_fn: np.ndarray, n_data: int) -> np.ndarray:
    """Get the Cholesky decomposition of the autocovariance function"""
    # Figure out if we need to pad or truncate the autocovariance function
    if len(autocov_fn) > n_data:
        autocov_fn = autocov_fn[:n_data]
    else:
        autocov_fn = np.pad(autocov_fn, (0, n_data - len(autocov_fn)), 'constant')

    # Generate the autocovariance matrix
    autocov_matrix = toeplitz(autocov_fn)

    # Get Cholesky decomposition.
    cholesky = scipy.linalg.cholesky(autocov_matrix, lower=True)
    return cholesky

def generate_data(cholesky_mat: np.ndarray, scale:float = 1) -> np.ndarray:
    """Generate new data from an autocovariance function"""
    white_noise = np.random.normal(size=cholesky_mat.shape[0], scale=scale)
    return np.dot(cholesky_mat, white_noise)


def generate_data_with_transient(cholesky_mat: np.ndarray,
                                 times: np.ndarray,
                                 slow_exponential_params: np.ndarray,
                                 fast_exponential_params: np.ndarray,
                                 scale: float = 1) -> np.ndarray:
    """Generate synthetic data for equilibration detection with an exponential transient"""
    # Get the exponential transient
    n_data= cholesky_mat.shape[0]
    slow_transient_data = exp_decay(times, *slow_exponential_params)
    fast_transient_data = exp_decay(times, *fast_exponential_params)

    # Get the stationary data
    stationary_data = generate_data(cholesky_mat, scale=scale)

    # Add the two together
    return stationary_data + slow_transient_data + fast_transient_data

In [None]:
# Generate data up to 8 ns

SET_NAME = "standard"
N_REPEATS = 1000
IDX = get_time_idxs(timeseries_data[systems[0]]["times"], 8)

synthetic_data_free_vanish = {}
synthetic_data_free_vanish[SET_NAME] = {}

for system in tqdm.tqdm(["T4L", "PDE2A"]):
    synthetic_data_free_vanish[SET_NAME][system] = {}
    autocov_fn = synthetic_data_params[system]["autocov_convex"]
    cholesky_mat = get_cholesky(autocov_fn, IDX)
    times = timeseries_data[system]["times"][:IDX]
    synthetic_data_free_vanish[SET_NAME]["times"] = times
    for i in range(N_REPEATS):
        data = generate_data_with_transient(cholesky_mat,
                                            times,
                                            synthetic_data_params[system]["exp_params"],
                                            synthetic_data_params[system]["fast_exp_params"])
        synthetic_data_free_vanish[SET_NAME][system][i] = {}
        synthetic_data_free_vanish[SET_NAME][system][i]["data"] = data

# with open("output_free/synthetic_data_free_vanish.pkl", "wb") as f:
#     pkl.dump(synthetic_data_free_vanish, f)


In [None]:
import red
from pymbar import timeseries
import numpy as np
import pickle as pkl

# with open("output_free/synthetic_data_free_vanish.pkl", "rb") as f:
with open("output/synthetic_data_bound_vanish.pkl", "rb") as f:
    synthetic_data_free_vanish = pkl.load(f)

In [None]:
from statsmodels.tsa.stattools import acovf
import matplotlib.pyplot as plt

def standard_cov(data):
    n_samples = len(data)
    max_lag = n_samples - 1
    return np.correlate(data, data, mode='full')[n_samples-1:n_samples+max_lag] / n_samples

def slow_standard_cov(data):
    n_samples = len(data)
    max_lag = n_samples - 1
        # Initialise the auto-correlation function.
    auto_cov = np.zeros(max_lag + 1)

    # Calculate the auto-correlation function.
    auto_cov[0] = data.dot(data)
    for t in range(1, max_lag + 1):
        auto_cov[t] = data[t:].dot(data[:-t])
        # auto_cov[t] = _np.sum(data[t:] * data[:-t])
    auto_cov /= n_samples  # "Biased" estimate, rather than n - 1.

    return auto_cov

def fft_cov(data):
    n_samples = len(data)
    max_lag = n_samples - 1
    return acovf(data, adjusted=False, nlag=max_lag, fft=True, demean=False)

# Time both of the above functions against data of increasing length

import time

full_data = synthetic_data_free_vanish["standard"]["PDE2A"][0]["data"]
total_samples = len(full_data)

n_samples = np.arange(200, 800, 1)
# n_samples = np.arange(100, , 1)
times_standard = []
times_slow_standard = []
times_fft = []

for n in n_samples:
    data = full_data[:n]
    start = time.time()
    standard_cov(data)
    end = time.time()
    times_standard.append(end - start)

    start = time.time()
    fft_cov(data)
    end = time.time()
    times_fft.append(end - start)

    start = time.time()
    slow_standard_cov(data)
    end = time.time()
    times_slow_standard.append(end - start)

# Plot the times
fig, ax = plt.subplots()
ax.plot(n_samples, times_standard, label="Standard Covariance")
ax.plot(n_samples, times_fft, label="FFT Covariance")
ax.plot(n_samples, times_slow_standard, label="Slow Standard Covariance")
ax.set_xlabel("Number of Samples")
ax.set_ylabel("Time / s")
ax.legend()


In [None]:
full_data = synthetic_data_free_vanish["standard"]["PDE2A"][0]["data"]
total_samples = len(full_data)

n_samples = np.arange(100, total_samples, 100)
times_standard = []
times_fft = []

for n in n_samples:
    data = full_data[:n]
    start = time.time()
    standard_cov(data)
    end = time.time()
    times_standard.append(end - start)

    start = time.time()
    fft_cov(data)
    end = time.time()
    times_fft.append(end - start)

# Plot the times
fig, ax = plt.subplots()
ax.plot(n_samples, times_standard, label="Standard Covariance")
ax.plot(n_samples, times_fft, label="FFT Covariance")
ax.set_xlabel("Number of Samples")
ax.set_ylabel("Time / s")
ax.legend()


In [None]:
time t, g, neff = red.detect_equilibration_init_seq(synthetic_data_free_vanish["standard"]["PDE2A"][2]["data"][:], method="max_ess", sequence_estimator="initial_convex", plot=True)

In [None]:
%time t,g, neff = timeseries.detectEquilibration(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:], fast=False)

In [None]:
%time t, g, neff = red.detect_equilibration_init_seq(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:], method="max_ess", sequence_estimator="positive")

In [None]:
%time t, g, neff = red.detect_equilibration_window(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:])

In [None]:
%timeit t,g, neff = timeseries.detectEquilibration(synthetic_data_free_vanish["standard"]["T4L"][0]["data"])

In [None]:
%timeit t,g, neff = timeseries.detectEquilibration(synthetic_data_free_vanish[SET_NAME]["T4L"][0]["data"], fast=False)

In [None]:
%timeit t, g, neff = red.detect_equilibration_init_seq(synthetic_data_free_vanish[SET_NAME]["T4L"][0]["data"], method="max_ess")

In [None]:
len(synthetic_data_free_vanish[SET_NAME]["T4L"][0]["data"])

In [None]:
%prun t,g, neff = timeseries.detectEquilibration(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:10000], fast=False)

In [None]:
%prun t, g, neff = red.detect_equilibration_init_seq(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:])

In [None]:
import numpy as _np
from typing import Union as _Union
from copy import deepcopy as _deepcopy


def _get_autocovariance(
    data: _np.ndarray,
    max_lag: _Union[None, int] = None,
    mean: _Union[None, float] = None,
) -> _np.ndarray:
    """
    Calculate the auto-covariance as a function of lag time for a time series.

    Parameters
    ----------
    data : numpy.ndarray
        A time series of data with shape (n_samples,).

    max_lag : int, optional, default=None
        The maximum lag time to use when calculating the auto-correlation function.
        If None, the maximum lag time will be the length of the time series.
        The default is None.

    mean: float, optional, default=None
        The mean of the time series. If None, the mean will be calculated from the
        time series. This is useful when the mean has been calculated from an
        ensemble of time series.

    Returns
    -------
    numpy.ndarray
        The auto-correlation function of the time series.
    """
    # Copy the data so we don't modify the original.
    data = _deepcopy(data)

    # Get the length of the time series.
    n_samples: int = data.shape[0]

    # If max_lag_time is None, set it to the length of the time series.
    if max_lag is None:
        max_lag = n_samples - 1

    # If mean is None, calculate it from the time series.
    if mean is None:
        mean = data.mean()

    # Subtract the mean from the data.
    data -= mean  # type: ignore

    return _np.correlate(data, data, mode='full')[n_samples-1:n_samples+max_lag] / n_samples

    # Initialise the auto-correlation function.
    auto_cov = _np.zeros(max_lag + 1)

    # Calculate the auto-correlation function.
    auto_cov[0] = data.dot(data)
    for t in range(1, max_lag + 1):
        auto_cov[t] = data[t:].dot(data[:-t])
        # auto_cov[t] = _np.sum(data[t:] * data[:-t])
    auto_cov /= n_samples  # "Biased" estimate, rather than n - 1.

    return auto_cov

_get_autocovariance(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:5000])

In [None]:
from statsmodels.tsa.stattools import acovf, acf

In [None]:
acovf(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:], fft=True)

In [None]:
acovf(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:], fft=False)

In [None]:
%timeit acovf(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:], fft=False)

In [None]:
%timeit _get_autocovariance(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:]) 

In [None]:
acovf(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:5000], fft=False)

In [None]:
%prun _get_autocovariance(synthetic_data_free_vanish["standard"]["T4L"][0]["data"][:5000])

In [None]:
def statistical_inefficiency(A_n, B_n=None, fast=False, mintime=3, fft=False):
    """Compute the (cross) statistical inefficiency of (two) timeseries.

    Parameters
    ----------
    A_n : np.ndarray, float
        A_n[n] is nth value of timeseries A.  Length is deduced from vector.
    B_n : np.ndarray, float, optional, default=None
        B_n[n] is nth value of timeseries B.  Length is deduced from vector.
        If supplied, the cross-correlation of timeseries A and B will be estimated instead of the
        autocorrelation of timeseries A.
    fast : bool, optional, default=False
        f True, will use faster (but less accurate) method to estimate correlation
        time, described in Ref. [1] (default: False).  This is ignored
        when B_n=None and fft=True.
    mintime : int, optional, default=3
        minimum amount of correlation function to compute (default: 3)
        The algorithm terminates after computing the correlation time out to mintime when the
        correlation function first goes negative.  Note that this time may need to be increased
        if there is a strong initial negative peak in the correlation function.
    fft : bool, optional, default=False
        If fft=True and B_n=None, then use the fft based approach, as
        implemented in statistical_inefficiency_fft().

    Returns
    -------
    g : np.ndarray,
        g is the estimated statistical inefficiency (equal to 1 + 2 tau, where tau is the correlation time).
        We enforce g >= 1.0.

    Notes
    -----
    The same timeseries can be used for both A_n and B_n to get the autocorrelation statistical inefficiency.
    The fast method described in Ref [1] is used to compute g.

    References
    ----------
    [1] J. D. Chodera, W. C. Swope, J. W. Pitera, C. Seok, and K. A. Dill. Use of the weighted
    histogram analysis method for the analysis of simulated and parallel tempering simulations.
    JCTC 3(1):26-41, 2007.

    Examples
    --------

    Compute statistical inefficiency of timeseries data with known correlation time.

    >>> from pymbar.testsystems import correlated_timeseries_example
    >>> A_n = correlated_timeseries_example(N=100000, tau=5.0)
    >>> g = statistical_inefficiency(A_n, fast=True)

    """

    # Create numpy copies of input arguments.
    A_n = np.array(A_n)

    if fft and B_n is None:
        return statistical_inefficiency_fft(A_n, mintime=mintime)

    if B_n is not None:
        B_n = np.array(B_n)
    else:
        B_n = np.array(A_n)

    # Get the length of the timeseries.
    N = A_n.size

    # Be sure A_n and B_n have the same dimensions.
    if A_n.shape != B_n.shape:
        raise ParameterError("A_n and B_n must have same dimensions.")

    # Initialize statistical inefficiency estimate with uncorrelated value.
    g = 1.0

    # Compute mean of each timeseries.
    mu_A = A_n.mean()
    mu_B = B_n.mean()

    # Make temporary copies of fluctuation from mean.
    dA_n = A_n.astype(np.float64) - mu_A
    dB_n = B_n.astype(np.float64) - mu_B

    # Compute estimator of covariance of (A,B) using estimator that will ensure C(0) = 1.
    sigma2_AB = (dA_n * dB_n).mean()  # standard estimator to ensure C(0) = 1

    # Trap the case where this covariance is zero, and we cannot proceed.
    if sigma2_AB == 0:
        raise ParameterError(
            "Sample covariance sigma_AB^2 = 0 -- cannot compute statistical inefficiency"
        )

    # Accumulate the integrated correlation time by computing the normalized correlation time at
    # increasing values of t.  Stop accumulating if the correlation function goes negative, since
    # this is unlikely to occur unless the correlation function has decayed to the point where it
    # is dominated by noise and indistinguishable from zero.
    t = 1
    increment = 1
    while t < N - 1:
        # compute normalized fluctuation correlation function at time t
        C = np.sum(dA_n[0 : (N - t)] * dB_n[t:N] + dB_n[0 : (N - t)] * dA_n[t:N]) / (
            2.0 * float(N - t) * sigma2_AB
        )

        # Terminate if the correlation function has crossed zero and we've computed the correlation
        # function at least out to 'mintime'.
        if (C <= 0.0) and (t > mintime):
            break

        # Accumulate contribution to the statistical inefficiency.
        g += 2.0 * C * (1.0 - float(t) / float(N)) * float(increment)

        # Increment t and the amount by which we increment t.
        t += increment

        # Increase the interval if "fast mode" is on.
        if fast:
            increment += 1

    # g must be at least unity
    if g < 1.0:
        g = 1.0

    # Return the computed statistical inefficiency.
    return g

In [None]:
%timeit statistical_inefficiency(synthetic_data_free_vanish["standard"]["PDE2A"][0]["data"][:5000])

In [None]:
%timeit _get_autocovariance(synthetic_data_free_vanish["standard"]["PDE2A"][0]["data"][:5000])

In [None]:
%timeit t, g, neff = red.detect_equilibration_window(synthetic_data_free_vanish[SET_NAME]["T4L"][0]["data"], method="max_ess")

In [None]:
# Let's see how Chodera's method varies on just 10 of the outputs
from tqdm import tqdm
from matplotlib import pyplot as plt

N_IT = 50
SET_NAME = "standard"
SYSTEM = "PDE2A"

fracs_discarded = np.zeros(N_IT)
for i in tqdm(range(N_IT)):
    data = synthetic_data_free_vanish[SET_NAME][SYSTEM]["data"]
    t, g, Neff = timeseries.detectEquilibration(data)
    frac_discarded = t/len(data)
    fracs_discarded[i] = frac_discarded
    print(f"Fraction discarded: {frac_discarded:.2f}, t={t}, g={g}, Neff={Neff}")

# Plot distribution of fractions discarded
import seaborn as sns

fig, ax = plt.subplots()
sns.histplot(fracs_discarded, ax=ax, kde=True)
ax.set_xlabel("Fraction Discarded")
ax.set_ylabel("Frequency")


In [None]:
# Let's see how Chodera's method varies on just 10 of the outputs

fracs_discarded = np.zeros(N_IT)
for i in tqdm(range(N_IT)):
    data = synthetic_data_free_vanish[SET_NAME][SYSTEM][i]["data"]
    t, g, Neff = timeseries.detectEquilibration(data, fast=False)
    frac_discarded = t/len(data)
    fracs_discarded[i] = frac_discarded
    print(f"Fraction discarded: {frac_discarded:.2f}, t={t}, g={g}, Neff={Neff}")

# Plot distribution of fractions discarded
import seaborn as sns

fig, ax = plt.subplots()
sns.histplot(fracs_discarded, ax=ax, kde=True)
ax.set_xlabel("Fraction Discarded")
ax.set_ylabel("Frequency")


In [None]:

fracs_discarded = np.zeros(N_IT)
for i in range(N_IT):
    data = synthetic_data_free_vanish[SET_NAME][SYSTEM][i]["data"]
    t, g, neff = red.detect_equilibration_init_seq(data, method="max_ess", sequence_estimator="positive")
    frac_discarded = t/len(data)
    fracs_discarded[i] = frac_discarded
    print(f"Fraction discarded: {frac_discarded:.2f}, t={t}, g={g}, Neff={Neff}")

# Plot distribution of fractions discarded
import seaborn as sns

fig, ax = plt.subplots()
sns.histplot(fracs_discarded, ax=ax, kde=True)
ax.set_xlabel("Fraction Discarded")
ax.set_ylabel("Frequency")


In [None]:
fracs_discarded = np.zeros(N_IT)
for i in range(N_IT):
    data = synthetic_data_free_vanish[SET_NAME][SYSTEM][i]["data"]
    t, g, neff = red.detect_equilibration_init_seq(data, method="min_sse", sequence_estimator="positive")
    frac_discarded = t/len(data)
    fracs_discarded[i] = frac_discarded
    print(f"Fraction discarded: {frac_discarded:.2f}, t={t}, g={g}, Neff={Neff}")

# Plot distribution of fractions discarded
import seaborn as sns

fig, ax = plt.subplots()
sns.histplot(fracs_discarded, ax=ax, kde=True)
ax.set_xlabel("Fraction Discarded")
ax.set_ylabel("Frequency")


In [None]:
# Let's see how Chodera's method varies on just 10 of the outputs

N_IT = 50

fracs_discarded = np.zeros(N_IT)
for i in range(N_IT):
    data = synthetic_data_bound_vanish[SET_NAME][system][i]["data"]
    t, g, Neff = timeseries.detectEquilibration(data)
    frac_discarded = t/len(data)
    fracs_discarded[i] = frac_discarded
    print(f"Fraction discarded: {frac_discarded:.2f}, t={t}, g={g}, Neff={Neff}")

# Plot distribution of fractions discarded
import seaborn as sns

fig, ax = plt.subplots()
sns.histplot(fracs_discarded, ax=ax, kde=True)
ax.set_xlabel("Fraction Discarded")
ax.set_ylabel("Frequency")


In [None]:
t1, g1, neff1 = timeseries.detectEquilibration(data, fast=False)
t2, g2, neff2 = red.detect_equilibration_init_seq(data, method="max_ess", sequence_estimator="positive")
assert t1 == t2
# assert g1 == g2
assert neff1 == neff2

In [None]:
g1

In [None]:
g2