In [None]:
from typing import Dict, Optional, Any
import os
import pickle
import collections
from pathlib import Path

import logging
logging.getLogger('fontTools').setLevel(logging.ERROR)  # Only show errors, not warnings or info

import mdtraj as md
import numpy as np
import scipy.stats
import pyemma
import pandas as pd
import lovelyplots
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors
mpl.rcParams['axes.formatter.useoffset'] = False
mpl.rcParams['axes.formatter.limits'] = (-10000, 10000)  # Controls range before scientific notation is used
plt.style.use("ipynb")

# Comment if you want to see figures in notebook.
# plt.use('agg')

import pyemma_helper
from jamun import utils

### Paths

Load the results for the corresponding experiment, trajectories and reference trajectories.

In [None]:
results_dir = "/data/bucket/kleinhej/jamun-analysis/"
print(f"Results directory: {results_dir}")

In [None]:
# experiment = "Our_2AA"
# experiment = "Our_5AA"
experiment = "Timewarp_2AA"
# experiment = "Timewarp_4AA"
# experiment = "MDGen_4AA"
# experiment = "MDGen_4AA_new"

runs_df = pd.read_csv("sample_runs.csv")
if experiment not in runs_df["experiment"].values:
    raise ValueError(f"Experiment {experiment} not found in runs_df")

traj_name = runs_df.loc[runs_df["experiment"] == experiment, "trajectory"].values[0]
ref_traj_name = runs_df.loc[runs_df["experiment"] == experiment, "reference"].values[0]

print(f"Experiment: {experiment}")
print(f"Trajectory: {traj_name}")
print(f"Reference: {ref_traj_name}")

In [None]:
output_dir = os.path.join("/data/bucket/kleinhej/jamun-plots", experiment, traj_name, f"ref={ref_traj_name}")
os.makedirs(output_dir, exist_ok=True)

print(f"Plots will be saved to {output_dir}")

### Load All Trajectories

In [None]:
results_path = os.path.join(
    results_dir, experiment, traj_name, f"ref={ref_traj_name}"
)
print(f"Searching for results in {results_path}")

In [None]:
def load_results_path(results_path: str) -> pd.DataFrame:
    """Loads the results as a pandas DataFrame."""

    # Split the path to get the trajectory and reference names
    parts = Path(results_path).parts
    traj_name = parts[-2]  # The second last part is the trajectory name
    ref_traj_name = parts[-1]  # The last part is the reference trajectory name
    if not ref_traj_name.startswith("ref="):
        raise ValueError(f"Expected reference trajectory name to start with 'ref=', got {ref_traj_name}")
    ref_traj_name = ref_traj_name[len("ref="):]

    results = []
    for results_file in sorted(os.listdir(results_path)):
        peptide, ext = os.path.splitext(results_file)
        if ext != ".pkl":
            continue

        with open(os.path.join(results_path, results_file), "rb") as f:
            all_results = pickle.load(f)

        results.append({
            "traj": traj_name,
            "ref_traj": ref_traj_name,
            "peptide": peptide,
            "results": all_results["results"],
            "args": all_results["args"],
        })
    return pd.DataFrame(results)


results_df = load_results_path(results_path)
results_df

In [None]:
# Also, load TBG results for the same experiment.
def add_recursively(original_results, tbg_results, add_key):
    if not isinstance(original_results, dict) or not isinstance(tbg_results, dict):
        return

    if add_key in original_results:
        raise ValueError(f"Key '{add_key}' already exists in original_results")

    if "traj" in original_results:
        original_results[add_key] = tbg_results["traj"]
        return

    for key in original_results:
        add_recursively(original_results[key], tbg_results[key], add_key)


        
if experiment == "Timewarp_2AA":
    tbg_results_path = os.path.join(
        results_dir, "Timewarp_2AA", "TBG", f"ref={ref_traj_name}"
    )

    tbg_results_df = load_results_path(tbg_results_path)
    tbg_results_df = tbg_results_df.reset_index(drop=True)

    # Add tbg_results_df to the main results_df, by adding a key "TBG" to the result
    for i, row in results_df.iterrows():
        peptide = row["peptide"]
        tbg_row = tbg_results_df[tbg_results_df["peptide"] == peptide].iloc[0]

        original_results = row["results"]
        tbg_results = tbg_row["results"]

        add_recursively(original_results, tbg_results, add_key="TBG")

In [None]:
# Filter based on peptide names.
if "5AA" in experiment:
    peptides = ["KTYDI", "NRLCQ", "VWSPF"]
    peptides = ["uncapped_" + peptide for peptide in peptides]
    sampled_results_df = results_df[results_df["peptide"].isin(peptides)]

else:
    # Sample 4 random peptides
    sampled_results_df = results_df.sample(n=min(len(results_df), 4), random_state=42)


sampled_results_df = sampled_results_df.reset_index(drop=True)
sampled_results_df

