# Diffusion Model implementation for energy landscape exploration

Testing ground for diffusion model implementation using pytorch implementation

## LLM Assisted Model

Data Handling

    Current: The code uses a preprocessed .npy trajectory dataset.

    Modify: Parse .mdcrd files into usable trajectory tensors.

        Convert to .npy or torch.Tensor sequences.

        Normalize or align structures (e.g. RMSD alignment).

        Split into tuples: (start_frame, end_frame, full_path_sequence).

2. Input Preparation

    You’ll now want to encode the start and end frames as conditional inputs.

        Option A: Concatenate start + end frames to the noisy input.

        Option B: Encode start + end through a separate encoder and use in cross-attention or FiLM layers in the U-Net.

3. Modify the UNet + Diffusion Model

    Input shape: Instead of only receiving noisy intermediate x_t, the model receives:

    model(noisy_frame_t, timestep_t, start_frame, end_frame)

    Diffusion target: Instead of denoising a full trajectory from pure noise, the model learns to interpolate the path that connects the two known endpoints.

    Loss: Compare predicted frame at time t to ground-truth frame in the trajectory using MSE/L2.

4. Sampling (Inference)

    After training:

        Input start_frame and end_frame.

        Initialize the intermediate frames as noise.

        Use reverse diffusion steps to denoise iteratively — conditioning on both endpoints — to generate a discrete transition path.

### MDCRD data preparation

MDTraj python library to handle MDCRD files
Corresponding topology file (prmtop) required

Next we need to extract coordinates
 typically get an array of shape [num_frames, num_atoms, 3]. 
Let’s normalize and reshape:

goal is to predict transition paths between two endpoint structures, you’ll want to create training samples like:
(start_frame, end_frame, trajectory_segment)

In [None]:
# for all atom coordinates
import mdtraj as md
import numpy as np
import os

# Configuration
trajectory_folder = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/'       # folder containing all .mdcrd files
topology_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'  # shared topology

all_paths = []

for file in os.listdir(trajectory_folder): #loop over all mdcrd files in the directory
    if file.endswith(".mdcrd"): #check if the file ends with .mdcrd to make sure that only mdcrd files are selected
        filepath = os.path.join(trajectory_folder, file) #get the full filepath of the path file
        #print(f"Processing: {file}")
        
        traj = md.load_mdcrd(filepath, top=topology_file) #load the trajectory into python

        coords = traj.xyz #get the xyz (Cartersian) coordinates of the trajectories as a numpy array
                          #all of the distances in the Trajectory are stored in nanometers. The time unit is picoseconds. Angles are stored in degrees (not radians).
        
        flattened = coords.reshape(coords.shape[0], -1) #need to flatten the frames into 1D for diffusion U-net model to accept
        #goal is to predict paths between 2 endpoints so training samples should have a start and end point as well
        #create tuples like (start_frame, end_frame, path) for training data
        start = flattened[0]    # shape: (n_atoms * 3,), get the 1st frame of the path
        end = flattened[-1]    # shape: (n_atoms * 3,), get the last frame of the path
        path = flattened # the entire path
        all_paths.append((start, end, path)) #append the tuple to a new list


In [None]:
# for CG model
import mdtraj as md
import numpy as np
import os
from tqdm import tqdm

# configuration
trajectory_folder = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/' # Folder with .mdcrd files
prmtop_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop' # Topology file

PO5_BOND_LENGTH = 0.1615  # value to scale the vector from HO5' to O5' to match bond length of P-O5' (0.1615 nm = ~1.615 Å,). Use nm as MDTraj internally uses nm

atomic_weights = {  # Atomic weights for centre of mass calculations
    "C": 12.011,
    "N": 14.007,
    "O": 16.000
}

CG_bases = {  # sets of atoms for each nucleotide type base group
    "A1": ["N7", "N9", "C4", "C5", "C8"],
    "A2": ["N1", "C2", "N3", "C4", "C5", "N6", "C6"],
    "C1": ["N1", "C2", "N3", "N4", "C4", "C5", "C6", "O2"],
    "G1": ["N7", "N9", "C4", "C5", "C8"],
    "G2": ["N1", "N2", "C2", "N3", "C4", "C5", "C6", "O6"],
    "T1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"],
    "U1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"]
}

CG_bases_simple = {  # sets of atoms for each nucleotide type base group
    "A1": ["N", "N", "C", "C", "C"],
    "A2": ["N", "C", "N", "C", "C", "N", "C"],
    "C1": ["N", "C", "N", "N", "C", "C", "C", "O"],
    "G1": ["N", "N", "C", "C", "C"],
    "G2": ["N", "N", "C", "N", "C", "C", "C", "O"],
    "T1": ["N", "C", "N", "C", "C", "C", "O", "O"],
    "U1": ["N", "C", "N", "C", "C", "C", "O", "O"]
}

def compute_dihedrals(coords, quadruplets):
    """
    Compute dihedral angles from general bead coordinates.

    Parameters:
    - coords: np.ndarray, shape (n_frames, n_beads, 3)
    - quadruplets: list of (i, j, k, l) index tuples referring to bead positions

    Returns:
    - angles: np.ndarray, shape (n_frames, len(quadruplets)) — dihedral angles in radians
    """
    angles = []
    for (i, j, k, l) in quadruplets:
        p1 = coords[:, i]
        p2 = coords[:, j]
        p3 = coords[:, k]
        p4 = coords[:, l]

        # bond vectors
        b1 = p2 - p1
        b2 = p3 - p2
        b3 = p4 - p3

        # normal vectors to the planes
        n1 = np.cross(b1, b2)
        n2 = np.cross(b2, b3)

        # Normalize safely
        n1 /= np.linalg.norm(n1, axis=1, keepdims=True) + 1e-8
        n2 /= np.linalg.norm(n2, axis=1, keepdims=True) + 1e-8
        b2 /= np.linalg.norm(b2, axis=1, keepdims=True) + 1e-8

        # Compute angle using atan2 of the torsion formula
        x = np.sum(n1 * n2, axis=1)
        y = np.sum(np.cross(n1, n2) * b2, axis=1)
        angle = np.arctan2(y, x)

        angles.append(angle)

    return np.stack(angles, axis=1)


# function for calculating centre of mass
def compute_com(coords, elements):
    """
    Compute center of mass from coordinates and atom types.
        - coords = A NumPy array of shape (n_atoms, 3 coordinates), each row is the 3D coordinate [x, y, z] of one atom
        - elements = A list of N atomic element symbols in the same order as the coordinates
    """
    weights = np.array([atomic_weights[e] for e in elements]) # For each atom element, look up its atomic weights from the predefined dictionary, outputs a vector of atomic masses
    return np.average(coords, axis=0, weights=weights) # calculate masses using a weighted average (xyz coords are averaged using the atomic weights as weights)

