In [1]:
from coddiwomple.openmm.integrators import OMMLI
from coddiwomple.openmm.propagators import OMMBIP
from coddiwomple.openmm.reporters import OpenMMReporter
from openmmtools.states import ThermodynamicState, SamplerState
from simtk import unit
from copy import deepcopy
import os
import sys
import numpy as np

from openmmtools import cache, utils
from perses.dispersed.utils import check_platform, configure_platform
cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())



conducting subsequent work with the following platform: CPU
conducting subsequent work with the following platform: CPU
conducting subsequent work with the following platform: CPU


In [357]:
class Integrator(OMMLI):
    def __init__(self,
                 temperature=300.0 * unit.kelvin,
                 collision_rate=1.0 / unit.picoseconds,
                 timestep=1.0 * unit.femtoseconds,
                 splitting="V R O R F",
                 constraint_tolerance=1e-6,
                 **kwargs):
        """Create a Langevin integrator with the prescribed operator splitting.

        arguments
            splitting : string, default: "V R O R"
                Sequence of "R", "V", "O" (and optionally "{", "}", "V0", "V1", ...) substeps to be executed each timestep.
                Forces are only used in V-step. Handle multiple force groups by appending the force group index
                to V-steps, e.g. "V0" will only use forces from force group 0. "V" will perform a step using all forces.
            temperature : np.unit.Quantity compatible with kelvin, default: 300.0*unit.kelvin
               Fictitious "bath" temperature
            collision_rate : np.unit.Quantity compatible with 1/picoseconds, default: 1.0/unit.picoseconds
               Collision rate
            timestep : np.unit.Quantity compatible with femtoseconds, default: 1.0*unit.femtoseconds
               Integration timestep
            constraint_tolerance : float, default: 1.0e-8
                Tolerance for constraint solver
        """
        super().__init__(temperature,
                         collision_rate,
                         timestep,
                         splitting,
                         constraint_tolerance,
                         **kwargs)
        

    def _add_V_step(self, force_group="0"):
        """Deterministic velocity update, using only forces from force-group fg.

        arguments
            force_group : str, optional, default="0"
               Force group to use for this step
        """
        self.addComputeSum("old_ke", self._kinetic_energy)

        # update velocities
        if self._mts:
            self.addComputePerDof("v", "v + ((dt / {}) * moddi / m)".format(self._force_group_nV[force_group], force_group))
        else:
            self.addComputePerDof("v", "v + (dt / {}) * moddi / m".format(self._force_group_nV["0"]))

        self.addConstrainVelocities()


        self.addComputeSum("new_ke", self._kinetic_energy)
        self.addComputeGlobal("shadow_work", "shadow_work + (new_ke - old_ke)")
    
    def _add_F_step(self):
        """
        add an moddi update step
        """
        self.addComputePerDof('moddi', 'f')
        
    
    def _add_variables(self):
        super()._add_variables()
        self.addPerDofVariable('moddi', 0)
    
    def _add_integrator_steps(self):
        """
        Add the steps to the integrator--this can be overridden to place steps around the integration.
        """
        super()._add_integrator_steps()
        #self.addUpdateContextState()
    
    @property
    def _step_dispatch_table(self):
        dispatch_table = super()._step_dispatch_table
        dispatch_table['F'] = (self._add_F_step, False) #add a moddi variable
        return dispatch_table
        
    
    
        
        
        