In [None]:
def format_traj_name(results_traj_name: str) -> str:
    """Format the trajectory name for plotting"""
    return {
        "traj": traj_name,
        "ref_traj": "Reference",
        "ref_traj_10x": "Reference\n(10x shorter)",
        "ref_traj_100x": "Reference\n(100x shorter)",
    }[results_traj_name]

def format_quantity(quantity: str) -> str:
    """Format the quantity for plotting."""
    return {
        "JSD_backbone_torsions": "Backbone Torsions",
        "JSD_sidechain_torsions": "Sidechain Torsions",
        "JSD_all_torsions": "All Torsions",
        "JSD_TICA-0": "TICA-0 Projections",
        "JSD_TICA-0,1": "TICA-0,1 Projections",
        "JSD_metastable_probs": "Metastable State Probabilities",
    }[quantity]

def format_peptide_name(peptide: str) -> str:
    """Formats the peptide name for plotting."""
    if peptide.startswith("uncapped_"):
        peptide = peptide[len("uncapped_"):]
    if peptide.startswith("capped_"):
        peptide = peptide[len("capped_"):]
    if "_" in peptide:
        return peptide.replace("_", "-")
    return utils.convert_to_one_letter_codes(peptide)

### Ramachandran Plots

In [None]:
def plot_ramachandran_contour(results: Dict[str, Any], dihedral_index: int, ax: Optional[plt.Axes] = None) -> plt.Axes:
    """Plots the Ramachandran contour plot of a trajectory."""

    if ax is None:
        _, ax = plt.subplots(figsize=(10, 10))

    pmf, xedges, yedges = results["pmf"], results["xedges"], results["yedges"]
    im = ax.contourf(xedges[:-1], yedges[:-1], pmf[dihedral_index], cmap="viridis", levels=50)
    contour = ax.contour(xedges[:-1], yedges[:-1], pmf[dihedral_index], colors="white", linestyles="solid", levels=10, linewidths=0.25)

    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("$\phi$")
    ax.set_ylabel("$\psi$")

    tick_eps = 0.1
    ticks = [-np.pi + tick_eps, -np.pi / 2, 0, np.pi / 2, np.pi - tick_eps]
    tick_labels = ["$-\pi$", "$-\pi/2$", "$0$", "$\pi/2$", "$\pi$"]
    ax.set_xticks(ticks, tick_labels)
    ax.set_yticks(ticks, tick_labels)
    return ax


def get_num_dihedrals(experiment: str, pmf_type: str) -> int:
    # "internal" for psi_2 - phi_2, psi_3 - phi_3, etc.
    # "all" for psi_1 - phi_2, psi_2 - phi_3, etc.
    if pmf_type not in ["internal", "all"]:
        raise ValueError(f"Invalid pmf_type: {pmf_type}")

    if experiment == "Our_2AA":
        num_dihedrals = 1
    elif "2AA" in experiment:
        num_dihedrals = 0
    elif "4AA" in experiment:
        num_dihedrals = 2
    elif "5AA" in experiment:
        num_dihedrals = 3

    if pmf_type == "all":
        num_dihedrals += 1

    return num_dihedrals

#### Ramachandran Plots against Reference

In [None]:
pmf_type = "all"
num_dihedrals = get_num_dihedrals(experiment, pmf_type)
label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5


ones = list(np.ones(num_dihedrals))
fig, axs = plt.subplots(
    len(sampled_results_df), 2 * num_dihedrals+1,
    figsize=(8 * num_dihedrals, 4 * len(sampled_results_df)),
    gridspec_kw={
        'width_ratios': ones + [0.1] + ones,
        'hspace': 0.1
    }
)
for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]

    for j in range(num_dihedrals):
        plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj"][f"pmf_{pmf_type}"], j, axs[i, j])
        plot_ramachandran_contour(row["results"]["PMFs"]["traj"][f"pmf_{pmf_type}"], j, axs[i, j + num_dihedrals+1])

    # Add labels.
    ax_index = num_dihedrals // 2
    axs[0, ax_index].text(
        label_offset,
        1.1,
        format_traj_name("ref_traj"),
        horizontalalignment="center",
        verticalalignment="center",
        transform=axs[0, ax_index].transAxes,
        fontsize=22,
    )

    ax_index = num_dihedrals // 2 + num_dihedrals + 1
    axs[0, ax_index].text(
        label_offset,
        1.1,
        format_traj_name("traj"),
        horizontalalignment="center",
        verticalalignment="center",
        transform=axs[0, ax_index].transAxes,
        fontsize=22,
    )

    ax_index = -1
    axs[i, ax_index].text(
        1.1,
        0.5,
        format_peptide_name(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, ax_index].transAxes,
        fontsize=18,
    )


    axs[i, num_dihedrals].axis("off")
    
    if i != len(axs) - 1:
        for j in range(len(axs[i])):
            axs[i, j].set_xticks([])
            axs[i, j].set_xlabel("")

    for j in range(1,len(axs[i])):
        axs[i, j].set_yticks([])
        axs[i, j].set_ylabel("")
        