all_paths = []  # for storing (start, end, path) triplets

# processing path files
for filename in tqdm(os.listdir(trajectory_folder)): # loop through each file in the trajectory folder that end in .mdcrd
    if not filename.endswith(".mdcrd"): # if the file does not end in .mdcrd then skip
        continue

    # loading in files and topologies
    path_file = os.path.join(trajectory_folder, filename) # create full file path from the filename and folder path
    # print(f"Processing: {filename}") # print which file is being processed

    traj = md.load_mdcrd(path_file, top=prmtop_file) #load mdcrd file path
    top = traj.topology # load the topology for the path

    n_frames = traj.n_frames # find the number of frames in the path
    coarse_grained = []  # create empty list to store the CG data, shape: (n_frames, n_beads, 3 coordinates)

    # main CG processing
    for res in top.residues: # loop over each residue 
        atom_map = {a.name: a.index for a in res.atoms} # for each atom in the residue build a dictonary to map atom names to their indicies used later 

        # backbone bead extraction
        if res.index == 0: # special case for 1st nucleotide: estimate P from HO5' and O5'
            ho5 = traj.xyz[:, atom_map["HO5'"], :] # extract coordinate positions of HO5'
            o5 = traj.xyz[:, atom_map["O5'"], :] # extract coordinate positions of O5'
            direction = o5 - ho5 # find direction vector between HO5' and O5'
            unit_vec = direction / np.linalg.norm(direction, axis=1, keepdims=True) # normalize the direction vector by dividing it with its length to remove the magnitude (length) of the vector so we only have the direction in the form of a unit vector
                                                                                    # np.linalg.norm is used to calculate the Euclidean length of the direction vector 
            p_bead = o5 + unit_vec * PO5_BOND_LENGTH # extending vector to match P-O5' bond length
        else:
            p_bead = traj.xyz[:, atom_map["P"], :] # if is not 1st atom then just extract the P

        # extract coordinates of specific atoms for CG beads
        # outputs 5 arrays with (n_frames, 3 coordinates)
        o5_bead = traj.xyz[:, atom_map["O5'"], :] 
        c5_bead = traj.xyz[:, atom_map["C5'"], :]
        c4_bead = traj.xyz[:, atom_map["C4'"], :]
        c1_bead = traj.xyz[:, atom_map["C1'"], :]

        # collect backbone by setting up a list of 5 arrays 
        backbone = [p_bead, o5_bead, c5_bead, c4_bead, c1_bead]  # (5, n_frames, 3 coordiantes)

        # base bead calculation setup
        resname = res.name.strip() # extract the name of the residue 
        base_beads = [] # list setup to store the calculated base beads
        base_names = []
        # base centre of mass bead calculations
        for key, atom_list in CG_bases.items(): # extract the base names and atom lists from the defined dictionary setup before
            if key.startswith(resname[0]): # resname is the current residue's name (e.g., "ADE", "CYT"). resname[0] would be the base letter: "A", "C", "G", "T", or "U". 
                                           # checks if the CG base type (eg: "A1", "A2") matches the nucleotide
                coords_list = [] # initialize lists to store the coordinates and element names
                elements_list = []
                for atom_name in atom_list: # loop over each atom in the list of predefined CG base atoms 
                    if atom_name in atom_map: # check if the name of the atom is in the atom map of the residue
                        coords_list.append(traj.xyz[:, atom_map[atom_name], :]) # extract and append the coordinates of that atom for every frame. creates a list of arrays of shape (n_frames, 3)
                        elements_list.append(top.atom(atom_map[atom_name]).element.symbol) # extract the element symbol (e.g., 'C', 'N') for that atom index and appends it to a list
                if len(coords_list) > 0: # proceed only if we successfully collected atoms from the residue
                    coords_arr = np.stack(coords_list, axis=1)  # join the list of arrays together to create a shape of (n_frames, n_atoms, 3)
                    coms = np.array([compute_com(coords_arr[i], elements_list) for i in range(n_frames)]) # for each frame of the path compute the COM for the nucleotide, returns an array of shape (n_frames, 3) — 1 or 2 COM per frame
                    base_beads.append(coms[:, np.newaxis, :]) # save value to list

        if base_beads is not None: # if we have a base bead add it to the full CG set
            base_beads_cat = np.concatenate(base_beads, axis=1)  #  concatenate the beads together so it can be concatenated into the full bead set later
            full_bead_set = backbone + [base_beads_cat]
        else: # if no base bead is calculated it's just the backbone set
            full_bead_set = backbone

        # merge all beads into a single array per residue into shape (n_frames, N_beads, 3)
        # Each bead is a NumPy array of shape (n_frames, 3): coordinates across trajectory frames for a single bead
        reshaped_beads = [b if b.ndim == 3 else b[:, np.newaxis, :] for b in full_bead_set] # check that all bead arrays are shape (n_frames, N_beads, 3)
                                                                                            # If the bead has shape (n_frames, 3) (i.e., a single bead) insert a new axis at position 1 → becomes (n_frames, 1, 3).
        residue_beads = np.concatenate(reshaped_beads, axis=1) #  concatenates beads along axis 1 (the bead dimension), outputs (n_frames, total_beads_per_residue, 3 coordiantes)

        coarse_grained.append(residue_beads) # append to list

    # once all CG beads are calculated 
    # Join together the coarse-grained coordinates from each residue in the molecule along the bead dimension.
    cg_traj = np.concatenate(coarse_grained, axis=1)  # (n_frames, total_beads, 3)

    # Flatten each frame from 2D (total_beads, 3) to 1D (total_beads * 3) so each frame becomes a single vector of coordinates.
    cg_flat = cg_traj.reshape(n_frames, -1)  # (n_frames, total_beads * 3)

    # split out and save the start and end frames and the full path for dataset setup later
    start = cg_flat[0]
    end = cg_flat[-1]
    path = cg_flat

    # append to output list
    all_paths.append((start, end, path))

# ========== SAVE FOR TRAINING ==========
#np.save("cg_paths.npy", all_paths, allow_pickle=True)
print(f"Loaded {len(all_paths)} CG trajectories.")


100%|██████████| 27/27 [00:21<00:00,  1.24it/s]

Loaded 27 CG trajectories.





In [30]:
print (base_names)
len(base_beads[0])

['A1', 'G1', 'G2']


447

In [24]:
elements_list

['N', 'N', 'C', 'N', 'C', 'C', 'C', 'O']