In [367]:
class Propagator(OMMBIP):
    def __init__(self,
                 openmm_pdf_state,
                 openmm_pdf_state_subset,
                 subset_indices_map,
                 integrator,
                 ani_handler,
                 context_cache=None,
                 reassign_velocities=True,
                 n_restart_attempts=0,
                 reporter = None,
                 **kwargs):
        """
        arguments
            openmm_pdf_state : openmmtools.states.ThermodynamicState
                the pdf state of the propagator
            openmm_pdf_state_subset : openmmtools.states.ThermodynamicState
                the pdf state of the atom subset
            subset_indices_map : dict
                dict of {openmm_pdf_state atom_index : openmm_pdf_state_subset atom index}
            integrator : openmm.Integrator
                integrator of dynamics
            ani_handler : ANI1_force_and_energy
                handler for ani forces and potential energy
            context_cache : openmmtools.cache.ContextCache, optional
                The ContextCache to use for Context creation. If None, the global cache
                openmmtools.cache.global_context_cache is used (default is None).
            reassign_velocities : bool, optional
                If True, the velocities will be reassigned from the Maxwell-Boltzmann
                distribution at the beginning of the move (default is False).
            n_restart_attempts : int, optional
                When greater than 0, if after the integration there are NaNs in energies,
                the move will restart. When the integrator has a random component, this
                may help recovering. On the last attempt, the ``Context`` is
                re-initialized in a slower process, but better than the simulation
                crashing. An IntegratorMoveError is raised after the given number of
                attempts if there are still NaNs. 
            reporter : coddiwomple.openmm.reporter.OpenMMReporter, default None
                a reporter object to write trajectories
        """
        super().__init__(openmm_pdf_state,
                 integrator,
                 context_cache,
                 reassign_velocities,
                 n_restart_attempts)
        #create a pdf state for the subset indices (usually a vacuum system)
        self.pdf_state_subset = openmm_pdf_state_subset
        assert self.pdf_state_subset.temperature == self.pdf_state.temperature, f"the temperatures of the pdf states do not match"
        
        #create a dictionary for subset indices
        self._subset_indices_map = subset_indices_map
        
        #create an ani handler attribute that can be referenced
        self.ani_handler = ani_handler
        
        #create a context for the subset atoms that can be referenced
        self.context_subset, _ = cache.global_context_cache.get_context(self.pdf_state_subset)
        
        #create a reporter for the accumulated works
        self._state_works = {}
        self._state_works_counter = 0
        
        #create a reporter
        self._write_trajectory = False if reporter is None else True
        self.reporter=reporter
        if self._write_trajectory:
            from coddiwomple.particles import Particle
            self.particle = Particle(0)
        else:
            self.particle = None
        
    def _before_integration(self, *args, **kwargs):
        particle_state = args[0] #define the particle state
        n_iterations = args[1] #define the number of iterations
        
        self._current_state_works = [] #define an interim (auxiliary) list that will track the thermodynamic work of the current application
        self._current_state_works.append(0.0) #the first incremental work is always 0 since the importance function is identical to the first target distribution (i.e. fully interacting MM)
        
        self._iteration = 0.0 #define the first iteration as 0
        self._n_iterations = n_iterations #the number of iterations in the protocol is equal to the number of steps in the application
        
        #update the particle state and the particle state subset
        particle_state.update_from_context(self.context, ignore_velocities=True) #update the particle state from the context
        self.particle_state_subset = SamplerState(positions = particle_state.positions[list(self._subset_indices_map.keys())]) #create a particle state from the subset context
        self.particle_state_subset.apply_to_context(self.context_subset, ignore_velocities=True) #apply the subset particle state to its context
        self.particle_state_subset.update_from_context(self.context_subset, ignore_velocities=True) #update the subset particle state from its context to updated the potential energy
           
        #get the reduced potential
        reduced_potential = self._compute_hybrid_potential(_lambda = self._iteration / self._n_iterations, particle_state = particle_state)
        perturbed_reduced_potential = self._compute_hybrid_potential(_lambda = (self._iteration + 1.0) / self._n_iterations, particle_state = particle_state)
        self._current_state_works.append(self._current_state_works[-1] + (perturbed_reduced_potential - reduced_potential))
        #print(f"reduced_potential: {reduced_potential}; perturbed_potential: {perturbed_reduced_potential}")
        
        #make a new force object
        mm_force_matrix = self._compute_hybrid_forces(_lambda = (self._iteration + 1.0) / self._n_iterations, particle_state = particle_state).value_in_unit_system(unit.md_unit_system)
        self.integrator.setPerDofVariableByName('moddi', mm_force_matrix) 
        
        #report
        if self._write_trajectory:
            self.particle.update_state(particle_state)
            self.reporter.record([self.particle])
    
    def _during_integration(self, *args, **kwargs):
        particle_state = args[0]
        self._iteration += 1.0
        
        
        #update the particle state and the particle state subset
        particle_state.update_from_context(self.context, ignore_velocities=True) #update the particle state from the context
        self.particle_state_subset.positions = particle_state.positions[list(self._subset_indices_map.keys())] #update the particle subset positions appropriately
        self.particle_state_subset.apply_to_context(self.context_subset, ignore_velocities=True) #apply the subset particle state to its context
        self.particle_state_subset.update_from_context(self.context_subset, ignore_velocities=True) #update the subset particle state from its context to updated the potential energy
        
        #get the reduced potential
        if self._iteration < self._n_iterations:
            reduced_potential = self._compute_hybrid_potential(_lambda = self._iteration / self._n_iterations, particle_state = particle_state)
            perturbed_reduced_potential = self._compute_hybrid_potential(_lambda = (self._iteration + 1.0) / self._n_iterations, particle_state = particle_state)
            self._current_state_works.append(self._current_state_works[-1] + (perturbed_reduced_potential - reduced_potential))
            
            #and create a new modified force
            mm_force_matrix = self._compute_hybrid_forces(_lambda = (self._iteration + 1.0) / self._n_iterations, particle_state = particle_state).value_in_unit_system(unit.md_unit_system)
            self.integrator.setPerDofVariableByName('moddi', mm_force_matrix) 
        else:
            #we are done
            pass
        
        if self._write_trajectory:
            self.particle.update_state(particle_state)
            if self._iteration == self._n_iterations:
                self.reporter.record([self.particle], save_to_disk=True)
            else:
                self.reporter.record([self.particle], save_to_disk=False)
            
            
        
    def _after_integration(self, *args, **kwargs):
        self._state_works[self._state_works_counter] = deepcopy(self._current_state_works)
        self._state_works_counter += 1
        
        if self._write_trajectory:
            self.reporter.reset()
        

    def _compute_hybrid_potential(self,_lambda, particle_state):
        """
        function to compute the hybrid reduced potential defined as follows:
        U(x_rec, x_lig) = u_mm,rec(x_rec) - lambda*u_mm,lig(x_lig) + lambda*u_ani,lig(x_lig)
        """
        reduced_potential = (self.pdf_state.reduced_potential(particle_state)
                             - _lambda * self.pdf_state_subset.reduced_potential(self.particle_state_subset)
                             + _lambda * self.ani_handler.calculate_energy(self.particle_state_subset.positions) * self.pdf_state.beta)
        return reduced_potential
    
    def _compute_hybrid_forces(self, _lambda, particle_state):
        """
        function to compute a hybrid force matrix of shape num_particles x 3
        in the spirit of the _compute_hybrid_potential, we compute the forces in the following way
            F(x_rec, x_lig) = F_mm(x_rec, x_lig) - lambda * F_mm(x_lig) + lambda * F_ani(x_lig)
        """
        # get the complex mm forces
        state = self.context.getState(getForces=True)
        mm_force_matrix = state.getForces(asNumpy=True)

        # get the ligand mm forces
        subset_state = self.context_subset.getState(getForces=True)
        mm_force_matrix_subset = subset_state.getForces(asNumpy=True)

        # get the ligand ani forces
        coords = self.particle_state_subset.positions
        subset_ani_force_matrix, energie = self.ani_handler.calculate_force(coords)

        # now combine the ligand forces
        subset_force_matrix = _lambda * (subset_ani_force_matrix - mm_force_matrix_subset)

        # and append to the complex forces...
        mm_force_matrix[list(self._subset_indices_map.keys()), :] += subset_force_matrix

        return mm_force_matrix
    
    @property
    def state_works(self):
        return self._state_works

        
        