plt.subplots_adjust(hspace=0.06, wspace=0.04)
plt.savefig(os.path.join(output_dir, "ramachandran_contours.pdf"), dpi=300)
plt.show()

#### Ramachandran Plots against Reference (Shortened)

In [None]:
pmf_type = "all"
num_dihedrals = get_num_dihedrals(experiment, pmf_type)
label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5

ones = list(np.ones(num_dihedrals))
fig, axs = plt.subplots(len(sampled_results_df), 3 * num_dihedrals + 2, figsize=(12 * num_dihedrals, 4 * len(sampled_results_df)),gridspec_kw={'width_ratios': ones+[0.1]+ones+[0.1]+ones,'hspace':0.1})

for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]

    for j in range(num_dihedrals):
        plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj"][f"pmf_{pmf_type}"], j, axs[i, j])
        plot_ramachandran_contour(row["results"]["PMFs"]["traj"][f"pmf_{pmf_type}"], j, axs[i, j + num_dihedrals + 1])
        plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj_100x"][f"pmf_{pmf_type}"], j, axs[i, j + 2 * num_dihedrals + 2])

    # Add labels.
    ax_index = num_dihedrals // 2
    axs[0, ax_index].text(
        label_offset,
        1.1,
        format_traj_name("ref_traj"),
        horizontalalignment="center",
        verticalalignment="center",
        transform=axs[0, ax_index].transAxes,
        fontsize=22,
    )

    ax_index = num_dihedrals // 2 + num_dihedrals + 1
    axs[0, ax_index].text(
        label_offset,
        1.1,
        format_traj_name("traj"),
        horizontalalignment="center",
        verticalalignment="center",
        transform=axs[0, ax_index].transAxes,
        fontsize=22,
    )
    
    ax_index = num_dihedrals // 2 + 2 * num_dihedrals + 2
    axs[0, ax_index].text(
        label_offset,
        1.1,
        format_traj_name("ref_traj_100x"),
        horizontalalignment="center",
        verticalalignment="center",
        transform=axs[0, ax_index].transAxes,
        fontsize=22,
    )

    ax_index = -1
    axs[i, ax_index].text(
        1.1,
        0.5,
        format_peptide_name(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, ax_index].transAxes,
        fontsize=18,
    )

    axs[i, num_dihedrals].axis("off")
    axs[i, 2 * num_dihedrals + 1].axis("off")
    
    if i != len(axs) - 1:
        for j in range(len(axs[i])):
            axs[i, j].set_xticks([])
            axs[i, j].set_xlabel("")

    for j in range(1,len(axs[i])):
        axs[i, j].set_yticks([])
        axs[i, j].set_ylabel("")
        

plt.subplots_adjust(hspace=0.06, wspace=0.04)
plt.savefig(os.path.join(output_dir, "ramachandran_contours_with_shortened_reference.pdf"), dpi=300)
plt.show()

In [None]:
# For experiment "Timewarp_2AA", plot the TBG results as well.
if experiment == "Timewarp_2AA":
    pmf_type = "all"
    num_dihedrals = get_num_dihedrals(experiment, pmf_type)
    label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5

    ones = list(np.ones(num_dihedrals))
    fig, axs = plt.subplots(len(sampled_results_df), 4 * num_dihedrals + 3, figsize=(12 * num_dihedrals, 4 * len(sampled_results_df)),gridspec_kw={'width_ratios': ones+[0.1]+ones+[0.1]+ones+[0.1]+ones,'hspace':0.1})

    for i, row in sampled_results_df.iterrows():
        peptide = row["peptide"]

        for j in range(num_dihedrals):
            plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj"][f"pmf_{pmf_type}"], j, axs[i, j])
            plot_ramachandran_contour(row["results"]["PMFs"]["traj"][f"pmf_{pmf_type}"], j, axs[i, j + num_dihedrals + 1])
            plot_ramachandran_contour(row["results"]["PMFs"]["TBG"][f"pmf_{pmf_type}"], j, axs[i, j + 2 * num_dihedrals + 2])
            plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj_100x"][f"pmf_{pmf_type}"], j, axs[i, j + 3 * num_dihedrals + 3])

        # Add labels.
        ax_index = num_dihedrals // 2
        axs[0, ax_index].text(
            label_offset,
            1.2,
            format_traj_name("ref_traj"),
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, ax_index].transAxes,
            fontsize=22,
        )

        ax_index = num_dihedrals // 2 + num_dihedrals + 1
        axs[0, ax_index].text(
            label_offset,
            1.2,
            format_traj_name("traj"),
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, ax_index].transAxes,
            fontsize=22,
        )
        
        ax_index = num_dihedrals // 2 + 2 * num_dihedrals + 2
        axs[0, ax_index].text(
            label_offset,
            1.2,
            "TBG",
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, ax_index].transAxes,
            fontsize=22,
        )

        ax_index = num_dihedrals // 2 + 3 * num_dihedrals + 3
        axs[0, ax_index].text(
            label_offset,
            1.2,
            format_traj_name("ref_traj_100x"),
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, ax_index].transAxes,
            fontsize=22,
        )

        ax_index = -1
        axs[i, ax_index].text(
            1.1,
            0.5,
            format_peptide_name(peptide),
            rotation=90,
            verticalalignment="center",
            horizontalalignment="center",
            transform=axs[i, ax_index].transAxes,
            fontsize=18,
        )

        axs[i, num_dihedrals].axis("off")
        axs[i, 2 * num_dihedrals + 1].axis("off")
        axs[i, 3 * num_dihedrals + 2].axis("off")

        if i != len(axs) - 1:
            for j in range(len(axs[i])):
                axs[i, j].set_xticks([])
                axs[i, j].set_xlabel("")

        for j in range(1,len(axs[i])):
            axs[i, j].set_yticks([])
            axs[i, j].set_ylabel("")
            

    plt.subplots_adjust(hspace=0.06, wspace=0.04)
    plt.savefig(os.path.join(output_dir, "ramachandran_contours_with_TBG.pdf"), dpi=300)
    plt.show()