In [None]:
import mdtraj as md
import numpy as np
import os
from tqdm import tqdm

def idx(res, atom_name):
    try:
        return next(a.index for a in res.atoms if a.name == atom_name)
    except StopIteration:
        return None

def compute_dihedrals(coords, quadruplets):
    """
    Compute dihedral angles from general bead coordinates.

    Parameters:
    - coords: np.ndarray, shape (n_frames, n_beads, 3)
    - quadruplets: list of (i, j, k, l) index tuples referring to bead positions

    Returns:
    - angles: np.ndarray, shape (n_frames, len(quadruplets)) — dihedral angles in radians
    """
    angles = []
    for (i, j, k, l) in quadruplets:
        p1 = coords[:, i]
        p2 = coords[:, j]
        p3 = coords[:, k]
        p4 = coords[:, l]

        # bond vectors
        b1 = p2 - p1
        b2 = p3 - p2
        b3 = p4 - p3

        # normal vectors to the planes
        n1 = np.cross(b1, b2)
        n2 = np.cross(b2, b3)

        # Normalize safely
        n1 /= np.linalg.norm(n1, axis=1, keepdims=True) + 1e-8
        n2 /= np.linalg.norm(n2, axis=1, keepdims=True) + 1e-8
        b2 /= np.linalg.norm(b2, axis=1, keepdims=True) + 1e-8

        # Compute angle using atan2 of the torsion formula
        x = np.sum(n1 * n2, axis=1)
        y = np.sum(np.cross(n1, n2) * b2, axis=1)
        angle = np.arctan2(y, x)

        angles.append(angle)

    return np.stack(angles, axis=1)

# Configuration
trajectory_folder = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/'
prmtop_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'
PO5_BOND_LENGTH = 0.1615  # nm

all_paths = []

for filename in tqdm(os.listdir(trajectory_folder)):
    if not filename.endswith(".mdcrd"):
        continue

    traj = md.load_mdcrd(os.path.join(trajectory_folder, filename), top=prmtop_file)
    top = traj.topology
    atom_indices = {atom.name: atom.index for atom in top.atoms}
    n_frames = traj.n_frames

    dihedrals_all = []

    residues = list(top.residues)
    n_residues = len(residues)

    # main CG processing
    for res in top.residues: # loop over each residue 
        atom_map = {a.name: a.index for a in res.atoms} # for each atom in the residue build a dictonary to map atom names to their indicies used later 

        # backbone bead extraction
        if res.index == 0: # special case for 1st nucleotide: estimate P from HO5' and O5'
            ho5 = traj.xyz[:, atom_map["HO5'"], :] # extract coordinate positions of HO5'
            o5 = traj.xyz[:, atom_map["O5'"], :] # extract coordinate positions of O5'
            direction = o5 - ho5 # find direction vector between HO5' and O5'
            unit_vec = direction / np.linalg.norm(direction, axis=1, keepdims=True) # normalize the direction vector by dividing it with its length to remove the magnitude (length) of the vector so we only have the direction in the form of a unit vector
                                                                                    # np.linalg.norm is used to calculate the Euclidean length of the direction vector 
            p_bead = o5 + unit_vec * PO5_BOND_LENGTH # extending vector to match P-O5' bond length
        else:
            p_bead = traj.xyz[:, atom_map["P"], :] # if is not 1st atom then just extract the P

        # extract coordinates of specific atoms for CG beads
        # outputs 5 arrays with (n_frames, 3 coordinates)
        O5 = traj.xyz[:, atom_map["O5'"], :] 
        C5 = traj.xyz[:, atom_map["C5'"], :]
        R4 = traj.xyz[:, atom_map["C4'"], :]
        R1 = traj.xyz[:, atom_map["C1'"], :]

        # base bead calculation setup
        resname = res.name.strip() # extract the name of the residue 
        base_beads = [] # list setup to store the calculated base beads
        base_names = []
        # base centre of mass bead calculations
        for key, atom_list in CG_bases.items(): # extract the base names and atom lists from the defined dictionary setup before
            if key.startswith(resname[0]): # resname is the current residue's name (e.g., "ADE", "CYT"). resname[0] would be the base letter: "A", "C", "G", "T", or "U". 
                                        # checks if the CG base type (eg: "A1", "A2") matches the nucleotide
                coords_list = [] # initialize lists to store the coordinates and element names
                elements_list = []
                for atom_name in atom_list: # loop over each atom in the list of predefined CG base atoms 
                    if atom_name in atom_map: # check if the name of the atom is in the atom map of the residue
                        coords_list.append(traj.xyz[:, atom_map[atom_name], :]) # extract and append the coordinates of that atom for every frame. creates a list of arrays of shape (n_frames, 3)
                        elements_list.append(top.atom(atom_map[atom_name]).element.symbol) # extract the element symbol (e.g., 'C', 'N') for that atom index and appends it to a list
                        for key2, atom_list2 in CG_bases_simple.items():
                            if elements_list == atom_list2:
                                base_names.append(key2)
                if len(coords_list) > 0: # proceed only if we successfully collected atoms from the residue
                    coords_arr = np.stack(coords_list, axis=1)  # join the list of arrays together to create a shape of (n_frames, n_atoms, 3)
                    coms = np.array([compute_com(coords_arr[i], elements_list) for i in range(n_frames)]) # for each frame of the path compute the COM for the nucleotide, returns an array of shape (n_frames, 3) — 1 or 2 COM per frame
                    base_beads.append(coms[:, np.newaxis, :]) # save value to list

        # --- Dihedral Atom Selections ---
        atom_sets = [
            # 0: R4-R1-A1/G1-A2/G2
            [R4, R1, idx(r1, "N9") or idx(r1, "N1"), idx(r1, "C4")],
            # 1: R4-A1/G1-A2/G2-R1
            [R4, idx(r1, "N9") or idx(r1, "N1"), idx(r1, "C4"), idx(r1, "C1'")],
            # 2: C-R4-R1-X1
            [idx(r0, "C5'"), idx(r0, "C4'"), idx(r0, "C1'"), idx(r1, "C1'")],
            # 3: P-R4-R1-X1
            [idx(r0, "P"), idx(r0, "C4'"), idx(r0, "C1'"), idx(r1, "C1'")],
            # 4: C-R4-P-O
            [idx(r0, "C5'"), idx(r0, "C4'"), idx(r0, "P"), idx(r0, "O5'")],
            # 5: R1-R4-P-O
            [idx(r0, "C1'"), idx(r0, "C4'"), idx(r0, "P"), idx(r0, "O5'")],
            # 6: O-C-R4-P
            [idx(r0, "O5'"), idx(r0, "C5'"), idx(r0, "C4'"), idx(r0, "P")],
            # 7: O-C-R4-R1
            [idx(r0, "O5'"), idx(r0, "C5'"), idx(r0, "C4'"), idx(r0, "C1'")],
            # 8: P-O-C-R4
            [idx(r0, "P"), idx(r0, "O5'"), idx(r0, "C5'"), idx(r0, "C4'")],
            # 9: R4-P-O-C
            [idx(r0, "C4'"), idx(r0, "P"), idx(r0, "O5'"), idx(r0, "C5'")],
        ]

        frame_dihedrals = []

        for atom_quad in atom_sets:
            if None in atom_quad:
                frame_dihedrals.append(np.full(n_frames, np.nan))  # Placeholder for missing data
            else:
                angle = compute_dihedrals([atom_quad], )  # (n_frames,)
                frame_dihedrals.append(angle)

        dihedrals_all.append(np.stack(frame_dihedrals, axis=1))  # (n_frames, 10)

    if len(dihedrals_all) == 0:
        continue

    traj_dihedrals = np.stack(dihedrals_all, axis=1)  # (n_frames, n_residues_used, 10)
    traj_flat = traj_dihedrals.reshape(n_frames, -1)  # (n_frames, n_residues_used * 10)

    # Save (start, end, path)
    start = traj_flat[0]
    end = traj_flat[-1]
    path = traj_flat
    all_paths.append((start, end, path))

