This notebook is to create a `ray` actor to scale a hydration free energy (HFE) calculation in at least two ways:

1. create one actor to run the HFE calculation sequentially on the same resource
2. create three actors to run `decoupling vdw`, `annihilate elec`, and `recharging elec` on three different nodes

Dependencies:
- ray
- openmm
- pymbar

In [None]:
import os
import ray
import copy
import numpy as np

from pymbar import MBAR, timeseries

from simtk.unit import *
from simtk.openmm.app import (
    AmberPrmtopFile,
    AmberInpcrdFile,
    Simulation,
    HBonds,
    PME,
    DCDReporter,
    StateDataReporter,
    PDBFile,
)
from simtk.openmm import (
    XmlSerializer,
    LangevinIntegrator,
    MonteCarloBarostat,
    CustomNonbondedForce,
    CustomBondForce,
    NonbondedForce,
    Platform,
)

In [None]:
@ray.remote(num_cpus=2, num_gpus=1)
class HFE_actor:
    def __init__(self, system_xml, system_pdb, system_vac_xml, system_vac_pdb):
        self.system_pdb = PDBFile(system_pdb)
        self.system_vac = PDBFile(system_vac_pdb)
        self.topology_vac = self.system_vac.topology
        self.topology = self.system_pdb.topology
        self.positions = self.system_pdb.positions
        self.positions_vac = self.system_vac.positions
        self.platform = Platform.getPlatformByName("CUDA")
        self.platform.setPropertyDefaultValue("Precision", "mixed")
        self.solute = [
            atom.index for atom in self.topology.atoms() if atom.residue.name != "HOH"
        ]
        self.solvent = [
            atom.index for atom in self.topology.atoms() if atom.residue.name == "HOH"
        ]
        self.temperature = 298 * kelvin
        self.kT = self.temperature * BOLTZMANN_CONSTANT_kB * AVOGADRO_CONSTANT_NA
        with open(system_xml, "r") as f:
            self.system = XmlSerializer.deserialize(f.read())

        with open(system_vac_xml, "r") as f1:
            self.system_vac = XmlSerializer.deserialize(f1.read())
        print("This actor is allowed to use GPUs {}.".format(ray.get_gpu_ids()))

    @ray.method(num_returns=1)
    def add_elec_parm(self, this_system, lambda_value=1.0):
        """
        * Adds lambda scaling to both inter- and intramolecular electrostatic
          interactions in `NonbondedForce` for the molecule.

        Parameters
        ----------
        this_system: openmm.System
            The OpenMM system containing the nonbonded force.
        lambda_value: float
            The starting lambda value for the Elec scaling.
        """
        nonbonded = [
            force
            for force in this_system.getForces()
            if isinstance(force, NonbondedForce)
        ][0]

        nonbonded.addGlobalParameter("lambda_electrostatics", lambda_value)

        for atom_i in self.solute:
            charge, sigma, epsilon = nonbonded.getParticleParameters(atom_i)
            nonbonded.addParticleParameterOffset(
                "lambda_electrostatics", atom_i, charge, 0.0, 0.0
            )

        for exception_i in range(nonbonded.getNumExceptions()):
            (
                atom_i,
                atom_j,
                charge,
                sigma,
                epsilon,
            ) = nonbonded.getExceptionParameters(exception_i)
            if atom_i in self.solute and atom_j in self.solute:
                nonbonded.addExceptionParameterOffset(
                    "lambda_electrostatics", exception_i, charge, 0.0, 0.0
                )
        return this_system

    @ray.method(num_returns=1)
    def add_intra_LJ(self, this_system):
        """
        * move the Intramolecular LJ to `CustomBondForce`

        Parameters
        ----------
        this_system: openmm.System
            The OpenMM system containing the nonbonded force.
        """
        nonbonded = [
            force
            for force in this_system.getForces()
            if isinstance(force, NonbondedForce)
        ][0]

        # Intramolecular LJ
        intramol_LJ = CustomBondForce("4*epsilon*((sigma/r)^12 - (sigma/r)^6);")
        intramol_LJ.addPerBondParameter("sigma")
        intramol_LJ.addPerBondParameter("epsilon")

        for exception_i in range(nonbonded.getNumExceptions()):
            (
                atom_i,
                atom_j,
                charge,
                sigma,
                epsilon,
            ) = nonbonded.getExceptionParameters(exception_i)
            if atom_i in self.solute and atom_j in self.solute:
                intramol_LJ.addBond(atom_i, atom_j, [sigma, epsilon])

        this_system.addForce(intramol_LJ)
        return this_system

    @ray.method(num_returns=1)
    def turn_off_nonbonded(self, this_system):
        """
        * Turn off the LJ parameters in `NonbondedForce` for the molecule

        Parameters
        ----------
        this_system: openmm.System
            The OpenMM system containing the nonbonded force.
        molecule: list
            A list of atom indices (starts from 0) for the molecule of interest.
        """
        nonbonded = [
            force
            for force in this_system.getForces()
            if isinstance(force, NonbondedForce)
        ][0]

        # Turn off Elec and LJ parameters
        for atom_i in self.solute:
            charge, sigma, epsilon = nonbonded.getParticleParameters(atom_i)
            nonbonded.setParticleParameters(atom_i, 0.0, sigma, 0.0)

        # Turn off Elec and LJ Exceptions
        for exception_i in range(nonbonded.getNumExceptions()):
            atom_i, atom_j, charge, sigma, epsilon = nonbonded.getExceptionParameters(
                exception_i
            )
            if atom_i in self.solute and atom_j in self.solute:
                nonbonded.setExceptionParameters(
                    exception_i, atom_i, atom_j, 0.0, sigma, 0.0
                )
        return this_system

    @ray.method(num_returns=1)
    def add_intre_LJ(self, this_system, lambda_value=1.0):
        """
        * move the Intermolecular LJ to `CustomNonbonded`
        * add lambda scaling to intermolecular LJ

        Parameters
        ----------
        system: openmm.System
            The OpenMM system containing the nonbonded force.
        lambda_value: float
            The starting lambda value for the Elec scaling.
        """
        nonbonded = [
            force
            for force in this_system.getForces()
            if isinstance(force, NonbondedForce)
        ][0]

        # Intermolecular LJ - softcore potential added
        intermol_LJ = CustomNonbondedForce(
            "(lambda_sterics^softcore_a)*4*epsilon*x*(x-1.0);"
            "x=(sigma/reff_sterics)^6;"
            "reff_sterics=sigma*((softcore_alpha*(1.0-lambda_sterics)^softcore_b+(r/sigma)^softcore_c))^(1/softcore_c);"
            "sigma=0.5*(sigma1+sigma2);"
            "epsilon=sqrt(epsilon1*epsilon2);"
        )
        intermol_LJ.addGlobalParameter("lambda_sterics", lambda_value)
        intermol_LJ.addGlobalParameter("softcore_a", 1)
        intermol_LJ.addGlobalParameter("softcore_b", 1)
        intermol_LJ.addGlobalParameter("softcore_c", 6)
        intermol_LJ.addGlobalParameter("softcore_alpha", 0.5)
        intermol_LJ.addPerParticleParameter("sigma")
        intermol_LJ.addPerParticleParameter("epsilon")
        intermol_LJ.setCutoffDistance(nonbonded.getCutoffDistance())
        intermol_LJ.setNonbondedMethod(CustomNonbondedForce.CutoffPeriodic)
        intermol_LJ.setUseLongRangeCorrection(nonbonded.getUseDispersionCorrection())
        intermol_LJ.addInteractionGroup(self.solute, self.solvent)

        # Add Particles to the `intermol_LJ`
        for atom_i in range(nonbonded.getNumParticles()):
            charge, sigma, epsilon = nonbonded.getParticleParameters(atom_i)
            intermol_LJ.addParticle([sigma, epsilon])

        # Copy Exceptions to `intermol_LJ`
        for exception_i in range(nonbonded.getNumExceptions()):
            (
                atom_i,
                atom_j,
                chargeprod,
                sigma,
                epsilon,
            ) = nonbonded.getExceptionParameters(exception_i)
            intermol_LJ.addExclusion(atom_i, atom_j)

        this_system.addForce(intermol_LJ)
        return this_system

    @ray.method(num_returns=1)
    def decoupling_vdw_system(self):
        dvdw_system = copy.deepcopy(self.system)
        dvdw_system = self.add_intre_LJ(dvdw_system)
        dvdw_system = self.add_intra_LJ(dvdw_system)
        dvdw_system = self.turn_off_nonbonded(dvdw_system)

        return XmlSerializer.serialize(dvdw_system)

    @ray.method(num_returns=1)
    def annihilate_elec_system(self):
        aelec_system = copy.deepcopy(self.system)
        aelec_system = self.add_elec_parm(aelec_system)
        aelec_system = self.add_intre_LJ(aelec_system)
        aelec_system = self.add_intra_LJ(aelec_system)
        aelec_system = self.turn_off_nonbonded(aelec_system)

        return XmlSerializer.serialize(aelec_system)

    @ray.method(num_returns=1)
    def recharging_ele_system(self):
        relec_system = copy.deepcopy(self.system_vac)
        relec_system = self.add_elec_parm(relec_system)
        relec_system = self.add_intra_LJ(relec_system)
        relec_system = self.turn_off_nonbonded(relec_system)

        return XmlSerializer.serialize(relec_system)

    @ray.method(num_returns=1)
    def alchemy_worker(
        self, this_system, parameter_name, lambda_values, prod_steps, niterations
    ):

        """
        parameter_name:
        vdw: lambda_sterics
        elec: lambda_electrostatics
        """

        this_system = copy.deepcopy(XmlSerializer.deserialize(this_system))

        this_integrator = LangevinIntegrator(
            self.temperature, 1 / picosecond, 2 * femtoseconds
        )

        if parameter_name in ["dvdw", "aelec"]:
            simulation = Simulation(
                self.topology, this_system, this_integrator, self.platform
            )
            simulation.context.setPositions(self.positions)
        else:
            simulation = Simulation(
                self.topology_vac, this_system, this_integrator, self.platform
            )
            simulation.context.setPositions(self.positions_vac)
        if parameter_name == "dvdw":
            decoupled_parameter = "lambda_sterics"
        else:
            decoupled_parameter = "lambda_electrostatics"

        simulation.minimizeEnergy(
            tolerance=1.0 * kilojoules_per_mole, maxIterations=5000
        )

        simulation.step(5000)

        n_lambda = len(lambda_values)

        u_kln = np.zeros((n_lambda, n_lambda, niterations))

        for k, lb1 in enumerate(lambda_values):

            simulation.context.setParameter(decoupled_parameter, lb1)

            # Short equilibration
            simulation.step(prod_steps)

            # Run production
            for n in range(niterations):
                simulation.context.setParameter(decoupled_parameter, lb1)
                simulation.step(prod_steps)

                for l, lb2 in enumerate(lambda_values):
                    simulation.context.setParameter(decoupled_parameter, lb2)
                    u_kln[k, l, n] = (
                        simulation.context.getState(getEnergy=True).getPotentialEnergy()
                        / self.kT
                    )

        return u_kln

    @ray.method(num_returns=3)
    def analyst(self, u_kln, lambda_values):

        n_lambda = len(lambda_values)
        u_kln_copy = copy.deepcopy(u_kln)
        N_k = np.zeros([n_lambda], np.int32)

        for k in range(n_lambda):
            [nequil, g, Neff_max] = timeseries.detectEquilibration(u_kln_copy[k, k, :])
            indices = timeseries.subsampleCorrelatedData(u_kln_copy[k, k, :], g=g)
            N_k[k] = len(indices)
            u_kln_copy[k, :, 0 : N_k[k]] = u_kln_copy[k, :, indices].T

        mbar = MBAR(u_kln_copy, N_k)
        [DeltaF_ij, dDeltaF_ij, Theta_ij] = mbar.getFreeEnergyDifferences(
            return_theta=True
        )
        ODeltaF_ij = mbar.computeOverlap()["matrix"]
        convert_kcalmol = self.kT.value_in_unit(kilocalorie_per_mole)
        return (
            DeltaF_ij[0, -1] * convert_kcalmol,
            dDeltaF_ij[0, -1] * convert_kcalmol,
            ODeltaF_ij,
        )