#### Ramachandran Plots for a Single Peptide

In [None]:
pmf_type = "all"
num_dihedrals = get_num_dihedrals(experiment, pmf_type)
label_offset = 0.0 if num_dihedrals % 2 == 0 else 0.5


fig, axs = plt.subplots(2, num_dihedrals, figsize=(4 * num_dihedrals, 8), squeeze=False)
peptide = next(iter(sampled_results_df["peptide"]))
for j in range(num_dihedrals):
    plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj"][f"pmf_{pmf_type}"], j, axs[0, j])
    plot_ramachandran_contour(row["results"]["PMFs"]["traj"][f"pmf_{pmf_type}"], j, axs[1, j])

for i in range(2):
    for j in range(1,len(axs[i])):
        axs[i, j].set_yticks([])
        axs[i, j].set_ylabel("")

for j in range(len(axs[0])):
    axs[0, j].set_xticks([])
    axs[0, j].set_xlabel("")
    
# Add labels.
axs[0, -1].text(
    1.1,
    0.5,
    format_traj_name("ref_traj"),
    rotation=90,
    verticalalignment="center",
    horizontalalignment="center",
    transform=axs[0, -1].transAxes,
)
axs[1, -1].text(
    1.1,
    0.5,
    format_traj_name("traj"),
    rotation=90,
    verticalalignment="center",
    horizontalalignment="center",
    transform=axs[1, -1].transAxes,
)
fig.suptitle(format_peptide_name(peptide))
plt.subplots_adjust(hspace=0.06, wspace=0.04)
plt.savefig(os.path.join(output_dir, f"ramachandran_contours_{peptide}.pdf"), dpi=300)
plt.show()

### Feature Histograms

In [None]:
fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(14, 4 * len(sampled_results_df)), squeeze=False)
for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]

    feats = row["results"]["featurization"]
    histograms = row["results"]["feature_histograms"]

    pyemma_helper.plot_feature_histograms(
        histograms["ref_traj"]["torsions"]["histograms"],
        histograms["ref_traj"]["torsions"]["edges"],
        feature_labels=feats["ref_traj"]["feats"]["torsions"].describe(),
        ax=axs[i, 0]
    )

    pyemma_helper.plot_feature_histograms(
        histograms["traj"]["torsions"]["histograms"],
        histograms["traj"]["torsions"]["edges"],    
        feature_labels=feats["traj"]["feats"]["torsions"].describe(),
        ax=axs[i, 1]
    )

    axs[i, -1].text(
        1.1,
        0.5,
        format_peptide_name(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

axs[0, 0].set_title(format_traj_name("ref_traj"))
axs[0, 1].set_title(format_traj_name("traj"))
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "feature_histograms.pdf"), dpi=300)
plt.show()

In [None]:
fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(14, 4 * len(sampled_results_df)), squeeze=False)
for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]

    feats = row["results"]["featurization"]
    histograms = row["results"]["feature_histograms"]

    num_hists = len(histograms["ref_traj"]["distances"]["histograms"])
    indices = np.random.choice(num_hists, replace=False, size=min(num_hists, 10))

    pyemma_helper.plot_feature_histograms(
        histograms["ref_traj"]["distances"]["histograms"][indices],
        histograms["ref_traj"]["distances"]["edges"][indices],
        feature_labels=[feats["ref_traj"]["feats"]["distances"].describe()[i] for i in indices],
        ax=axs[i, 0]
    )

    pyemma_helper.plot_feature_histograms(
        histograms["traj"]["distances"]["histograms"][indices],
        histograms["traj"]["distances"]["edges"][indices],    
        feature_labels=[feats["traj"]["feats"]["distances"].describe()[i] for i in indices],
        ax=axs[i, 1]
    )

    axs[i, 1].set_xlim(axs[i, 0].get_xlim())  # Ensure both axes have the same x-limits
    axs[i, 1].set_ylim(axs[i, 0].get_ylim())  # Ensure both axes have the same y-limits

    axs[i, -1].text(
        1.1,
        0.5,
        format_peptide_name(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

axs[0, 0].set_title(format_traj_name("ref_traj"))
axs[0, 1].set_title(format_traj_name("traj"))
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "distance_histograms.pdf"), dpi=300)
plt.show()