# Save paths
np.save("cg_dihedral_paths.npy", all_paths, allow_pickle=True)
print(f"Loaded {len(all_paths)} trajectories with internal coordinates.")


In [1]:
import mdtraj as md
import numpy as np
import os
from tqdm import tqdm

# Configuration
trajectory_folder = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/'
prmtop_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'
PO5_BOND_LENGTH = 0.1615  # nm

# Atomic weights for COM calculations
atomic_weights = {"C": 12.011, "N": 14.007, "O": 16.000}

# Nucleotide base definitions
CG_bases = {
    "A1": ["N7", "N9", "C4", "C5", "C8"],
    "A2": ["N1", "C2", "N3", "C4", "C5", "N6", "C6"],
    "C1": ["N1", "C2", "N3", "N4", "C4", "C5", "C6", "O2"],
    "G1": ["N7", "N9", "C4", "C5", "C8"],
    "G2": ["N1", "N2", "C2", "N3", "C4", "C5", "C6", "O6"],
    "T1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"],
    "U1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"]
}

def compute_dihedrals(coords, quadruplets):
    """
    Compute dihedral angles from bead coordinates.
        - coords: A numpy array of shape (n_frames, n_beads, 3) containing 3D coordinates of beads across multiple frames.
        - quadruplets: A list of tuples (i, j, k, l), that defines the indices of 4 beads needed to compute a dihedral angle.
    """
    angles = [] # list to store computed dihedrals
    for (i, j, k, l) in quadruplets: # iterate over each set of 4 bead indices (i, j, k, l) in the predefined quadruplets
        # extract the coordinates of the 4 beads
        p1 = coords[:, i]
        p2 = coords[:, j]
        p3 = coords[:, k]
        p4 = coords[:, l]

        # Calculates the vectors connecting consecutive beads I.E the bond vectors 
        b1 = p2 - p1 # vector from bead i to j.
        b2 = p3 - p2 # vector from bead j to k.
        b3 = p4 - p3 # vector from bead k to l.

        # cross product to find normals to the two planes
        n1 = np.cross(b1, b2) # normal to the plane formed by (p1, p2, p3).
        n2 = np.cross(b2, b3) # normal to the plane formed by (p2, p3, p4).

        # Normalizes n1, n2, and b2 to unit vectors to avoid scaling artifacts and remove lengths, 1e-8 is added to prevent division-by-zero and floating-point instability when normalizing near-zero vectors
        n1_norm = np.linalg.norm(n1, axis=1, keepdims=True) + 1e-8 
        n2_norm = np.linalg.norm(n2, axis=1, keepdims=True) + 1e-8
        b2_norm = np.linalg.norm(b2, axis=1, keepdims=True) + 1e-8

        n1 /= n1_norm
        n2 /= n2_norm
        b2 /= b2_norm

        # Compute dihedral using atan2
        x = np.sum(n1 * n2, axis=1) # dot product of n1 and n2 (cosine of the angle between normals).
        y = np.sum(np.cross(n1, n2) * b2, axis=1) # Projection of the cross product (n1 * n2) onto b2 (sine of the angle).
        angle = np.arctan2(y, x) # Uses arctan2(y, x) to compute the dihedral angle in the range [-π, π].
        angles.append(angle) # append to storage list
        
    return np.stack(angles, axis=1) # convert the list of angles (one per quadruplet) into a 2D NumPy array of shape (n_frames, n_quadruplets).

def compute_com(coords, elements):
    """Compute center of mass from coordinates and elements."""
    weights = np.array([atomic_weights[e] for e in elements])
    return np.average(coords, axis=0, weights=weights)

# Quadruplet definitions for dihedrals (0-based indexing)
PURINE_QUADRUPLETS = [  # For A, G (10 dihedrals)
    (3, 4, 5, 6),   # 0: R4-R1-A1/G1-A2/G2
    (3, 5, 6, 4),   # 1: R4-A1/G1-A2/G2-R1
    (2, 3, 4, 5),   # 2: C-R4-R1-X1
    (0, 3, 4, 5),   # 3: P-R4-R1-X1
    (2, 3, 0, 1),   # 4: C-R4-P-O
    (4, 3, 0, 1),   # 5: R1-R4-P-O
    (1, 2, 3, 0),   # 6: O-C-R4-P
    (1, 2, 3, 4),   # 7: O-C-R4-R1
    (0, 1, 2, 3),   # 8: P-O-C-R4
    (3, 0, 1, 2)    # 9: R4-P-O-C
]

PYRIMIDINE_QUADRUPLETS = [  # For C, T, U (8 dihedrals, skipping 0 and 1)
    (2, 3, 4, 5),   # 2: C-R4-R1-X1 
    (0, 3, 4, 5),   # 3: P-R4-R1-X1 
    (2, 3, 0, 1),   # 4: C-R4-P-O 
    (4, 3, 0, 1),   # 5: R1-R4-P-O 
    (1, 2, 3, 0),   # 6: O-C-R4-P 
    (1, 2, 3, 4),   # 7: O-C-R4-R1 
    (0, 1, 2, 3),   # 8: P-O-C-R4 
    (3, 0, 1, 2)    # 9: R4-P-O-C 
]

all_paths = []

