In [None]:
import mdtraj as md 
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns

In [None]:
## Define distance calc function

def calc_com_distance(traj):
    # Compute center of mass for the first group (chain A)
    chaina_com = md.compute_center_of_mass(traj, select='(residue 2174 or residue 2176 or residue 2177 or residue 2178 or residue 2179 or residue 2180 or residue 2181 or residue 2182 or residue 2183) and chainid 0 and backbone')

    # Compute center of mass for the second group (chain B)
    chainb_com = md.compute_center_of_mass(traj, select='(residue 2579 or residue 2580 or residue 2581 or residue 2582 or residue 2583 or residue 2584 or residue 2585) and chainid 1 and backbone')

    # Compute Euclidean distance between centers of mass at each frame
    distances = np.linalg.norm(chainb_com - chaina_com, axis=1)
    
    return distances


In [None]:
# Define trimming function

def trim_replicate_timeseries(data, total_ns, n_replicates, trim_ns):
    data = np.asarray(data)
    n_frames = data.shape[0]
    frames_per_rep = n_frames // n_replicates

    # assume each replicate represents 500 ns 
    frames_per_ns = frames_per_rep / 500.0
    trim_frames = int(trim_ns * frames_per_ns)

    trimmed_segments = []
    for i in range(n_replicates):
        start = i * frames_per_rep + trim_frames
        end = (i + 1) * frames_per_rep
        trimmed_segments.append(data[start:end])

    trimmed_data = np.concatenate(trimmed_segments, axis=0)
    return trimmed_data


In [None]:
# Load trajectories 

atp_traj=md.load_dcd('path/to/joined.dcd',top='path/to/step3_input.pdb')
adp_traj=md.load_dcd('path/to/joined.dcd',top='path/to/step3_input.pdb')
amp_traj=md.load_dcd('path/to/joined.dcd',top='path/to/step3_input.pdb')
camp_traj=md.load_dcd('path/to/joined.dcd',top='path/to/step3_input.pdb')
adenosine_traj=md.load_dcd('path/to/joined.dcd',top='path/to/step3_input.pdb')
guanosine_traj=md.load_dcd('path/to/joined.dcd',top='path/to/step3_input.pdb')


In [None]:
# Trim distances

atp_dist=trim_replicate_timeseries(atp_dist,5000,10,100)
adp_dist=trim_replicate_timeseries(adp_dist,5000,10,100)
amp_dist=trim_replicate_timeseries(amp_dist,5000,10,100)
camp_dist=trim_replicate_timeseries(camp_dist,5000,10,100)
adenosine_dist=trim_replicate_timeseries(adenosine_dist,5000,10,100)
guanosine_dist=trim_replicate_timeseries(guanosine_dist,5000,10,100)


In [None]:
# combine adenine nucleotide distances\

adenine_nucleotide_dist=np.concatenate([atp_dist,adp_dist,amp_dist,camp_dist])

In [None]:
# Plot facet plot

distributions = [atp_dist * 10, adenosine_dist * 10, guanosine_dist * 10]
labels = ["Dist 1", "Dist 2", "Dist 3"]
colors = ["gray", "gray", "gray"]

fig, axes = plt.subplots(len(distributions), 1, figsize=(4, 3), sharex=True, gridspec_kw={"hspace": 0})

ymax = 0  # initialize a variable to store the global max y value

# First, draw the plots and track each one's max y value
for i, (dist, label, color) in enumerate(zip(distributions, labels, colors)):
    kde = sns.kdeplot(
        dist,
        ax=axes[i],
        alpha=0.6,
        color=color,
        linewidth=0,
        fill=True,
        bw_adjust=0.7
    )
    # Update the global ymax based on current axes
    ymax = max(ymax, axes[i].get_ylim()[1])
    
    axes[i].set_title(label, loc="left", fontsize=10, pad=-15)
    axes[i].set_yticks([])
    axes[i].set_ylabel("")
    axes[i].set_xlim(min(guanosine_dist * 10), max(guanosine_dist * 10))
    axes[i].tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True)
    for spine in axes[i].spines.values():
        spine.set_visible(True)

# --- Apply the same ylim to all ---
for ax in axes:
    ax.set_ylim(0, ymax)

axes[-1].set_xlabel("Value")
plt.tight_layout()
plt.savefig('adn_gsn_facet_plot.png', dpi=300)
plt.show()