### Torsion Angle Decorrelation Times

In [None]:
all_ref_decorrelation_times = {"backbone": [], "sidechain": []}
all_traj_decorrelation_times = {"backbone": [], "sidechain": []}
total_count = {"backbone": 0, "sidechain": 0}

for i, row in results_df.iterrows():    
    results = row["results"]["torsion_decorrelations"]

    for feat in results:
        ref_decorrelation_time = results[feat]["ref_traj_decorrelation_time"]
        traj_decorrelation_time = results[feat]["traj_decorrelation_time"]

        if 'PHI' in feat or 'PSI' in feat:
            torsion_type = "backbone"
        elif 'CHI' in feat:
            torsion_type = "sidechain"
        else:
            raise ValueError(f"Unknown torsion type: {feat}")

        total_count[torsion_type] += 1
        
        if np.isnan(ref_decorrelation_time) or np.isnan(traj_decorrelation_time):
            continue
        
        all_ref_decorrelation_times[torsion_type].append(ref_decorrelation_time)
        all_traj_decorrelation_times[torsion_type].append(traj_decorrelation_time)


for key in all_ref_decorrelation_times:
    all_ref_decorrelation_times[key] = np.asarray(all_ref_decorrelation_times[key])
    all_traj_decorrelation_times[key] = np.asarray(all_traj_decorrelation_times[key])

#### Backbone Torsion Angle Decorrelation

In [None]:
print(f"Number of backbone torsions with valid decorrelation times: {len(all_ref_decorrelation_times['backbone'])} out of {total_count['backbone']}")

# Scatter plot of probabilities.
plt.scatter(all_ref_decorrelation_times["backbone"], all_traj_decorrelation_times["backbone"], alpha=0.3, edgecolors="none", color='tab:blue')
plt.xscale("log")
plt.yscale("log")
plt.xlabel(format_traj_name("ref_traj"))
plt.ylabel(format_traj_name("traj"))
plt.title("Decorrelation Times of Backbone Torsions")

# Fit line.
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
    np.log(all_ref_decorrelation_times["backbone"]), np.log(all_traj_decorrelation_times["backbone"])
)

# # Create x points for line.
# x_line = np.array([np.percentile(all_ref_decorrelation_times["backbone"], 5), np.percentile(all_ref_decorrelation_times["backbone"], 95)])
# log_x_line = np.log(x_line)
# log_y_line = slope * log_x_line + intercept

# # Transform back to original scale for plotting
# y_line = np.exp(log_y_line)

# # Plot the fitted line with dashed style.
# plt.plot(x_line, y_line, color='tab:blue', linestyle='--')
plt.text(0.65, 0.90, f'R² = {r_value**2:.3f}', transform=plt.gca().transAxes, color='tab:blue')

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "backbone_torsion_decorrelation_times.pdf"), dpi=300)
plt.show()

In [None]:
print(f"Number of backbone torsions with valid decorrelation times: {len(all_ref_decorrelation_times['backbone'])} out of {total_count['backbone']}")

backbone_torsion_speedups = all_ref_decorrelation_times["backbone"] / all_traj_decorrelation_times["backbone"]

bins = np.logspace(np.log10(np.min(backbone_torsion_speedups)), np.log10(np.max(backbone_torsion_speedups)), 21)
plt.hist(backbone_torsion_speedups, bins=bins)
plt.xscale("log")
plt.xlabel("Speedup Factor")
plt.xticks([1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3])
plt.ylabel("Frequency")
plt.suptitle(f"Speedups of Backbone Torsion Decorrelation Times")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "backbone_torsion_speedups.pdf"), dpi=300)
plt.show()

#### Sidechain Torsion Angle Decorrelation

In [None]:
print(f"Number of sidechain torsions with valid decorrelation times: {len(all_ref_decorrelation_times['sidechain'])} out of {total_count['sidechain']}")

# Scatter plot of probabilities.
plt.scatter(all_ref_decorrelation_times["sidechain"], all_traj_decorrelation_times["sidechain"], alpha=0.3, edgecolors="none", color='tab:orange')
plt.xscale("log")
plt.yscale("log")
plt.xlabel(format_traj_name("ref_traj"))
plt.ylabel(format_traj_name("traj"))
plt.title("Decorrelation Times of Sidechain Torsions")