# Process each trajectory
for filename in tqdm(os.listdir(trajectory_folder)):
    if not filename.endswith(".mdcrd"):
        continue

    path_file = os.path.join(trajectory_folder, filename)
    traj = md.load_mdcrd(path_file, top=prmtop_file)
    top = traj.topology
    n_frames = traj.n_frames
    
    dihedral_trajs = []  # Store dihedrals per residue per frame

    for res in top.residues:
        atom_map = {a.name: a.index for a in res.atoms}
        resname = res.name.strip()

        # Backbone bead calculation
        if res.index == 0:
            ho5 = traj.xyz[:, atom_map["HO5'"], :]
            o5 = traj.xyz[:, atom_map["O5'"], :]
            direction = o5 - ho5
            unit_vec = direction / (np.linalg.norm(direction, axis=1, keepdims=True) + 1e-8)
            p_bead = o5 + unit_vec * PO5_BOND_LENGTH
        else:
            p_bead = traj.xyz[:, atom_map["P"], :]

        o5_bead = traj.xyz[:, atom_map["O5'"], :]
        c5_bead = traj.xyz[:, atom_map["C5'"], :]
        c4_bead = traj.xyz[:, atom_map["C4'"], :]
        c1_bead = traj.xyz[:, atom_map["C1'"], :]
        backbone = [p_bead, o5_bead, c5_bead, c4_bead, c1_bead] # bead indices are fixed here (indices 0-4) which allows for predefined quadruplets 

        # Base beads
        base_beads = []
        for key, atom_list in CG_bases.items():
            if key.startswith(resname[0]):
                coords_list = []
                elements_list = []
                for atom_name in atom_list:
                    if atom_name in atom_map:
                        coords_list.append(traj.xyz[:, atom_map[atom_name], :])
                        elements_list.append(top.atom(atom_map[atom_name]).element.symbol)
                if coords_list:
                    coords_arr = np.stack(coords_list, axis=1)
                    coms = np.array([compute_com(coords_arr[i], elements_list) for i in range(n_frames)])
                    base_beads.append(coms)

        # Combine all beads for residue
        if base_beads:
            full_beads = backbone + [bead[:, np.newaxis, :] for bead in base_beads]
        else:
            full_beads = backbone
            
        reshaped_beads = [b if b.ndim == 3 else b[:, np.newaxis, :] for b in full_beads]
        residue_beads = np.concatenate(reshaped_beads, axis=1)  # (n_frames, n_beads, 3)
        # print("Bead order:", ["P", "O5'", "C5'", "C4'", "C1'", "A1", "A2"][:len(full_beads)])

        # Select quadruplets based on residue type
        if resname[0] in ['A', 'G']:  # Purines: 10 dihedrals
            dihedrals = compute_dihedrals(residue_beads, PURINE_QUADRUPLETS)
        else:  # Pyrimidines: 8 dihedrals (skip 0 and 1)
            dihedrals = compute_dihedrals(residue_beads, PYRIMIDINE_QUADRUPLETS)
            
        dihedral_trajs.append(dihedrals)

    # Combine dihedrals across residues
    dihedrals_combined = np.concatenate(dihedral_trajs, axis=1)  # (n_frames, n_dihedrals_total)
    
    # Extract start, end, and path
    start = dihedrals_combined[0]
    end = dihedrals_combined[-1]
    path = dihedrals_combined
    
    all_paths.append((start, end, path))

# Save results
# np.save("dihedral_paths.npy", all_paths, allow_pickle=True)
print(f"Processed {len(all_paths)} trajectories. Dihedrals per frame: {all_paths[0][2].shape[1]}")

# check that the dihedrals are consistent across the data samples
n_dihedrals = all_paths[0][2].shape[1]  # Get from the first sample's path
assert all(path.shape[1] == n_dihedrals for (_, _, path) in all_paths), \
       "Inconsistent dihedral counts!"

100%|██████████| 27/27 [00:20<00:00,  1.30it/s]

Processed 27 trajectories. Dihedrals per frame: 198





structure of output is as follows
n number of path files is len(all_paths)
    each element of all_paths holds 3 other elements
    1st element is the coordinates for the 1st frame
    2nd element is the coordinates for the last frame
    3rd element are the coordinates for every other frame

Each frame should hold 708 (atoms) * 3 (coordinates) elements corresponding to each atom 

In [6]:
len(all_paths) #length of each element should be number of frames in each path

test = all_paths[1] #1st element is 299 frames for that path
len(test) #length is 708 atoms * 3 xyz coords = 2124 elements

test2 =test[1]
len(test2)

330

### Building a PyTorch dataset from the MDCRD data

In [2]:
import torch
import numpy as np

def normalize_dihedrals(dihedrals):
    """
    Normalize dihedral angles from [-π, π] to [-1, 1].
    Args:
        dihedrals: Tensor or array of shape (..., n_dihedrals) in radians.
    Returns:
        Normalized dihedrals in [-1, 1].
    """
    return dihedrals / np.pi

# Example usage (for your dataset/dataloader):
# Assuming 'all_paths' contains (start, end, path) tuples with dihedrals in radians
normalized_paths = []
for start, end, path in all_paths:
    norm_start = normalize_dihedrals(start)  # (n_dihedrals,)
    norm_end = normalize_dihedrals(end)      # (n_dihedrals,)
    norm_path = normalize_dihedrals(path)    # (n_frames, n_dihedrals)
    normalized_paths.append((norm_start, norm_end, norm_path))

In [3]:
# While training a model, we typically want to pass samples in batches and reshuffle the data at every epoch to reduce model overfitting
# DataLoader is an iterable that abstracts this complexity in an easy API.
from Landscape_DDPM import MolecularPathDataset, collate_paths
from torch.utils.data import DataLoader

dataset = MolecularPathDataset(normalized_paths) # use all_paths not all atom coordinates
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_paths)

In [3]:
# how to access the dataset
sample = dataset[26]
print("Start shape:", sample['start'].shape)
print("End shape:", sample['end'].shape)
print("Path shape:", sample['path'].shape)

Start shape: torch.Size([396])
End shape: torch.Size([396])
Path shape: torch.Size([447, 396])


In [15]:
test = collate_paths(dataset)
print("Path shape:", test['path'].shape)
print(test['mask'])

Path shape: torch.Size([27, 769, 2124])
tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.]])


In [None]:
sample = next(iter(dataloader))
F = sample['start'].shape[-1]
n_atoms = F // 3

396
132


In [7]:
from Landscape_DDPM import GaussianDiffusion, Trainer, UNet