In [368]:
from pkg_resources import resource_filename
import pickle
vac_factory_filename = resource_filename('coddiwomple', '/data/perses_data/benzene_methylbenzene.vacuum.factory.pkl')
solvent_factory_filename = resource_filename('coddiwomple', '/data/perses_data/benzene_methylbenzene.solvent.factory.pkl')

with open(vac_factory_filename, 'rb') as f:
    vacuum_factory = pickle.load(f)

with open(solvent_factory_filename, 'rb') as f:
    solvent_factory = pickle.load(f)



In [369]:
import torchani
from execute import ANI1_force_and_energy
species_str = ''.join(
            [atom.element.symbol for atom in vacuum_factory._topology_proposal._old_topology.atoms()])
ani_handler = ANI1_force_and_energy(model = torchani.models.ANI1ccx(),
                                                 atoms=species_str,
                                                 platform='cpu',
                                                 temperature=300 * unit.kelvin)

In [370]:
integrator = Integrator()

DEBUG:openmm_propagators:initializing Integrator...
DEBUG:openmm_propagators:Integrator: metropolization is False
DEBUG:openmm_propagators:Integrator: successfully parsed splitting string
DEBUG:openmm_propagators:Integrator: adding global variables...
DEBUG:openmm_propagators:Integrator: adding integrator steps...
DEBUG:openmm_propagators:Integrator: adding substep functions...


