# Diffusion Model implementation for energy landscape exploration

Testing ground for diffusion model implementation using pytorch implementation

In [3]:
#load in libraries
import torch
import numpy as np
import pandas as pd
import mdtraj

In [None]:
device = torch.device("cuda")

# define the U-net structure
model = Unet(
    dim = 32,                   
    dim_mults = (1, 2, 2, 4 ),   
    groups = 8 
)

In [None]:
# define diffusion model
device = torch.device("cuda")

# define the U-net structure
model = Unet(
    dim = 32,                   
    dim_mults = (1, 2, 2, 4 ),   
    groups = 8 
)


TypeError: GaussianDiffusion.__init__() missing 1 required keyword-only argument: 'image_size'

In [None]:
#set training parameters
trainer = Trainer(
    diffusion,                                   # diffusion model
    folder = 'traj_AIB9',                        # folder of trajectories
    train_batch_size = 128,                      # training batch size
    train_lr = 1e-5,                             # learning rate
    train_num_steps = 2000000,                   # total training steps
    gradient_accumulate_every = 1,               # gradient accumulation steps
    ema_decay = 0.995,                           # exponential moving average decay
    op_number = op_num,
    fp16 = False                                 # turn on mixed precision training with apex
)

In [None]:
# start training
#trainer.train()

In [None]:
# load trained model
model_id = 30     
trainer.load(model_id) 

## LLM Assisted Model

Data Handling

    Current: The code uses a preprocessed .npy trajectory dataset.

    Modify: Parse .mdcrd files into usable trajectory tensors.

        Convert to .npy or torch.Tensor sequences.

        Normalize or align structures (e.g. RMSD alignment).

        Split into tuples: (start_frame, end_frame, full_path_sequence).

2. Input Preparation

    You’ll now want to encode the start and end frames as conditional inputs.

        Option A: Concatenate start + end frames to the noisy input.

        Option B: Encode start + end through a separate encoder and use in cross-attention or FiLM layers in the U-Net.

3. Modify the UNet + Diffusion Model

    Input shape: Instead of only receiving noisy intermediate x_t, the model receives:

    model(noisy_frame_t, timestep_t, start_frame, end_frame)

    Diffusion target: Instead of denoising a full trajectory from pure noise, the model learns to interpolate the path that connects the two known endpoints.

    Loss: Compare predicted frame at time t to ground-truth frame in the trajectory using MSE/L2.

4. Sampling (Inference)

    After training:

        Input start_frame and end_frame.

        Initialize the intermediate frames as noise.

        Use reverse diffusion steps to denoise iteratively — conditioning on both endpoints — to generate a discrete transition path.

### MDCRD data preparation

MDTraj python library to handle MDCRD files
Corresponding topology file (prmtop) required

Next we need to extract coordinates
 typically get an array of shape [num_frames, num_atoms, 3]. 
Let’s normalize and reshape:

goal is to predict transition paths between two endpoint structures, you’ll want to create training samples like:
(start_frame, end_frame, trajectory_segment)

In [1]:
import mdtraj as md
import numpy as np
import os

# Configuration
trajectory_folder = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/'       # folder containing all .mdcrd files
topology_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'  # shared topology

all_paths = []

for file in os.listdir(trajectory_folder): #loop over all mdcrd files in the directory
    if file.endswith(".mdcrd"): #check if the file ends with .mdcrd to make sure that only mdcrd files are selected
        filepath = os.path.join(trajectory_folder, file) #get the full filepath of the path file
        print(f"Processing: {file}")
        
        traj = md.load_mdcrd(filepath, top=topology_file) #load the trajectory into python

        coords = traj.xyz #get the xyz (Cartersian) coordinates of the trajectories as a numpy array
                          #all of the distances in the Trajectory are stored in nanometers. The time unit is picoseconds. Angles are stored in degrees (not radians).
        
        flattened = coords.reshape(coords.shape[0], -1) #need to flatten the frames into 1D for diffusion U-net model to accept
        #goal is to predict paths between 2 endpoints so training samples should have a start and end point as well
        #create tuples like (start_frame, end_frame, path) for training data
        start = flattened[0]    # shape: (n_atoms * 3,), get the 1st frame of the path
        end = flattened[-1]    # shape: (n_atoms * 3,), get the last frame of the path
        path = flattened # the entire path
        all_paths.append((start, end, path)) #append the tuple to a new list


