Scripts for plotting energy distribtions from trajectory files

In [None]:
import torch
import numpy as np
#import openmm
import pickle
import mdtraj
from openmmtools import integrators
from simtk import unit
from simtk import openmm
import pandas as pd

In [None]:
device = "cuda:3" if torch.cuda.is_available() else "cpu"
dtype = torch.float32

ctx = torch.zeros([], device=device, dtype=dtype)

In [None]:
def find_stride(fname, top_file, fpath_stub):
    if 'coupled' in fname:
        traj = mdtraj.load(f'Coupled_scheme/Trajectories/{fname}.dcd',top=top_file)
    else:
        traj = mdtraj.load(f'{fpath_stub}/Trajectories/{fname}.dcd',top=top_file)

    print(traj.n_frames/10000)
    return len(traj), traj.n_frames/10000

In [None]:
def get_pot_energies(fname, molecule, average=False):
    #Function to get the potential energies of all conformations in a trajectory dcd file - uses bgflow's bridge to OpenMM
    if molecule == 'pro':
        fpath_stub = 'proline'
        top_file = f'{fpath_stub}/cis_pro.pdb'
        system_file = f'{fpath_stub}/noconstraints_xmlsystem.txt'
        n_atoms = 26
    elif molecule == 'ala2':
        fpath_stub = 'Alanine_dipeptide'
        top_file = 'Alanine_dipeptide/ala2_fromURL.pdb'
        system_file = 'Alanine_dipeptide/ala2_noconstraints_system.txt'
        n_atoms = 22
    else:
        print('molecule not recognised')
    

    with open(system_file) as f:
        xml = f.read()
    system = openmm.XmlSerializer.deserialize(xml)

    from bgflow.distribution.energy.openmm import OpenMMBridge, OpenMMEnergy
    from openmmtools import integrators
    from simtk import unit
    import pickle

    try:
        pickleFile = open(f'{fpath_stub}/parameters/parameters{fname}.pkl','rb')
        parametersdict = pickle.load(pickleFile)
        temperature = parametersdict['Temperature']
        collision_rate = parametersdict['Collision rate']
        timestep = parametersdict['Timestep']
    except:
        print(f'{fname} no parameters found')
        if 'coupled' in fname or '1000' in fname:
            print(f'{fname}using 1000K')
            temperature = 1000.0 * unit.kelvin
            collision_rate = 1.0 / unit.picosecond
            timestep = 2.0 * unit.femtosecond
        else:
            print(f'{fname} using 300K')
            temperature = 300.0 * unit.kelvin
            collision_rate = 1.0 / unit.picosecond
            timestep = 2.0 * unit.femtosecond

    
    
    integrator = integrators.LangevinIntegrator(temperature=temperature,collision_rate=collision_rate,timestep=timestep)
    energy_bridge = OpenMMBridge(system, integrator, n_workers=1)
    target_energy = OpenMMEnergy(n_atoms*3, energy_bridge)
    data_len, stride = find_stride(fname, top_file, fpath_stub)

    if 'coupled' in fname:
        trajectory = mdtraj.load(f'Coupled_scheme/Trajectories/{fname}.dcd',top=top_file,stride=stride)
    else:
        trajectory = mdtraj.load(f'{fpath_stub}/Trajectories/{fname}.dcd',top=top_file,stride=stride)
    coordinates = trajectory.xyz[0:10000]
    data = torch.tensor(coordinates.reshape(-1,n_atoms*3)).to(ctx)
    energies = target_energy.energy(data).cpu().detach().numpy()
    if average == True:
        return np.median(energies)
    if max(energies) > 800:
        data_cap = max(np.percentile(energies,80),400)
    else:
        data_cap = max(energies)[0]
    data_min = min(energies)[0]

    return energies, data_len, data_cap, data_min

