### Objective 1
We want something which will, given the md trajectory, calculate the dihedral angles (from this we can get their distributions)
given the relevant SMILES string.

### Objective 2
We also want something to plot the csv file data, so that we can do a sanity check to ensure the macroscopic properties (temp, pressure)
are reasonable and are not discontinuous.

In [11]:
import os
from typing import List

from openmm import *
from openmm.app import *

import mdtraj as md
import numpy as np

In [None]:
smiles_dict = {
    'Lys-Tyr': 'CNCc1c(O)ccc(c1)C',
    'Lys-Arg': 'N1(C)CN=C(NC1)NC',
    'Sulfur-Mediated-Amide': 'SC[C@@H](NC(=O)C)C(=O)C',
    'Carboxyl-Carboxyl': 'C(=O)(C)NCc1ccc(cc1)CNC(=O)C',
    'Disulfide': 'N[C@H](C(=O)O)CSSC[C@@H](N)C(=O)O',
    'Cys-Arg': 'CSCC(=O)NCCCC',
    'Cys-Carboxyl': 'CSCC(=O)C',
}

def load_trajectory(traj_path: str, system_pdb: str, fraction: float=0.2) -> md.Trajectory:
    """
    Load an MD trajectory from a .dcd file using the topology of the initial system PDB.
    Only keeps the last `fraction` of the frames.
    
    Args:
        traj_path (str): Path to the .dcd trajectory file.
        system_pdb (str): Path to the system .pdb file used for topology.
        fraction (float): Fraction of the trajectory to keep (default: last 20%).
    
    Returns:
        md.Trajectory: Processed trajectory with correct topology.
    """
    
    # Load the system PDB for topology
    pdb_topology = md.load(system_pdb).topology

    # Print the expected number of atoms from the PDB
    #print(f"Expected number of atoms from PDB: {pdb_topology.n_atoms}")

    # Load the DCD trajectory with the correct topology
    traj = md.load(traj_path, top=system_pdb)

    # Print the actual number of atoms in the loaded DCD file
    #print(f"Number of atoms in DCD file: {traj.topology.n_atoms}")

    # Sanity check: Ensure atom numbers match
    if traj.topology.n_atoms != pdb_topology.n_atoms:
        print("❌ Mismatch detected! First few atoms in PDB vs. DCD:")
        
        print("PDB Atom Order:")
        for atom in pdb_topology.atoms[:10]:
            print(atom.index, atom.name, atom.element)

        print("\nDCD Atom Order:")
        for atom in traj.topology.atoms[:10]:
            print(atom.index, atom.name, atom.element)

        raise ValueError(f"Atom count mismatch: PDB ({pdb_topology.n_atoms}) vs DCD ({traj.topology.n_atoms})")

    # Keep only the last `fraction` of the trajectory
    start_frame = int(traj.n_frames * (1 - fraction))
    traj = traj[start_frame:]

    #print(f"✅ Successfully loaded trajectory with {traj.n_frames} frames.")

    return traj

def filter_traj_water(system_traj: md.Trajectory) -> md.Trajectory:
    """
    Filters the trajectory to remove residues that consist exclusively of O, H1, and H2.

    Args:
        system_traj (md.Trajectory): The full trajectory including solvent.

    Returns:
        md.Trajectory: The trajectory without water (or any OHH-only residues).
    """

    # Get topology from trajectory
    topology = system_traj.topology

    # Identify residues that only contain O, H1, and H2
    water_residues = set()
    for res in topology.residues:
        atom_names = {atom.name for atom in list(res.atoms)}  # Convert iterator to list first
        if atom_names == {"O1", "H1", "H2"}:  # Check if residue is exclusively OHH
            water_residues.add(res)

    # Get indices of all non-water atoms
    non_water_atoms = [atom.index for atom in topology.atoms if atom.residue not in water_residues]

    # Debugging: Print number of atoms removed and retained
    #print(f"✅ Removing {topology.n_atoms - len(non_water_atoms)} water atoms.")
    #print(f"✅ Keeping {len(non_water_atoms)} non-water atoms.")

    # Slice trajectory to keep only non-water atoms
    filtered_traj = system_traj.atom_slice(non_water_atoms)

    return filtered_traj


In [47]:
Lys_Tyr_dir = '/home/bfd21/rds/hpc-work/tbg/cyclization/jobs/md-jobs/Lys-Tyr/results/'
Lys_Tyr_trajs = [os.path.join(Lys_Tyr_dir, f'traj_seed_{x}.dcd') for x in range(10)]
Lys_Tyr_system_pdb = os.path.join(Lys_Tyr_dir, 'system.pdb')

traj = load_trajectory(Lys_Tyr_trajs[0], Lys_Tyr_system_pdb)
filt_traj = filter_traj_water(traj)

top = filt_traj.topology