In [None]:
# Load required libraries

import os 
import re
from collections import defaultdict, OrderedDict
from typing import List, Tuple, Dict, Optional
import numpy as np
import mdtraj as md
import matplotlib.pyplot as plt 


In [None]:
# Define function to compute RMSF differences between two sets of trajectories 

def rmsf_backbone_difference_two_tops(
    dir1: str,
    dir2: str,
    residue_ranges: List[Tuple[int, int]] = [(2136, 2216), (2569, 2636)],
    topology_name: str = "system.pdb",   
    chunk_size: int = 1000,
    n_frames_remove: int = 0,           
    backbone_atoms: Tuple[str, ...] = ("N", "CA", "C", "O"),
    save_plot: Optional[str] = None,
    return_top_n: Optional[int] = None,
):

    def _numeric_dcds(dir_path: str) -> List[str]:
        dcds = []
        for fn in os.listdir(dir_path):
            m = re.match(r"^(\d+)\.dcd$", fn)
            if m:
                dcds.append((int(m.group(1)), os.path.join(dir_path, fn)))
        dcds.sort(key=lambda x: x[0])
        return [p for _, p in dcds]

    def _compute_rmsf(dir_path: str) -> Dict[int, float]:
        top_path = os.path.join(dir_path, topology_name)
        if not os.path.exists(top_path):
            raise FileNotFoundError(f"No topology {topology_name} in {dir_path}")

        dcds = _numeric_dcds(dir_path)
        if not dcds:
            raise FileNotFoundError(f"No numbered DCDs in {dir_path}")

        # reference
        ref = md.load_frame(dcds[0], 0, top=top_path)
        top = ref.topology

        # find selection atoms
        target_resseqs = set()
        for res in top.residues:
            rs = res.resSeq
            for a, b in residue_ranges:
                if a <= rs <= b:
                    target_resseqs.add(rs)
        align_idx = top.select("protein and backbone")
        sel_idx = []
        atom_to_resseq = {}
        for atom in top.atoms:
            if atom.residue.is_protein and atom.residue.resSeq in target_resseqs and atom.name in backbone_atoms:
                sel_idx.append(atom.index)
                atom_to_resseq[atom.index] = atom.residue.resSeq
        sel_idx = np.array(sel_idx, dtype=int)

        # accumulate values
        n_atoms = len(sel_idx)
        count = 0
        mean = np.zeros((n_atoms, 3), float)
        M2 = np.zeros((n_atoms, 3), float)

        for dcd in dcds:
            for chunk in md.iterload(dcd, top=top_path, chunk=chunk_size):
                # remove equilibration frames
                if n_frames_remove > 0 and chunk.n_frames > n_frames_remove:
                    chunk = chunk[n_frames_remove:]
                elif n_frames_remove >= chunk.n_frames:
                    continue 

                # align
                chunk.superpose(ref, atom_indices=align_idx, ref_atom_indices=align_idx)
                X = chunk.xyz[:, sel_idx, :] * 10.0  # nm -> Å

                for frame in X:
                    count += 1
                    delta = frame - mean
                    mean += delta / count
                    delta2 = frame - mean
                    M2 += delta * delta2

        if count == 0:
            raise RuntimeError("No frames after trimming!")

        var = M2 / count
        rmsf_atom = np.sqrt(np.sum(var, axis=1))

        # average per residue
        per_res = defaultdict(list)
        for ai, resseq in atom_to_resseq.items():
            per_res[resseq].append(rmsf_atom[list(sel_idx).index(ai)])
        return {rs: float(np.mean(vals)) for rs, vals in per_res.items()}

    # compute both
    rmsf1 = _compute_rmsf(dir1)
    rmsf2 = _compute_rmsf(dir2)

    # overlap
    common = sorted(set(rmsf1.keys()) & set(rmsf2.keys()))
    diffs = {rs: abs(rmsf1[rs] - rmsf2[rs]) for rs in common}

    # for labeling, use topology from dir1
    top1 = md.load(os.path.join(dir1, topology_name)).topology
    
    rs_to_name = {
        res.resSeq: res.name
        for res in top1.residues
        if res.is_protein and res.resSeq in common
    }
    
    sorted_items = sorted(
        ((f"{rs_to_name.get(rs, 'UNK')} {rs}", diffs[rs]) for rs in common),
        key=lambda kv: kv[1], reverse=True
    )
    if return_top_n:
        sorted_items = sorted_items[:return_top_n]
    diffs_ordered = OrderedDict(sorted_items)


    # --- Plot ΔRMSF across residues ---
    x = common
    y = [diffs[rs] for rs in x]

    fig, ax = plt.subplots(figsize=(10, 4))

    # split into contiguous segments
    segments = []
    seg_x, seg_y = [x[0]], [y[0]]
    for i in range(1, len(x)):
        if x[i] == x[i - 1] + 1:  # contiguous residue numbering
            seg_x.append(x[i])
            seg_y.append(y[i])
        else:
            segments.append((seg_x, seg_y))
            seg_x, seg_y = [x[i]], [y[i]]
    segments.append((seg_x, seg_y))  # add last segment

    # draw each contiguous segment separately
    for seg_x, seg_y in segments:
        ax.plot(seg_x, seg_y, color="black", linewidth=0.8)
        ax.fill_between(seg_x, seg_y, color="black", alpha=0.9)

    ax.set_xlabel("Residue (resSeq)")
    ax.set_ylabel("ΔRMSF (Å)")
    ax.set_title(f"ΔRMSF: {os.path.basename(dir1)} vs {os.path.basename(dir2)}")
    ax.set_xlim(min(x), max(x))
    ax.set_ylim(0, max(y) * 1.05)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.grid(False)
    # make x-axis residue labels more frequent
    step = 5  # or 10 depending on how dense you want it
    ax.set_xticks(np.arange(min(x), max(x)+1, step))
    ax.set_xticklabels(ax.get_xticks(), rotation=45, fontsize=8)

    if save_plot:
        fig.tight_layout()
        fig.savefig(save_plot, dpi=300)
        
    return fig,diffs_ordered
        

In [None]:

# Run the comparison between trajectories

fig, diffs = rmsf_backbone_difference_two_tops(
    dir1="/data/chodera/viktor/IP3R_MD_t/test_systems/type2_jd_zn_apo_test/data_final",
    dir2="/data/chodera/viktor/IP3R_MD_t/systems/type2_jd_zn_ATP/data_final/",
    residue_ranges=[(2136, 2216), (2569, 2636)], 
    topology_name="step3_input.pdb",                   
    chunk_size=1000,                              
    n_frames_remove=100,      

    save_plot="delta_rmsf.png",                  
    return_top_n=100                              
)