In [None]:
def get_energy_fromtxt(fname, column, molecule):
    #Function to read energies from a txt file generated by an OpenMM simulation
    if molecule == 'pro':
        fpath_stub = 'proline'
    elif molecule == 'ala2':
        fpath_stub = 'Alanine_dipeptide'

    if 'coupled' in fname:
        data = pd.read_csv(f'Coupled_scheme/Trajectories/{fname}.txt')
    else:
        data = pd.read_csv(f'{fpath_stub}/Trajectories/{fname}.txt')
    #data.name = f'{fname}'
    if len(data) > 10000:
        step = len(data)//10000
    else:
        step = 1
    return data[column].to_numpy()[::step], len(data), max(data[column].to_numpy()[::step]), min(data[column].to_numpy()[::step])

In [None]:
def average_energy(fname, target_energy, molecule):
    #Function to determine the average potential energy of a set of conformations in a trajectory (dcd file)
    if molecule == 'pro':
        fpath_stub = 'proline'
        top_file = f'{fpath_stub}/cis_pro.pdb'
        system_file = f'{fpath_stub}/noconstraints_xmlsystem.txt'
        n_atoms = 26
    elif molecule == 'ala2':
        fpath_stub = 'Alanine_dipeptide'
        top_file = 'Alanine_dipeptide/ala2_fromURL.pdb'
        system_file = 'Alanine_dipeptide/ala2_noconstraints_system.txt'
        n_atoms = 22
    else:
        print('molecule not recognised')
    trajectory = mdtraj.load(f'Alanine_dipeptide/Trajectories/{fname}.dcd',top=top_file,stride=stride)
    coordinates = trajectory.xyz
    data = torch.tensor(coordinates.reshape(-1,66)).to(ctx)
    energies = target_energy.energy(data).cpu().detach().numpy()
    return np.average(energies)

In [None]:
def plot_temp(ax, temps, bins, index):
    #Specific plotting function for temperature - includes average lines
    ax.hist(temps, bins=bins, alpha = (1-index/5))
    min_ylim, max_ylim = ax.get_ylim()
    ax.axvline(temps.mean(), color='k',linestyle='dashed')
    ax.text(0.05, (0.9-index/10), f'Mean: {temps.mean():.2f} K', color=f'C{index}', transform=ax.transAxes )
    
    if index == 0:
        ax.text(0.4, 0.9, f'Thermostat aim: 300 K', color='r', transform=ax.transAxes)
        ax.axvline(300, color='r',linestyle='dashed')
    ax.set_xlabel('Instantaneous temperature [$K$]')
    ax.set_ylabel(f"Count   [#Samples / {len(temps)}]")
    ax.legend(bbox_to_anchor = (1.04,1), loc='upper left')