# Fit line.
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
    np.log(all_ref_decorrelation_times["sidechain"]), np.log(all_traj_decorrelation_times["sidechain"])
)

# # Create x points for line.
# x_line = np.array([np.percentile(all_ref_decorrelation_times["sidechain"], 5), np.percentile(all_ref_decorrelation_times["sidechain"], 95)])
# log_x_line = np.log(x_line)
# log_y_line = slope * log_x_line + intercept

# # Transform back to original scale for plotting
# y_line = np.exp(log_y_line)

# # Plot the fitted line with dashed style.
# plt.plot(x_line, y_line, color='tab:orange', linestyle='--')
plt.text(0.65, 0.90, f'R² = {r_value**2:.3f}', transform=plt.gca().transAxes, color='tab:orange')

plt.tight_layout()
plt.savefig(os.path.join(output_dir, "sidechain_torsion_decorrelation_times.pdf"), dpi=300)
plt.show()

In [None]:
print(f"Number of sidechain torsions with valid decorrelation times: {len(all_ref_decorrelation_times['sidechain'])} out of {total_count['sidechain']}")

sidechain_torsion_speedups = all_ref_decorrelation_times["sidechain"] / all_traj_decorrelation_times["sidechain"]

bins = np.logspace(np.log10(np.min(sidechain_torsion_speedups)),
                   np.log10(np.max(sidechain_torsion_speedups)), 21)
plt.hist(sidechain_torsion_speedups, bins=bins)
plt.xscale("log")
plt.xlabel("Speedup Factor")
plt.ylabel("Frequency")
plt.suptitle(f"Speedups of Sidechain Torsion Decorrelation Times")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "sidechain_torsion_speedups.pdf"), dpi=300)
plt.show()

### Jenson-Shannon Divergences (JSD)

In [None]:
def get_JSD_results(quantity: str, name: str, key: str):
    """Helper to load final JSD results."""
    JSDs = []

    for i, row in results_df.iterrows():
        try:
            JSD = row["results"][key][name][quantity]
        except KeyError:
            continue
        JSDs.append(JSD)

    JSDs = np.asarray(JSDs)
    return JSDs

In [None]:
JSD_final_results = {
    "JSD_backbone_torsions": {},
    "JSD_sidechain_torsions": {},
    "JSD_all_torsions": {},
    "JSD_TICA-0": {},
    "JSD_TICA-0,1": {},
    "JSD_metastable_probs": {},
}
traj_names = ["traj", "ref_traj", "ref_traj_10x", "ref_traj_100x"]
if experiment == "Timewarp_2AA":
    traj_names.append("TBG")

for quantity in ["JSD_backbone_torsions", "JSD_sidechain_torsions", "JSD_all_torsions"]:
    for name in traj_names:
        JSD_final_results[quantity][name] = get_JSD_results(
            quantity, name, "JSD_torsions"
        )

for quantity in ["JSD_TICA-0", "JSD_TICA-0,1"]:
    for name in traj_names:
        JSD_final_results[quantity][name] = get_JSD_results(
            quantity, name, "JSD_TICA"
        )

for quantity in ["JSD_metastable_probs"]:
    for name in traj_names:
        JSD_final_results[quantity][name] = get_JSD_results(
            quantity, name, "JSD_MSM"
        )

In [None]:
JSD_final_results_df = pd.DataFrame.from_dict(JSD_final_results)

# Apply mean to each array in the DataFrame
means_series = JSD_final_results_df.map(lambda x: np.mean(x) if isinstance(x, np.ndarray) else None)

# Apply std to each array in the DataFrame
stds_series = JSD_final_results_df.map(lambda x: np.std(x) if isinstance(x, np.ndarray) else None)

means_series

In [None]:
means_series.to_csv(os.path.join(output_dir, "JSDs.csv"))

In [None]:
JSD_MSM = JSD_final_results["JSD_metastable_probs"]["traj"]

plt.hist(JSD_MSM)
plt.title("Jenson-Shannon Distances of Metastable State Probabilities")
plt.xlabel("JSD")
plt.xticks(np.arange(0.1, JSD_MSM.max() + 0.1, 0.1))
plt.ylabel("Frequency")
plt.ticklabel_format(useOffset=False, style="plain")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "jsd_metastable_probs.pdf"), dpi=300)
plt.show()

#### JSD against Trajectory Progress