Processing: path106.mdcrd
Processing: path136.mdcrd
Processing: path139.mdcrd
Processing: path140.mdcrd
Processing: path141.mdcrd
Processing: path142.mdcrd
Processing: path20.mdcrd
Processing: path22.mdcrd
Processing: path25.mdcrd
Processing: path34.mdcrd
Processing: path35.mdcrd
Processing: path39.mdcrd
Processing: path40.mdcrd
Processing: path41.mdcrd
Processing: path43.mdcrd
Processing: path44.mdcrd
Processing: path5.mdcrd
Processing: path51.mdcrd
Processing: path55.mdcrd
Processing: path56.mdcrd
Processing: path57.mdcrd
Processing: path63.mdcrd
Processing: path65.mdcrd
Processing: path81.mdcrd
Processing: path86.mdcrd
Processing: path89.mdcrd
Processing: path91.mdcrd


structure of output is as follows
n number of path files is len(all_paths)
    each element of all_paths holds 3 other elements
    1st element is the coordinates for the 1st frame
    2nd element is the coordinates for the last frame
    3rd element are the coordinates for every other frame

Each frame should hold 708 elements corresponding to each atom 

In [20]:
len(all_paths[0]) #length of each element should be number of frames in each path

test = all_paths[0] #1st element is 299 frames for that path
len(test[1]) #length is 708 atoms * 3 xyz coords = 2124 elements

2124

### Building a PyTorch dataset from the MDCRD data

In [2]:
# While training a model, we typically want to pass samples in batches and reshuffle the data at every epoch to reduce model overfitting
# DataLoader is an iterable that abstracts this complexity in an easy API.
from Landscape_DDPM import MolecularPathDataset, collate_paths
from torch.utils.data import DataLoader

dataset = MolecularPathDataset(all_paths)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_paths)

In [3]:
# how to access the dataset
sample = dataset[0]
print("Start shape:", sample['start'].shape)
print("End shape:", sample['end'].shape)
print("Path shape:", sample['path'].shape)

Start shape: torch.Size([2124])
End shape: torch.Size([2124])
Path shape: torch.Size([299, 2124])


In [4]:
from Landscape_DDPM import GaussianDiffusion, Trainer, UNet

# define the U-net structure
n_atoms = 708
F = n_atoms * 3

model = UNet(
    input_dim=F,           # atoms × 3 coordinates
    base_dim=64,
    dim_mults=(1, 2, 2, 4),
    time_emb_dim=128,
    out_dim=None         # same as in_dim by default
)


diffusion_model = GaussianDiffusion(
    model,                        # U-net model
    timesteps = 100,             # number of diffusion steps 
)

trainer = Trainer(
    diffusion=diffusion_model,           # Your GaussianDiffusion instance
    dataloader=dataloader,               # From your earlier collate_paths function
    ema_decay=0.995,
    learning_rate=1e-5,
    results_folder='C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Diffusion_model/Models',
    save_name='molecular_path_diffusion.pt',
    use_amp=True
)
trainer.train(num_epochs=10)

Epoch 1/10: 100%|██████████| 4/4 [00:15<00:00,  3.80s/it, loss=2.25e+3]
Epoch 2/10: 100%|██████████| 4/4 [00:14<00:00,  3.54s/it, loss=2.22e+3]
Epoch 3/10: 100%|██████████| 4/4 [00:15<00:00,  3.77s/it, loss=2.2e+3] 
Epoch 4/10: 100%|██████████| 4/4 [00:15<00:00,  3.80s/it, loss=2.18e+3]
Epoch 5/10: 100%|██████████| 4/4 [00:15<00:00,  3.79s/it, loss=2.18e+3]
Epoch 6/10: 100%|██████████| 4/4 [00:14<00:00,  3.70s/it, loss=2.17e+3]
Epoch 7/10: 100%|██████████| 4/4 [00:15<00:00,  3.76s/it, loss=2.18e+3]
Epoch 8/10: 100%|██████████| 4/4 [00:14<00:00,  3.54s/it, loss=2.17e+3]
Epoch 9/10: 100%|██████████| 4/4 [00:12<00:00,  3.22s/it, loss=2.17e+3]
Epoch 10/10: 100%|██████████| 4/4 [00:14<00:00,  3.60s/it, loss=2.16e+3]