In [None]:
system_name = "molecule"
system_smiles = "CCO"
forcefield_name = "SAGE"

system_pdb = f"{system_name}.sol.pdb"
system_xml = f"{system_name}_sol_system_{forcefield_name.lower()}.xml"
system_vac_pdb = f"{system_name}.vac.pdb"
system_vac_xml = f"{system_name}_vac_system_{forcefield_name.lower()}.xml"

Start a `ray` head node and connect this runtime to the head node. Here, we started the head node on the local machine, and assigned resource (`num_cpus=2, num_gpus=1`) per actor to the head node. More information regarding starting a cluster can be found on [ray-project](https://docs.ray.io/en/master/?badge=master).

In [None]:
wd = os.getcwd()
runtime_env = {"working_dir": wd}
ray.shutdown()
ray.init(address="auto", runtime_env=runtime_env)

## A. create an actor handler and run the HFE calculation sequentially

In [None]:
hfe_actor_handler = HFE_actor.remote(
    system_xml, system_pdb, system_vac_xml, system_vac_pdb
)

### Create alchemical systems using the same handler

In [None]:
dsystem = hfe_actor_handler.decoupling_vdw_system.remote()
asystem = hfe_actor_handler.annihilate_elec_system.remote()
rsystem = hfe_actor_handler.recharging_ele_system.remote()

### Start alchemy workers on the handler

In [None]:
keywords = [[dsystem, "dvdw"], [asystem, "aelec"], [rsystem, "relec"]]
lambda_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
prod_steps = 50000
niterations = 100

In [None]:
alchemy_workers = [
    hfe_actor_handler.alchemy_worker.remote(
        k[0], k[1], lambda_values, prod_steps, niterations
    )
    for k in keywords
]

### Analyze the result

In [None]:
alchemy_results = ray.get(alchemy_workers)

In [None]:
analysis_worker = [
    hfe_actor_handler.analyst.remote(data, lambda_values) for data in alchemy_results
]

### Print out the result

In [None]:
tmp_results = ray.get(analysis_worker)
results = {}
for idx, k in enumerate(keywords):
    this_result = tmp_results[idx]
    results[k[1]] = {}
    results[k[1]]["fe"] = this_result[0]
    results[k[1]]["sem"] = this_result[1]
    print(f"{k[1]:24}{this_result[0]: .3f} +- {this_result[1]: .3f} kcal/mol")
total_fe = results["aelec"]["fe"] - results["relec"]["fe"] + results["dvdw"]["fe"]
total_sem = results["aelec"]["sem"] + results["relec"]["sem"] + results["dvdw"]["sem"]
print("-" * 50)
print(f"solvation free energy:\t{total_fe: .3f} +- {total_sem: .3f} kcal/mol")

## B. Create three actor handlers and run the HFE calculation on multiple nodes

This example will run on three GPUs

In [None]:
dvdw_handler = HFE_actor.remote(system_xml, system_pdb, system_vac_xml, system_vac_pdb) 
aele_handler = HFE_actor.remote(system_xml, system_pdb, system_vac_xml, system_vac_pdb) 
rele_handler = HFE_actor.remote(system_xml, system_pdb, system_vac_xml, system_vac_pdb) 

### Create an alchemical system for each handler

In [None]:
dsystem = dvdw_handler.decoupling_vdw_system.remote()
asystem = aele_handler.annihilate_elec_system.remote()
rsystem = rele_handler.recharging_ele_system.remote()

### Start alchemy workers on handlers

In [None]:
keywords = [
    [dvdw_handler, dsystem, "dvdw"],
    [aele_handler, asystem, "aelec"],
    [rele_handler, rsystem, "relec"],
]
lambda_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
prod_steps = 50000
niterations = 100

In [None]:
alchemy_workers = [
    k[0].alchemy_worker.remote(k[1], k[2], lambda_values, prod_steps, niterations)
    for k in keywords
]

### Analyze the result

In [None]:
alchemy_results = ray.get(alchemy_workers)

In [None]:
analysis_worker = [
    k[idx][0].analyst.remote(data, lambda_values)
    for idx, data in enumerate(alchemy_results)
]

### Print out the result

In [None]:
tmp_results = ray.get(analysis_worker)
results = {}
for idx, k in enumerate(keywords):
    this_result = tmp_results[idx]
    results[k[-1]] = {}
    results[k[-1]]["fe"] = this_result[0]
    results[k[-1]]["sem"] = this_result[1]
    print(f"{k[-1]:24}{this_result[0]: .3f} +- {this_result[1]: .3f} kcal/mol")
total_fe = results["aelec"]["fe"] - results["relec"]["fe"] + results["dvdw"]["fe"]
total_sem = results["aelec"]["sem"] + results["relec"]["sem"] + results["dvdw"]["sem"]
print("-" * 50)
print(f"solvation free energy:\t{total_fe: .3f} +- {total_sem: .3f} kcal/mol")