In [None]:
def get_JSD_results_against_time(quantity: str, name: str, key: str) -> np.ndarray:
    """Helper to load JSD vs time results."""
    JSD_vs_time = {
        "steps": None,
        "JSDs": []
    }

    for i, row in results_df.iterrows():  
        try:  
            results = row["results"][key]
        except KeyError:
            continue

        steps = np.asarray(list(results[name].keys()))
        if JSD_vs_time["steps"] is None:
            JSD_vs_time["steps"] = steps
        
        assert np.allclose(JSD_vs_time["steps"], steps)

        JSDs = np.asarray(list([v[quantity] for v in results[name].values()]))
        JSD_vs_time["JSDs"].append(JSDs)

    JSD_vs_time["progress"] = JSD_vs_time["steps"] / JSD_vs_time["steps"][-1]
    JSD_vs_time["JSDs"] = np.stack(JSD_vs_time["JSDs"])
    return JSD_vs_time

In [None]:
JSD_results = {
    "JSD_backbone_torsions": {},
    "JSD_sidechain_torsions": {},
    "JSD_all_torsions": {},
    "JSD_TICA-0": {},
    "JSD_TICA-0,1": {},
    "JSD_metastable_probs": {},
}
traj_names = ["traj", "ref_traj", "ref_traj_10x", "ref_traj_100x"]

for quantity in ["JSD_backbone_torsions", "JSD_sidechain_torsions", "JSD_all_torsions"]:
    for name in traj_names:
        JSD_results[quantity][name] = get_JSD_results_against_time(
            quantity, name, "JSD_torsions_against_time"
        )

# for quantity in ["JSD_TICA-0", "JSD_TICA-0,1"]:
#     for name in traj_names:
#         JSD_results[quantity][name] = get_JSD_results_against_time(
#             quantity, name, "JSD_TICA_against_time"
#         )

# for quantity in ["JSD_metastable_probs"]:
#     for name in traj_names:
#         JSD_results[quantity][name] = get_JSD_results_against_time(
#             quantity, name, "JSD_MSM_against_time"
#         )

In [None]:
for quantity in JSD_results:
    for name in JSD_results[quantity]:
        mean = np.mean(JSD_results[quantity][name]["JSDs"], axis=0)
        std = np.std(JSD_results[quantity][name]["JSDs"], axis=0)
        progress = JSD_results[quantity][name]["progress"]

        # Plot mean line
        if name == "traj":
            color = "tab:orange"
        else:
            color = None

        line, = plt.plot(progress, mean, label=format_traj_name(name), color=color)
        color = line.get_color()
        
        # Add shaded region for standard deviation
        plt.fill_between(progress, mean - std, mean + std,
                         alpha=0.2, color=color)


    # plt.yscale('function', functions=(np.sqrt, lambda x: x**2))
    plt.ylim(0, 1) 
    plt.legend(bbox_to_anchor=(1.05, 0.5), loc='center left')
    plt.title(f"JSD vs Trajectory Progress\n{format_quantity(quantity)}")
    plt.show()


### TICA Analysis

In [None]:
fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(12, 3.5 * len(sampled_results_df)), squeeze=False)
for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]
    results = row["results"]["TICA_histograms"]

    # Plot free energy.
    ref_traj_tica = results["ref_traj"]
    pyemma_helper.plot_free_energy(*ref_traj_tica, cmap="plasma", ax=axs[i, 0])
    
    axs[i, 0].ticklabel_format(useOffset=False, style="plain")

    traj_tica = results["traj"]
    pyemma_helper.plot_free_energy(*traj_tica, cmap="plasma", ax=axs[i, 1])
    if i==0:
        axs[i, 1].set_title(format_traj_name("traj"))
        axs[i, 0].set_title(format_traj_name("ref_traj"))
    axs[i, 1].ticklabel_format(useOffset=False, style="plain")

    # Set the same limits for both plots.
    axs[i, 1].set_xlim(axs[i, 0].get_xlim())
    axs[i, 1].set_ylim(axs[i, 0].get_ylim())
    axs[i, -1].text(
        1.4,
        0.5,
        format_peptide_name(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

plt.suptitle("TICA-0,1 Projections", fontsize="x-large")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "tica_projections.pdf"), dpi=300)
plt.show()

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=1, squeeze=False)

tica_0_speedups = []
for i, row in results_df.iterrows():
    peptide = row["peptide"]
    results = row["results"]["TICA_decorrelations"]
    
    speedup_factor = results['ref_traj_decorrelation_time'] / results['traj_decorrelation_time']
    if np.isnan(speedup_factor):
        continue

    tica_0_speedups.append(speedup_factor)

print(f"Number of systems with valid decorrelations: {len(tica_0_speedups)} out of {len(results_df)}")

# Place legend outside plot.
bins = np.logspace(np.log10(np.min(tica_0_speedups)),np.log10(np.max(tica_0_speedups)), 21)
plt.hist(tica_0_speedups, bins=bins)
plt.xscale("log")
plt.xlabel("Speedup Factor")
plt.ylabel("Frequency")
plt.suptitle(f"Speedups of TICA-0 Decorrelation Times")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "tica_0_speedups.pdf"), dpi=300)
plt.show()

### MSM State Probabilities