# define the U-net structure
#n_atoms = 708
#F = n_atoms * 3
#sample = next(iter(dataloader))
#F = all_paths[0][2].shape[1]
#n_atoms = F // 3

# Get dihedral dimensions from the first sample
sample = next(iter(dataloader))
n_dihedrals = sample['path'].shape[-1]  # Shape is (B, T, n_dihedrals)

model = UNet(
    input_dim=F,           # atoms × 3 coordinates
    base_dim=64,
    dim_mults=(1, 2, 2, 4),
    time_emb_dim=128,
    out_dim=None         # same as in_dim by default
)


diffusion_model = GaussianDiffusion(
    model,                        # U-net model
    timesteps = 100,             # number of diffusion steps 
    loss_type='periodic'
)

trainer = Trainer(
    diffusion=diffusion_model,           # Your GaussianDiffusion instance
    dataloader=dataloader,               # From your earlier collate_paths function
    ema_decay=0.995,
    learning_rate=1e-5,
    results_folder='C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Diffusion_model/Models',
    save_name='molecular_path_diffusion_test.pt',
    use_amp=False
)
trainer.train(num_epochs=10)

Epoch 1/10: 100%|██████████| 4/4 [00:01<00:00,  3.07it/s, loss=65.8]
Epoch 2/10: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s, loss=65.6]
Epoch 3/10: 100%|██████████| 4/4 [00:01<00:00,  3.20it/s, loss=65.5]
Epoch 4/10: 100%|██████████| 4/4 [00:01<00:00,  3.18it/s, loss=65.6]
Epoch 5/10: 100%|██████████| 4/4 [00:01<00:00,  2.88it/s, loss=65.6]
Epoch 6/10: 100%|██████████| 4/4 [00:01<00:00,  2.77it/s, loss=65.5]
Epoch 7/10: 100%|██████████| 4/4 [00:01<00:00,  3.02it/s, loss=65.7]
Epoch 8/10: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s, loss=65.8]
Epoch 9/10: 100%|██████████| 4/4 [00:01<00:00,  3.05it/s, loss=65.7]
Epoch 10/10: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s, loss=65.7]

Model saved to C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Diffusion_model/Models\molecular_path_diffusion_test.pt





In [8]:
import torch
from Landscape_DDPM import GaussianDiffusion, Trainer, UNet
# Load model state dict
# define the U-net structure
# for all atom coordinate model
#n_atoms = 708
#F = n_atoms * 3

#for CG model
#sample = next(iter(dataloader))
#F = sample['start'].shape[-1]
#n_atoms = F // 3

# Get dihedral dimensions from the first sample
sample = next(iter(dataloader))
n_dihedrals = sample['path'].shape[-1]  # Shape is (B, T, n_dihedrals)

model = UNet(
    input_dim=F,           # atoms × 3 coordinates
    base_dim=64,
    dim_mults=(1, 2, 2, 4),
    time_emb_dim=128,
    out_dim=None         # same as in_dim by default
)

checkpoint = torch.load("C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Diffusion_model/Models/molecular_path_diffusion_test.pt")
model.load_state_dict(checkpoint['ema'])

# Load diffusion wrapper
diffusion = GaussianDiffusion(model, timesteps=100)

In [None]:
# FOR ALL ATOM COORDINATE MODEL
import mdtraj as md
import torch
# Configuration
topology_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'

traj = md.load_mdcrd('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/path5.mdcrd' , top=topology_file) #load the trajectory into python

coords = traj.xyz #get the xyz (Cartersian) coordinates of the trajectories as a numpy array
flattened = coords.reshape(coords.shape[0], -1) #need to flatten the frames into 1D for diffusion U-net model to accept

#convert into tensors
start = torch.tensor(flattened[0]).float()
end = torch.tensor(flattened[-1]).float()
#add extra 1 as model expects batch and I only want 1 sample
start = start.unsqueeze(0)  # (1, F)
end = end.unsqueeze(0)      # (1, F)
print("start shape:", start.shape)  # Should be (1, F)
print("end shape:", end.shape)      # Should be (1, F)

#generate path
generated_path = diffusion.sample(model, start, end, frames=50, device = 'cpu')  # (1, 50, F)

In [None]:
# FOR CG MODEL
import mdtraj as md
import numpy as np
import torch

# Config
topology_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'
PO5_BOND_LENGTH = 0.1615  # nm
atomic_weights = {"C": 12.011, "N": 14.007, "O": 16.000}
CG_bases = {
    "A1": ["N7", "N9", "C4", "C5", "C8"],
    "A2": ["N1", "C2", "N3", "C4", "C5", "N6", "C6"],
    "C1": ["N1", "C2", "N3", "N4", "C4", "C5", "C6", "O2"],
    "G1": ["N7", "N9", "C4", "C5", "C8"],
    "G2": ["N1", "N2", "C2", "N3", "C4", "C5", "C6", "O6"],
    "T1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"],
    "U1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"]
}

def compute_com(coords, elements):
    weights = np.array([atomic_weights[e] for e in elements])
    return np.average(coords, axis=0, weights=weights)

# === Load and Coarse-Grain a Single Path ===
traj = md.load_mdcrd('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/path5.mdcrd', top=topology_file)
top = traj.topology
n_frames = traj.n_frames