In [None]:
def plot_energy(ax,
        *fnamesandlabels, molecule='ala2', binsize=10, log=False, total_energy=False, kinetic_energy=False, potential_energy=False, temperature=False, lefthandplot=True, bottomrowplot=True):
    #Function to plot energy distributions of trajectory conformations, options include total energy, kinetic energy, potential energy or temperature from txt file or (default) potential energy from only dcd file
    energies = {}
    lengths = {}
    upper_cuts = {}
    lower_cuts = {}
    maxes = {}
    mins = {}
    
    if total_energy == True:
        text = True
        datacol = 'Total Energy (kJ/mole)'
        x_label = 'Total Energy   [$kJ$ $mol^{-1}$]'
    elif kinetic_energy == True:
        text = True
        datacol = 'Kinetic Energy (kJ/mole)'
        x_label = 'Kinetic Energy   [$kJ$ $mol^{-1}$]'
    elif potential_energy == True:
        text = True
        datacol = 'Potential Energy (kJ/mole)'
        x_label = 'Potential Energy   [$kJ$ $mol^{-1}$]'
    elif temperature == True:
        text = True
        datacol = 'Temperature (K)'
        x_label = 'Instantaneous temperature [$K$]'
    else:
        text = False
        x_label = 'Potential Energy   [$k_B T$]'


    for fname,label in fnamesandlabels:
        if text == True:
            energies[fname], lengths[fname], maxes[fname], mins[fname] = get_energy_fromtxt(fname, datacol, molecule)
        else:
            energies[fname], lengths[fname], maxes[fname], mins[fname] = get_pot_energies(fname, molecule)


    data_len = min(lengths.values())
    print(f'data len {data_len}')
    
    binwidth = binsize
    if log == True:
        bins = np.geomspace(1, max(maxes.values()), 100) 
        print('bins', bins)
        print('min', min(mins.values()))
        print('max', max(maxes.values()))
        ax.set_xscale('log')
    else:
        bins = np.arange(min(mins.values()), max(maxes.values()) + binwidth ,binwidth)
  
    if temperature==True:
        for fname,label in fnamesandlabels:
            plot_temp(ax, energies[fname][0:data_len], bins, fnamesandlabels.index(fname))
    else:
        for fname,label in fnamesandlabels:
            alpha = (1-fnamesandlabels.index((fname,label))/6)
            ax.hist(energies[fname][0:data_len],  bins=bins, density=False, label=label, alpha=alpha)
            print(f'plotted {fnamesandlabels.index((fname,label))}')

    if bottomrowplot == True:
        ax.set_xlabel(f"{x_label}")
    if lefthandplot == True:
        ax.set_ylabel(f"Count   [#Samples / {data_len}]")

    if len(fnamesandlabels) == 1:
        ax.set_title(f'{label}')
    else:
        ax.legend(bbox_to_anchor = (1.04,1), loc='upper left')

In [None]:
#Example to plot figure 11 in thesis - trajectory files listed here are not available on github due to size restrictions
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2,1, figsize=(9,8), sharey=True,sharex=True)

plot_energy(axes[0], 
('300K_noconstr_long_1_samplestraj', 'Trained on constrained MD 1'),
('300K_noconstr_long_2_samplestraj', 'Trained on constrained MD 2'),
('300K_noconstr_long_3_samplestraj', 'Trained on constrained MD 3'),
molecule='ala2', binsize=7,) 

plot_energy(axes[1], 
('300KunconstrainedMD_noconstr_long_1_samplestraj', 'Trained on unconstrained MD 1'),
('300KunconstrainedMD_noconstr_long_2_samplestraj', 'Trained on unconstrained MD 2'),
('300KunconstrainedMD_noconstr_long_3_samplestraj', 'Trained on unconstrained MD 3'),
molecule='ala2', binsize=7,) 

plt.tight_layout()


In [None]:
#Figure 8
# plot_energy(axes[0], 
# ('TSF_noconstr_long_1_samplestraj', '1,000,000 frames Repeat 1'),
# ('TSFtraj_noconstr_long_2_samplestraj', '1,000,000 frames Repeat 2'),
# ('TSFtraj_noconstr_long_3_samplestraj', '1,000,000 frames Repeat 3'),
# molecule='ala2', binsize=7,) 
# plot_energy(axes[1], 
# ('TSFtrajstride10_noconstr_long_1_samplestraj', '100,000 frames Repeat 1'),
# ('TSFtrajstride10_noconstr_long_2_samplestraj', '100,000 frames Repeat 2'),
# ('TSFtrajstride10_noconstr_long_3_samplestraj', '100,000 frames Repeat 3'),
# molecule='ala2', binsize=7,) 

#Figure 9
# plot_energy(axes[0], 
# ('300K_noconstr_long_1_samplestraj', 'No constraints in training 1'),
# ('300K_noconstr_long_2_samplestraj', 'No constraints in training 2'),
# ('300K_noconstr_long_3_samplestraj', 'No constraints in training 3'),
# molecule='ala2', binsize=10, log=True) 
# plot_energy(axes[1], 
# ('300K_superposed_20000KLL_1', 'Constraints in training 1'),
# ('300K_superposed_20000KLL_2', 'Constraints in training 2'),
# ('300K_superposed_20000KLL_3', 'Constraints in training 3'),
# molecule='ala2', binsize=10E5, log=True) 