In [3]:
import pickle
import os
from perses.annihilation.rest import RESTTopologyFactory
from perses.annihilation.lambda_protocol import RESTState
from openmmtools.states import SamplerState, ThermodynamicState, CompoundThermodynamicState
from openmmtools import cache, utils
from perses.dispersed.utils import configure_platform
cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())
from simtk import openmm, unit
import math
from openmmtools.constants import kB
from openmmtools import mcmc, multistate
import argparse
import copy
from perses.dispersed import feptasks
import numpy as np
from simtk.openmm import app
from openmmforcefields.generators import SystemGenerator
import pickle
import mdtraj as md
import itertools

conducting subsequent work with the following platform: CUDA


Create solvate alanine vanilla system

In [4]:
pdb = app.PDBFile("../../input/ala_vacuum.pdb")

forcefield_files = ['amber14/protein.ff14SB.xml', 'amber14/tip3p.xml']
barostat = openmm.MonteCarloBarostat(1.0 * unit.atmosphere, 298 * unit.kelvin, 50)
system_generator = SystemGenerator(forcefields=forcefield_files,
                               barostat=barostat,
                               forcefield_kwargs={'removeCMMotion': False,
                                                    'ewaldErrorTolerance': 1e-4,
                                                    'constraints' : app.HBonds,
                                                    'hydrogenMass' : 4 * unit.amus},
                                periodic_forcefield_kwargs={'nonbondedMethod': app.PME},
                                small_molecule_forcefield='gaff-2.11',
                                nonperiodic_forcefield_kwargs=None, 
                                   molecules=None, 
                                   cache=None)
modeller = app.Modeller(pdb.topology, pdb.positions)
modeller.addSolvent(system_generator.forcefield, model='tip3p', padding=9*unit.angstroms, ionicStrength=0.15*unit.molar)
solvated_topology = modeller.getTopology()
solvated_positions = modeller.getPositions()

# Canonicalize the solvated positions: turn tuples into np.array
positions = unit.quantity.Quantity(value=np.array([list(atom_pos) for atom_pos in solvated_positions.value_in_unit_system(unit.md_unit_system)]), unit=unit.nanometers)
sys = system_generator.create_system(solvated_topology)



DEBUG:openmmforcefields.system_generators:Trying GAFFTemplateGenerator to load gaff-2.11


In [3]:
for i in range(sys.getNumForces()):
    print(sys.getForce(i))

<simtk.openmm.openmm.HarmonicBondForce; proxy of <Swig Object of type 'OpenMM::HarmonicBondForce *' at 0x2b40d5c03420> >
<simtk.openmm.openmm.HarmonicAngleForce; proxy of <Swig Object of type 'OpenMM::HarmonicAngleForce *' at 0x2b40d5c03840> >
<simtk.openmm.openmm.PeriodicTorsionForce; proxy of <Swig Object of type 'OpenMM::PeriodicTorsionForce *' at 0x2b40d5c03420> >
<simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x2b40d5c03840> >
<simtk.openmm.openmm.MonteCarloBarostat; proxy of <Swig Object of type 'OpenMM::MonteCarloBarostat *' at 0x2b40d5c03420> >


Create a subclass of the RESTTopologyFactory so that we can play around with the functions and test the energies