cg_beads = []
# main CG processing
for res in top.residues: # loop over each residue 
    atom_map = {a.name: a.index for a in res.atoms} # for each atom in the residue build a dictonary to map atom names to their indicies used later 

    # backbone bead extraction
    if res.index == 0: # special case for 1st nucleotide: estimate P from HO5' and O5'
        ho5 = traj.xyz[:, atom_map["HO5'"], :] # extract coordinate positions of HO5'
        o5 = traj.xyz[:, atom_map["O5'"], :] # extract coordinate positions of O5'
        direction = o5 - ho5 # find direction vector between HO5' and O5'
        unit_vec = direction / np.linalg.norm(direction, axis=1, keepdims=True) # normalize the direction vector by dividing it with its length to remove the magnitude (length) of the vector so we only have the direction in the form of a unit vector
                                                                                # np.linalg.norm is used to calculate the Euclidean length of the direction vector 
        p_bead = o5 + unit_vec * PO5_BOND_LENGTH # extending vector to match P-O5' bond length
    else:
        p_bead = traj.xyz[:, atom_map["P"], :] # if is not 1st atom then just extract the P

    # extract coordinates of specific atoms for CG beads
    # outputs 5 arrays with (n_frames, 3 coordinates)
    o5_bead = traj.xyz[:, atom_map["O5'"], :] 
    c5_bead = traj.xyz[:, atom_map["C5'"], :]
    c4_bead = traj.xyz[:, atom_map["C4'"], :]
    c1_bead = traj.xyz[:, atom_map["C1'"], :]

    # collect backbone by setting up a list of 5 arrays 
    backbone = [p_bead, o5_bead, c5_bead, c4_bead, c1_bead]  # (5, n_frames, 3 coordiantes)
    
    # Base COMs
    resname = res.name.strip() # extract the name of the residue 
    base_coms = []
    # base centre of mass bead calculations
    for key, atom_list in CG_bases.items(): # extract the base names and atom lists from the defined dictionary setup before
        if key.startswith(resname[0]): # resname is the current residue's name (e.g., "ADE", "CYT"). resname[0] would be the base letter: "A", "C", "G", "T", or "U". 
                                        # checks if the CG base type (eg: "A1", "A2") matches the nucleotide
            coords_list = [] # initialize lists to store the coordinates and element names
            elements_list = []
            for atom_name in atom_list: # loop over each atom in the list of predefined CG base atoms 
                if atom_name in atom_map: # check if the name of the atom is in the atom map of the residue
                    coords_list.append(traj.xyz[:, atom_map[atom_name], :]) # extract and append the coordinates of that atom for every frame. creates a list of arrays of shape (n_frames, 3)
                    elements_list.append(top.atom(atom_map[atom_name]).element.symbol) # extract the element symbol (e.g., 'C', 'N') for that atom index and appends it to a list
            if len(coords_list) > 0: # proceed only if we successfully collected atoms from the residue
                coords_arr = np.stack(coords_list, axis=1)  # join the list of arrays together to create a shape of (n_frames, n_atoms, 3)
                coms = np.array([compute_com(coords_arr[i], elements_list) for i in range(n_frames)]) # for each frame of the path compute the COM for the nucleotide, returns an array of shape (n_frames, 3) — 1 or 2 COM per frame
                base_coms.append(coms[:, np.newaxis, :]) # save value to list

    if base_coms is not None: # if we have a base bead add it to the full CG set
        base_beads_cat = np.concatenate(base_coms, axis=1)  #  concatenate the beads together so it can be concatenated into the full bead set later
        full_bead_set = backbone + [base_beads_cat]
    else: # if no base bead is calculated it's just the backbone set
        full_bead_set = backbone

    # merge all beads into a single array per residue into shape (n_frames, N_beads, 3)
    # Each bead is a NumPy array of shape (n_frames, 3): coordinates across trajectory frames for a single bead
    reshaped_beads = [b if b.ndim == 3 else b[:, np.newaxis, :] for b in full_bead_set] # check that all bead arrays are shape (n_frames, N_beads, 3)
                                                                                        # If the bead has shape (n_frames, 3) (i.e., a single bead) insert a new axis at position 1 → becomes (n_frames, 1, 3).
    residue_beads = np.concatenate(reshaped_beads, axis=1) #  concatenates beads along axis 1 (the bead dimension), outputs (n_frames, total_beads_per_residue, 3 coordiantes)

    cg_beads.append(residue_beads) # append to list

# Join together the coarse-grained coordinates from each residue in the molecule along the bead dimension.
cg_traj = np.concatenate(cg_beads, axis=1)  # (n_frames, total_beads, 3)

# Flatten each frame from 2D (total_beads, 3) to 1D (total_beads * 3) so each frame becomes a single vector of coordinates.
cg_flat = cg_traj.reshape(n_frames, -1)  # (n_frames, total_beads * 3)

# Extract start and end structure and convert to tensors for inference
start = torch.tensor(cg_flat[0], dtype=torch.float32).unsqueeze(0)
end = torch.tensor(cg_flat[-1], dtype=torch.float32).unsqueeze(0)

# run inference 
generated_path = diffusion.sample(model, start, end, frames=50, device='cpu')  # (1, 50, F)


In [10]:
import mdtraj as md
import numpy as np
import os
from tqdm import tqdm

# Configuration
trajectory_folder = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/'
prmtop_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'
PO5_BOND_LENGTH = 0.1615  # nm

# Atomic weights for COM calculations
atomic_weights = {"C": 12.011, "N": 14.007, "O": 16.000}

# Nucleotide base definitions
CG_bases = {
    "A1": ["N7", "N9", "C4", "C5", "C8"],
    "A2": ["N1", "C2", "N3", "C4", "C5", "N6", "C6"],
    "C1": ["N1", "C2", "N3", "N4", "C4", "C5", "C6", "O2"],
    "G1": ["N7", "N9", "C4", "C5", "C8"],
    "G2": ["N1", "N2", "C2", "N3", "C4", "C5", "C6", "O6"],
    "T1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"],
    "U1": ["N1", "C2", "N3", "C4", "C5", "C6", "O2", "O4"]
}

def compute_dihedrals(coords, quadruplets):
    """
    Compute dihedral angles from bead coordinates.
        - coords: A numpy array of shape (n_frames, n_beads, 3) containing 3D coordinates of beads across multiple frames.
        - quadruplets: A list of tuples (i, j, k, l), that defines the indices of 4 beads needed to compute a dihedral angle.
    """
    angles = [] # list to store computed dihedrals
    for (i, j, k, l) in quadruplets: # iterate over each set of 4 bead indices (i, j, k, l) in the predefined quadruplets
        # extract the coordinates of the 4 beads
        p1 = coords[:, i]
        p2 = coords[:, j]
        p3 = coords[:, k]
        p4 = coords[:, l]

        # Calculates the vectors connecting consecutive beads I.E the bond vectors 
        b1 = p2 - p1 # vector from bead i to j.
        b2 = p3 - p2 # vector from bead j to k.
        b3 = p4 - p3 # vector from bead k to l.

        # cross product to find normals to the two planes
        n1 = np.cross(b1, b2) # normal to the plane formed by (p1, p2, p3).
        n2 = np.cross(b2, b3) # normal to the plane formed by (p2, p3, p4).

        # Normalizes n1, n2, and b2 to unit vectors to avoid scaling artifacts and remove lengths, 1e-8 is added to prevent division-by-zero and floating-point instability when normalizing near-zero vectors
        n1_norm = np.linalg.norm(n1, axis=1, keepdims=True) + 1e-8 
        n2_norm = np.linalg.norm(n2, axis=1, keepdims=True) + 1e-8
        b2_norm = np.linalg.norm(b2, axis=1, keepdims=True) + 1e-8

        n1 /= n1_norm
        n2 /= n2_norm
        b2 /= b2_norm

        # Compute dihedral using atan2
        x = np.sum(n1 * n2, axis=1) # dot product of n1 and n2 (cosine of the angle between normals).
        y = np.sum(np.cross(n1, n2) * b2, axis=1) # Projection of the cross product (n1 * n2) onto b2 (sine of the angle).
        angle = np.arctan2(y, x) # Uses arctan2(y, x) to compute the dihedral angle in the range [-π, π].
        angles.append(angle) # append to storage list
        
    return np.stack(angles, axis=1) # convert the list of angles (one per quadruplet) into a 2D NumPy array of shape (n_frames, n_quadruplets).