Model saved to C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Diffusion_model/Models\molecular_path_diffusion.pt





In [22]:
import torch
from Landscape_DDPM import GaussianDiffusion, Trainer, UNet
# Load model state dict
# define the U-net structure
n_atoms = 708
F = n_atoms * 3

model = UNet(
    input_dim=F,           # atoms × 3 coordinates
    base_dim=64,
    dim_mults=(1, 2, 2, 4),
    time_emb_dim=128,
    out_dim=None         # same as in_dim by default
)

checkpoint = torch.load("C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Diffusion_model/Models/molecular_path_diffusion.pt")
model.load_state_dict(checkpoint['ema'])

# Load diffusion wrapper
diffusion = GaussianDiffusion(model, timesteps=100)

In [23]:
import mdtraj as md
import torch
# Configuration
topology_file = 'C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop'

traj = md.load_mdcrd('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/path5.mdcrd' , top=topology_file) #load the trajectory into python

coords = traj.xyz #get the xyz (Cartersian) coordinates of the trajectories as a numpy array
flattened = coords.reshape(coords.shape[0], -1) #need to flatten the frames into 1D for diffusion U-net model to accept

#convert into tensors
start = torch.tensor(flattened[0]).float()
end = torch.tensor(flattened[-1]).float()
#add extra 1 as model expects batch and I only want 1 sample
start = start.unsqueeze(0)  # (1, F)
end = end.unsqueeze(0)      # (1, F)
print("start shape:", start.shape)  # Should be (1, F)
print("end shape:", end.shape)      # Should be (1, F)

#generate path
generated_path = diffusion.sample(model, start, end, frames=50, device = 'cpu')  # (1, 50, F)

start shape: torch.Size([1, 2124])
end shape: torch.Size([1, 2124])


In [24]:
import numpy as np
generated_path = generated_path.squeeze(0)  # Shape: (50, F), remove the batch number
generated_xyz = generated_path.reshape(-1, n_atoms, 3) # turn into format 
generated_coordinates = generated_xyz.numpy()
generated_coordinates.shape # output multidimensional array (number of frames, number of atoms, XYZ coordinates for each atom)
                            # so frame 1 will contain 708 entries of 3 coordinates

(50, 708, 3)

In [26]:
def save_xyz(coordinates, filename, atom_names=None):
    """
    Save a trajectory as an XYZ file.
    
    Parameters:
    - coordinates: (T, N, 3) array of frames
    - filename: output filename
    - atom_names: list of atom names (optional, defaults to 'C')
    """
    frames, atoms, _ = coordinates.shape
    if atom_names is None:
        atom_names = ['C'] * atoms  # Default to carbon

    with open(filename, 'w') as file:
        for f in range(frames): # loop over every frame
            file.write(str(atoms) + "\n") # number of atoms
            file.write(f"Frame {f+1}\n") # comment line just saying which frame the following information is, +1 to avoid 0 indexing
            for i in range(atoms): # loop over every atom
                x, y, z = coordinates[f, i] # extract coordinates from each frame and atom
                file.write(f"{atom_names[i]} {x:.5f} {y:.5f} {z:.5f}\n") # write each coordinate the .5f is python string formatting to only show 5 decimal places

In [None]:
import mdtraj as md

# Configuration
topology = md.load_prmtop('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/PK1/coords.prmtop')
traj = md.load_mdcrd('C:/Users/ckcho/OneDrive/Desktop/KCL Bioinformatics/Research_project/Paths/path5.mdcrd' , top=topology_file) #load the trajectory into python

topology.atoms
atom_elements = [atom.element.symbol for atom in topology.atoms]

list

In [27]:
save_xyz(generated_coordinates, 'generated_path.xyz', atom_elements)