In [5]:
class REST2(RESTTopologyFactory):
    _known_forces = {'HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce', 'MonteCarloBarostat'}

    def __init__(self, system, solute_region, use_dispersion_correction=False):
        """
        arguments
            system : simtk.openmm.system
                system that will be rewritten
            solute_region : simtk.openmm.system
                subset solute region
            use_dispersion_correction : bool, default False
                whether to use a dispersion correction

        Properties
        ----------
            REST_system : simtk.openmm.system
                the REST2-implemented system
        """
        self._use_dispersion_correction = use_dispersion_correction
        self._num_particles = system.getNumParticles()
        self._og_system = system
        self._og_system_forces = {type(force).__name__ : force for force in self._og_system.getForces()}
        self._out_system_forces = {}
        self._solute_region = solute_region
        self._solvent_region = list(set(range(self._num_particles)).difference(set(self._solute_region)))

        assert set(solute_region).issubset(set(range(self._num_particles))), f"the solute region is not a subset of the system particles"
        self._nonbonded_method = self._og_system_forces['NonbondedForce'].getNonbondedMethod()
        self._out_system = openmm.System()

        for particle_idx in range(self._num_particles):
            particle_mass = self._og_system.getParticleMass(particle_idx)
            hybrid_idx = self._out_system.addParticle(particle_mass)

        if "MonteCarloBarostat" in self._og_system_forces.keys():
            barostat = copy.deepcopy(self._og_system_forces["MonteCarloBarostat"])
            self._out_system.addForce(barostat)
            self._out_system_forces[barostat.__class__.__name__] = barostat

        # Copy over the box vectors:
        box_vectors = self._og_system.getDefaultPeriodicBoxVectors()
        self._out_system.setDefaultPeriodicBoxVectors(*box_vectors)

        self._og_system_exceptions = self._generate_dict_from_exceptions(self._og_system_forces['NonbondedForce'])

        # Check that there are no unknown forces in the new and old systems:
        for system_name in ['og']:
            force_names = getattr(self, '_{}_system_forces'.format(system_name)).keys()
            unknown_forces = set(force_names) - set(self._known_forces)
            if len(unknown_forces) > 0:
                raise ValueError("Unkown forces {} encountered in {} system" % (unknown_forces, system_name))
        _logger.info("No unknown forces.")

        self._handle_constraints()

        self._add_bond_force_terms()
        self._add_bonds()

        self._add_angle_force_terms()
        self._add_angles()

        self._add_torsion_force_terms()
        self._add_torsions()

        self._add_nonbonded_force_terms()
        self._add_nonbondeds()
        
    def _add_nonbondeds(self):
        self._solute_exceptions, self._interexceptions = [], []

        #the output nonbonded force _only_ contains solvent atoms (the rest are zeroed); same with exceptions
        """
        First, handle the NonbondedForce in the out_system
        """
        og_nb_force = self._og_system_forces['NonbondedForce']
        for particle_idx in range(self._num_particles):
            q, sigma, epsilon = og_nb_force.getParticleParameters(particle_idx)
            identifier = self.get_identifier(particle_idx)

            if identifier == 1:
                self._out_system_forces['NonbondedForce'].addParticle(q, sigma, epsilon)
                self._out_system_forces['CustomNonbondedForce'].addParticle([q, sigma, epsilon, identifier])
            else:
                self._out_system_forces['NonbondedForce'].addParticle(q*0.0, sigma, epsilon*0.0)
                self._out_system_forces['CustomNonbondedForce'].addParticle([q, sigma, epsilon, identifier])

        #add appropriate interaction group
        solute_ig, solvent_ig = set(self._solute_region), set(self._solvent_region)
        self._out_system_forces['CustomNonbondedForce'].addInteractionGroup(solute_ig, solvent_ig)
        self._out_system_forces['CustomNonbondedForce'].addInteractionGroup(solute_ig, solute_ig)

        #handle exceptions
        for exception_idx in range(og_nb_force.getNumExceptions()):
            p1, p2, chargeProd, sigma, epsilon = og_nb_force.getExceptionParameters(exception_idx)
            identifier = self.get_identifier([p1, p2])
            if identifier == 1:
                self._out_system_forces['NonbondedForce'].addException(p1, p2, chargeProd, sigma, epsilon)
                self._out_system_forces['CustomNonbondedForce'].addExclusion(p1, p2) #maintain consistent exclusions w/ exceptions
            elif identifier == 0:
                self._solute_exceptions.append([p1, p2, [chargeProd, sigma, epsilon]])
                self._out_system_forces['NonbondedForce'].addException(p1, p2, chargeProd*0.0, sigma, epsilon*0.0)
                self._out_system_forces['CustomNonbondedForce'].addExclusion(p1, p2) #maintain consistent exclusions w/ exceptions
            elif identifier == 2:
                self._interexceptions.append([p1, p2, [chargeProd, sigma, epsilon]])
                self._out_system_forces['NonbondedForce'].addException(p1, p2, chargeProd*0.0, sigma, epsilon*0.0)
                self._out_system_forces['CustomNonbondedForce'].addExclusion(p1, p2) #maintain consistent exclusions w/ exceptions

        #now add the CustomBondForce for exceptions
        exception_force = self._out_system_forces['CustomExceptionForce']

        for solute_exception_term in self._solute_exceptions:
            p1, p2, [chargeProd, sigma, epsilon] = solute_exception_term
            if (chargeProd.value_in_unit_system(unit.md_unit_system) != 0.0) or (epsilon.value_in_unit_system(unit.md_unit_system) != 0.0):
                identifier = 0
                exception_force.addBond(p1, p2, [chargeProd, sigma, epsilon, identifier])

        for interexception_term in self._interexceptions:
            p1, p2, [chargeProd, sigma, epsilon] = interexception_term
            if (chargeProd.value_in_unit_system(unit.md_unit_system) != 0.0) or (epsilon.value_in_unit_system(unit.md_unit_system) != 0.0):
                identifier = 2
                exception_force.addBond(p1, p2, [chargeProd, sigma, epsilon, identifier])