def compute_com(coords, elements):
    """Compute center of mass from coordinates and elements."""
    weights = np.array([atomic_weights[e] for e in elements])
    return np.average(coords, axis=0, weights=weights)

# Quadruplet definitions for dihedrals (0-based indexing)
PURINE_QUADRUPLETS = [  # For A, G (10 dihedrals)
    (3, 4, 5, 6),   # 0: R4-R1-A1/G1-A2/G2
    (3, 5, 6, 4),   # 1: R4-A1/G1-A2/G2-R1
    (2, 3, 4, 5),   # 2: C-R4-R1-X1
    (0, 3, 4, 5),   # 3: P-R4-R1-X1
    (2, 3, 0, 1),   # 4: C-R4-P-O
    (4, 3, 0, 1),   # 5: R1-R4-P-O
    (1, 2, 3, 0),   # 6: O-C-R4-P
    (1, 2, 3, 4),   # 7: O-C-R4-R1
    (0, 1, 2, 3),   # 8: P-O-C-R4
    (3, 0, 1, 2)    # 9: R4-P-O-C
]

PYRIMIDINE_QUADRUPLETS = [  # For C, T, U (8 dihedrals, skipping 0 and 1)
    (2, 3, 4, 5),   # 2: C-R4-R1-X1 
    (0, 3, 4, 5),   # 3: P-R4-R1-X1 
    (2, 3, 0, 1),   # 4: C-R4-P-O 
    (4, 3, 0, 1),   # 5: R1-R4-P-O 
    (1, 2, 3, 0),   # 6: O-C-R4-P 
    (1, 2, 3, 4),   # 7: O-C-R4-R1 
    (0, 1, 2, 3),   # 8: P-O-C-R4 
    (3, 0, 1, 2)    # 9: R4-P-O-C 
]

# Process each trajectory
traj = md.load_mdcrd('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/path5.mdcrd', top=topology_file)
top = traj.topology
n_frames = traj.n_frames

dihedral_trajs = []  # Store dihedrals per residue per frame

for res in top.residues:
    atom_map = {a.name: a.index for a in res.atoms}
    resname = res.name.strip()

    # Backbone bead calculation
    if res.index == 0:
        ho5 = traj.xyz[:, atom_map["HO5'"], :]
        o5 = traj.xyz[:, atom_map["O5'"], :]
        direction = o5 - ho5
        unit_vec = direction / (np.linalg.norm(direction, axis=1, keepdims=True) + 1e-8)
        p_bead = o5 + unit_vec * PO5_BOND_LENGTH
    else:
        p_bead = traj.xyz[:, atom_map["P"], :]

    o5_bead = traj.xyz[:, atom_map["O5'"], :]
    c5_bead = traj.xyz[:, atom_map["C5'"], :]
    c4_bead = traj.xyz[:, atom_map["C4'"], :]
    c1_bead = traj.xyz[:, atom_map["C1'"], :]
    backbone = [p_bead, o5_bead, c5_bead, c4_bead, c1_bead] # bead indices are fixed here (indices 0-4) which allows for predefined quadruplets 

    # Base beads
    base_beads = []
    for key, atom_list in CG_bases.items():
        if key.startswith(resname[0]):
            coords_list = []
            elements_list = []
            for atom_name in atom_list:
                if atom_name in atom_map:
                    coords_list.append(traj.xyz[:, atom_map[atom_name], :])
                    elements_list.append(top.atom(atom_map[atom_name]).element.symbol)
            if coords_list:
                coords_arr = np.stack(coords_list, axis=1)
                coms = np.array([compute_com(coords_arr[i], elements_list) for i in range(n_frames)])
                base_beads.append(coms)

    # Combine all beads for residue
    if base_beads:
        full_beads = backbone + [bead[:, np.newaxis, :] for bead in base_beads]
    else:
        full_beads = backbone
        
    reshaped_beads = [b if b.ndim == 3 else b[:, np.newaxis, :] for b in full_beads]
    residue_beads = np.concatenate(reshaped_beads, axis=1)  # (n_frames, n_beads, 3)
    # print("Bead order:", ["P", "O5'", "C5'", "C4'", "C1'", "A1", "A2"][:len(full_beads)])

    # Select quadruplets based on residue type
    if resname[0] in ['A', 'G']:  # Purines: 10 dihedrals
        dihedrals = compute_dihedrals(residue_beads, PURINE_QUADRUPLETS)
    else:  # Pyrimidines: 8 dihedrals (skip 0 and 1)
        dihedrals = compute_dihedrals(residue_beads, PYRIMIDINE_QUADRUPLETS)
        
    dihedral_trajs.append(dihedrals)

# Combine dihedrals across residues
dihedrals_combined = np.concatenate(dihedral_trajs, axis=1)  # (n_frames, 10 * n_residues)

# Extract start and end structure and convert to tensors for inference
start = torch.tensor(dihedrals_combined[0], dtype=torch.float32).unsqueeze(0)
end = torch.tensor(dihedrals_combined[-1], dtype=torch.float32).unsqueeze(0)

# run inference 
generated_path = diffusion.sample(model, start, end, frames=50, device='cpu')  # (1, 50, F)


In [None]:
import numpy as np
generated_path = generated_path.squeeze(0)  # Shape: (50, F), remove the batch number
generated_xyz = generated_path.reshape(-1, n_atoms, 3) # turn into format for save_xyz (number of frames, number of atoms, XYZ coordinates for each atom)
generated_coordinates = generated_xyz.numpy() # convert to numpy array
generated_coordinates.shape # output multidimensional array (number of frames, number of atoms, XYZ coordinates for each atom)
                            # so frame 1 will contain 708 entries of 3 coordinates

(50, 708, 3)

In [11]:
import mdtraj as md

# Configuration
topology = md.load_prmtop('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop')
traj = md.load_mdcrd('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/path5.mdcrd' , top=topology) #load the trajectory into python

topology.atoms
atom_elements = [atom.element.symbol for atom in topology.atoms]

In [12]:
from Landscape_DDPM import cg_save_xyz
cg_save_xyz(generated_path, 'generated_path_test.xyz', topology_file='C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop')

AssertionError: Each bead must have 3 coordinates (x, y, z)