In [None]:
all_ref_metastable_probs = []
all_traj_metastable_probs = []
for i, row in results_df.iterrows():
    results = row["results"]["JSD_MSM"]["traj"]
    ref_metastable_probs = results["ref_metastable_probs"]
    traj_metastable_probs = results["traj_metastable_probs"]
    
    all_ref_metastable_probs.append(ref_metastable_probs)
    all_traj_metastable_probs.append(traj_metastable_probs)

all_ref_metastable_probs = np.concatenate(all_ref_metastable_probs)
all_traj_metastable_probs = np.concatenate(all_traj_metastable_probs)

# Scatter plot of probabilities.
plt.scatter(all_ref_metastable_probs, all_traj_metastable_probs, alpha=0.3, edgecolors="none")

# Fit line.
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
    all_ref_metastable_probs, all_traj_metastable_probs
)

# Create x points for line.
x_line = np.array([-0.5, 1.5])
y_line = slope * x_line + intercept

# Plot the fitted line with dashed style.
plt.plot(x_line, y_line, color='red', linestyle='--')
plt.text(0.45, 0.90, f'R² = {r_value**2:.3f}', transform=plt.gca().transAxes, color='red')

plt.title("Metastable State Probabilities")
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.xlabel(format_traj_name("ref_traj"))
plt.ylabel(format_traj_name("traj"))
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "metastable_probs.pdf"), dpi=300)
plt.show()

### Transition and Flux Matrices

In [None]:
fig, axs = plt.subplots(2, len(sampled_results_df), figsize=(15, 5))

mean_correlation = results_df["results"].apply(lambda x: x["MSM_matrices"]["traj"]["transition_spearman_correlation"]).mean()
print(f"Mean correlation for flux matrices: {mean_correlation:.2f}")

for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]
    results = row["results"]["MSM_matrices"]["traj"]
    
    msm_transition_matrix = results["msm_transition_matrix"]
    traj_transition_matrix = results["traj_transition_matrix"]
    correlation = results["transition_spearman_correlation"]

    im = axs[0][i].imshow(msm_transition_matrix, cmap='Blues', vmin=0, vmax=1)
    axs[1][i].imshow(traj_transition_matrix, cmap='Blues', vmin=0, vmax=1)
    axs[0][i].set_title(f"{format_peptide_name(peptide)}\nρ = {correlation:.2f}")

axs[0][0].text(
    -0.4,
    0.5,
    format_traj_name("ref_traj"),
    horizontalalignment="right",
    verticalalignment="center",
    transform=axs[0, 0].transAxes
)

axs[1][0].text(
    -0.4,
    0.5,
    format_traj_name("traj"),
    horizontalalignment="right",
    verticalalignment="center",
    transform=axs[1, 0].transAxes
)

fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.022)
plt.savefig(os.path.join(output_dir, "transition_matrices.pdf"), dpi=300)
plt.show()

In [None]:
fig, axs = plt.subplots(2, len(sampled_results_df), figsize=(15, 5))

vmin = np.inf
vmax = -np.inf

for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]
    results = row["results"]["MSM_matrices"]["traj"]
    
    msm_flux_matrix = results["msm_flux_matrix"]
    traj_flux_matrix = results["traj_flux_matrix"]

    vmin = min(vmin, np.min(msm_flux_matrix), np.min(traj_flux_matrix))
    vmax = max(vmax, np.max(msm_flux_matrix), np.max(traj_flux_matrix))

mean_correlation = results_df["results"].apply(lambda x: x["MSM_matrices"]["traj"]["flux_spearman_correlation"]).mean()
print(f"Mean correlation for flux matrices: {mean_correlation:.2f}")

for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]
    results = row["results"]["MSM_matrices"]["traj"]
    
    msm_flux_matrix = results["msm_flux_matrix"]
    traj_flux_matrix = results["traj_flux_matrix"]
    correlation = results["flux_spearman_correlation"]

    im = axs[0][i].imshow(msm_flux_matrix, cmap='Blues', norm=matplotlib.colors.PowerNorm(gamma=0.5, vmin=vmin, vmax=vmax))
    axs[1][i].imshow(traj_flux_matrix, cmap='Blues', norm=matplotlib.colors.PowerNorm(gamma=0.5, vmin=vmin, vmax=vmax))
    axs[0][i].set_title(f"{format_peptide_name(peptide)}\nρ = {correlation:.2f}")


axs[0][0].text(
    -0.4,
    0.5,
    format_traj_name("ref_traj"),
    horizontalalignment="right",
    verticalalignment="center",
    transform=axs[0, 0].transAxes
)

axs[1][0].text(
    -0.4,
    0.5,
    format_traj_name("traj"),
    horizontalalignment="right",
    verticalalignment="center",
    transform=axs[1, 0].transAxes
)

fig.colorbar(im, ax=axs, orientation='vertical', fraction=0.022)
plt.savefig(os.path.join(output_dir, "flux_matrices.pdf"), dpi=300)
plt.show()