In [6]:
# Create REST system
factory = REST2(sys, solute_region=list(range(0, 22)))
REST_system = factory.REST_system


getting into the nb version


In [7]:
REST_system.getForces()


[<simtk.openmm.openmm.MonteCarloBarostat; proxy of <Swig Object of type 'OpenMM::MonteCarloBarostat *' at 0x2ba54b032ea0> >,
 <simtk.openmm.openmm.CustomBondForce; proxy of <Swig Object of type 'OpenMM::CustomBondForce *' at 0x2ba54b032f00> >,
 <simtk.openmm.openmm.CustomAngleForce; proxy of <Swig Object of type 'OpenMM::CustomAngleForce *' at 0x2ba54a60c2a0> >,
 <simtk.openmm.openmm.CustomTorsionForce; proxy of <Swig Object of type 'OpenMM::CustomTorsionForce *' at 0x2ba54a60c240> >,
 <simtk.openmm.openmm.NonbondedForce; proxy of <Swig Object of type 'OpenMM::NonbondedForce *' at 0x2ba54abb12d0> >,
 <simtk.openmm.openmm.CustomNonbondedForce; proxy of <Swig Object of type 'OpenMM::CustomNonbondedForce *' at 0x2ba54abb15d0> >,
 <simtk.openmm.openmm.CustomBondForce; proxy of <Swig Object of type 'OpenMM::CustomBondForce *' at 0x2ba54abb1900> >]

In [8]:
for i in range(REST_system.getForce(5).getNumParticles()):
    print(i, REST_system.getForce(5).getParticleParameters(i))