In [371]:
from openmmtools.states import ThermodynamicState, SamplerState
from coddiwomple.openmm.reporters import OpenMMReporter
import mdtraj as md
openmm_pdf_state = ThermodynamicState(solvent_factory._topology_proposal._old_system, temperature = 300 * unit.kelvin)
openmm_pdf_state_subset = ThermodynamicState(vacuum_factory._topology_proposal._old_system, temperature = 300 * unit.kelvin)
reporter = OpenMMReporter('test', 'benzene', md.Topology.from_openmm(solvent_factory._topology_proposal._old_topology))
propagator = Propagator(openmm_pdf_state = openmm_pdf_state,
                 openmm_pdf_state_subset = openmm_pdf_state_subset,
                 subset_indices_map = {i: i for i in range(12)},
                 integrator = integrator,
                 ani_handler = ani_handler,
                 context_cache=None,
                 reassign_velocities=True,
                 n_restart_attempts=0,
                 reporter = reporter)

DEBUG:openmm_reporters:creating trajectory storage object...
DEBUG:openmm_propagators:successfully executed ABCMeta init.
DEBUG:openmm_propagators:successfully equipped integrator: Integrator
DEBUG:openmm_propagators:integrator printable: None


step      0 : allow forces to update the context state
step      1 : if(has_kT_changed = 1):
step      2 :    sigma <- sqrt(kT/m)
step      3 :    has_kT_changed <- 0
step      4 : end
step      5 : old_ke <- sum(0.5 * m * v * v)
step      6 : v <- v + (dt / 1) * moddi / m
step      7 : constrain velocities
step      8 : new_ke <- sum(0.5 * m * v * v)
step      9 : shadow_work <- shadow_work + (new_ke - old_ke)
step     10 : old_pe <- energy
step     11 : old_ke <- sum(0.5 * m * v * v)
step     12 : x <- x + ((dt / 2) * v)
step     13 : x1 <- x
step     14 : constrain positions
step     15 : v <- v + ((x - x1) / (dt / 2))
step     16 : constrain velocities
step     17 : new_pe <- energy
step     18 : new_ke <- sum(0.5 * m * v * v)
step     19 : shadow_work <- shadow_work + (new_ke + new_pe) - (old_ke + old_pe)
step     20 : old_ke <- sum(0.5 * m * v * v)
step     21 : v <- (a * v) + (b * sigma * gaussian)
step     22 : constrain velocities
step     23 : new_ke <- sum(0.5 * m * v * v)
s

In [372]:
solvent_eq_cache = resource_filename('coddiwomple', '/data/perses_data/benzene_methylbenzene.solvent.factory.pkl')

In [373]:
particle_state = SamplerState(positions = solvent_factory._old_positions, box_vectors = np.array(solvent_factory._old_system.getDefaultPeriodicBoxVectors()))
from perses.dispersed.feptasks import minimize
minimize(propagator.pdf_state, particle_state)

In [375]:
propagator.apply(particle_state, n_steps = 100, reset_integrator=True, apply_pdf_to_context=True)

(<openmmtools.states.SamplerState at 0x7f5ea44f4828>,
 {'_restorable__class_hash': 368641056.0,
  'kT': 2.4943417413660645,
  'a': 0.999000499833375,
  'b': 0.044699008184376096,
  'shadow_work': 14.152860867272466,
  'proposal_work': -76.9203169131124,
  'old_ke': 2293.6963482357933,
  'new_ke': 2293.69634823571,
  'old_pe': -15266.879245570948,
  'new_pe': -15263.147635303738,
  'accept': 0.0,
  'ntrials': 0.0,
  'nreject': 0.0,
  'naccept': 0.0,
  'has_kT_changed': 0.0})

In [356]:
len(propagator._state_works[0])

101

In [388]:
[atom.element.symbol for atom in reporter.md_topology.atoms][0]

'C'