0 (0.1123, 0.2649532787749369, 0.06568879999999999, 0.0)
1 (-0.3662, 0.3399669508423535, 0.4577296, 0.0)
2 (0.1123, 0.2649532787749369, 0.06568879999999999, 0.0)
3 (0.1123, 0.2649532787749369, 0.06568879999999999, 0.0)
4 (0.5972, 0.3399669508423535, 0.359824, 0.0)
5 (-0.5679, 0.2959921901149463, 0.87864, 0.0)
6 (-0.4157, 0.3249998523775958, 0.7112800000000001, 0.0)
7 (0.2719, 0.10690784617684071, 0.06568879999999999, 0.0)
8 (0.0337, 0.3399669508423535, 0.4577296, 0.0)
9 (0.0823, 0.2471353044121301, 0.06568879999999999, 0.0)
10 (-0.1825, 0.3399669508423535, 0.4577296, 0.0)
11 (0.0603, 0.2649532787749369, 0.06568879999999999, 0.0)
12 (0.0603, 0.2649532787749369, 0.06568879999999999, 0.0)
13 (0.5973, 0.3399669508423535, 0.359824, 0.0)
14 (-0.5679, 0.2959921901149463, 0.87864, 0.0)
15 (0.0603, 0.2649532787749369, 0.06568879999999999, 0.0)
16 (-0.4157, 0.3249998523775958, 0.7112800000000001, 0.0)
17 (0.2719, 0.10690784617684071, 0.06568879999999999, 0.0)
18 (-0.149, 0.3399669508423535, 0.45

965 (0.417, 1.0, 0.0, 1.0)
966 (0.417, 1.0, 0.0, 1.0)
967 (-0.834, 0.3150752406575124, 0.635968, 1.0)
968 (0.417, 1.0, 0.0, 1.0)
969 (0.417, 1.0, 0.0, 1.0)
970 (-0.834, 0.3150752406575124, 0.635968, 1.0)
971 (0.417, 1.0, 0.0, 1.0)
972 (0.417, 1.0, 0.0, 1.0)
973 (-0.834, 0.3150752406575124, 0.635968, 1.0)
974 (0.417, 1.0, 0.0, 1.0)
975 (0.417, 1.0, 0.0, 1.0)
976 (-0.834, 0.3150752406575124, 0.635968, 1.0)
977 (0.417, 1.0, 0.0, 1.0)
978 (0.417, 1.0, 0.0, 1.0)
979 (-0.834, 0.3150752406575124, 0.635968, 1.0)
980 (0.417, 1.0, 0.0, 1.0)
981 (0.417, 1.0, 0.0, 1.0)
982 (-0.834, 0.3150752406575124, 0.635968, 1.0)
983 (0.417, 1.0, 0.0, 1.0)
984 (0.417, 1.0, 0.0, 1.0)
985 (-0.834, 0.3150752406575124, 0.635968, 1.0)
986 (0.417, 1.0, 0.0, 1.0)
987 (0.417, 1.0, 0.0, 1.0)
988 (-0.834, 0.3150752406575124, 0.635968, 1.0)
989 (0.417, 1.0, 0.0, 1.0)
990 (0.417, 1.0, 0.0, 1.0)
991 (-0.834, 0.3150752406575124, 0.635968, 1.0)
992 (0.417, 1.0, 0.0, 1.0)
993 (0.417, 1.0, 0.0, 1.0)
994 (-0.834, 0.3150752406575

Get the energy of the REST system at the thermodynamic state where energy is scaled to 600K

In [7]:
T_min = 298.0 *unit.kelvin
T = 600.0 * unit.kelvin

# Create thermodynamic state
lambda_zero_alchemical_state = RESTState.from_system(REST_system)
thermostate = ThermodynamicState(REST_system, temperature=T_min)
compound_thermodynamic_state = CompoundThermodynamicState(thermostate,
                                                          composable_states=[lambda_zero_alchemical_state])

# Set alchemical parameters
beta_0 = 1 / (kB * T_min)
beta_m = 1 / (kB * T)
compound_thermodynamic_state.set_alchemical_parameters(beta_0, beta_m)

integrator = openmm.VerletIntegrator(1.0*unit.femtosecond)
context = compound_thermodynamic_state.create_context(integrator)
context.setPositions(positions)
sampler_state = SamplerState.from_context(context)
compound_thermodynamic_state.reduced_potential(sampler_state)


51.291644849804186

Get energy of vanilla system (after manually scaling energies) at same thermostate as above

In [9]:
system = sys

# Compute energy for non-RESTified system
protein = list(range(0, 22))
environment = list(range(22, system.getNumParticles()))
protein_scaling = beta_m / beta_0
inter_scaling = np.sqrt(beta_m / beta_0)

# Scale the terms in the bond force appropriately
bond_force = system.getForce(0)
for bond in range(bond_force.getNumBonds()):
    p1, p2, length, k = bond_force.getBondParameters(bond)
    if p1 in protein and p2 in protein:
        bond_force.setBondParameters(bond, p1, p2, length, k * protein_scaling)
    elif (p1 in protein and p2 in environment) or (p1 in environment and p2 in protein):
        bond_force.setBondParameters(bond, p1, p2, length, k * inter_scaling)

# Scale the terms in the angle force appropriately
angle_force = system.getForce(1)
for angle_index in range(angle_force.getNumAngles()):
    p1, p2, p3, angle, k = angle_force.getAngleParameters(angle_index)
    if p1 in protein and p2 in protein and p3 in protein:
        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * protein_scaling)
    elif set([p1, p2, p3]).intersection(set(protein)) != set() and set([p1, p2, p3]).intersection(set(environment)) != set():
        angle_force.setAngleParameters(angle_index, p1, p2, p3, angle, k * inter_scaling)

# Scale the terms in the torsion force appropriately
torsion_force = system.getForce(2)
for torsion_index in range(torsion_force.getNumTorsions()):
    p1, p2, p3, p4, periodicity, phase, k = torsion_force.getTorsionParameters(torsion_index)
    if p1 in protein and p2 in protein and p3 in protein and p4 in protein:
        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * protein_scaling)
    elif set([p1, p2, p3, p4]).intersection(set(protein)) != set() and set([p1, p2, p3, p4]).intersection(set(environment)) != set():
        torsion_force.setTorsionParameters(torsion_index, p1, p2, p3, p4, periodicity, phase, k * inter_scaling)

# Scale the exceptions in the nonbonded force appropriately
nb_force = system.getForce(3)
for nb_index in range(nb_force.getNumExceptions()):
    p1, p2, chargeProd, sigma, epsilon = nb_force.getExceptionParameters(nb_index)
    if p1 in protein and p2 in protein:
        nb_force.setExceptionParameters(nb_index, p1, p2, protein_scaling * chargeProd, sigma, protein_scaling * epsilon)
    elif (p1 in protein and p2 in environment) or (p1 in environment and p2 in protein):
        nb_force.setExceptionParameters(nb_index, p1, p2, inter_scaling * chargeProd, sigma, inter_scaling * epsilon)

exception_pairs = [tuple(sorted([nb_force.getExceptionParameters(nb_index)[0], nb_force.getExceptionParameters(nb_index)[1]])) for nb_index in range(nb_force.getNumExceptions())]
solute_pairs = set([tuple(sorted(pair)) for pair in list(itertools.product(protein, protein))])
# Scale nonbonded interactions for solute region by adding exceptions for all pairs of atoms 
for pair in list(solute_pairs):
    p1 = pair[0]
    p2 = pair[1]
    p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
    p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
    if p1 != p2:
        if pair not in exception_pairs:
#             print(pair)
            nb_force.addException(p1, p2, p1_charge*p2_charge*protein_scaling, 0.5*(p1_sigma+p2_sigma), np.sqrt(p1_epsilon*p2_epsilon)*protein_scaling)
        
# Scale nonbonded interactions for inter region by adding exceptions for all pairs of atoms 
# traj = md.Trajectory(np.array(positions), md.Topology.from_openmm(solvated_topology))
# nearby_waters = md.compute_neighbors(traj, 0.4, list(range(0,22)), haystack_indices=list(range(22, solvated_topology.getNumAtoms())))[0]
for pair in list(itertools.product(protein, environment)):
    p1 = pair[0]
    p2 = int(pair[1]) # otherwise, will be a numpy int
    p1_charge, p1_sigma, p1_epsilon = nb_force.getParticleParameters(p1)
    p2_charge, p2_sigma, p2_epsilon = nb_force.getParticleParameters(p2)
    nb_force.addException(p1, p2, p1_charge*p2_charge*inter_scaling, 0.5*(p1_sigma+p2_sigma), np.sqrt(p1_epsilon*p2_epsilon)*inter_scaling)

# Save energy
thermostate = ThermodynamicState(system, temperature=T_min)
integrator = openmm.VerletIntegrator(1.0*unit.femtosecond)
context = thermostate.create_context(integrator)
context.setPositions(positions)
sampler_state = SamplerState.from_context(context)
thermostate.reduced_potential(sampler_state)

-65.13044057879681