Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand Potential functionality for parameter interpolation #137

Merged
merged 8 commits into from Jul 28, 2021
75 changes: 66 additions & 9 deletions openff/interchange/components/potentials.py
@@ -1,9 +1,9 @@
import ast
from typing import TYPE_CHECKING, Dict, List, Set
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union

from openff.toolkit.typing.engines.smirnoff.parameters import ParameterHandler
from openff.utilities.utilities import requires_package
from pydantic import Field, validator
from pydantic import Field, PrivateAttr, validator

from openff.interchange.models import DefaultModel, PotentialKey, TopologyKey
from openff.interchange.types import ArrayQuantity, FloatQuantity
Expand All @@ -16,6 +16,7 @@ class Potential(DefaultModel):
"""Base class for storing applied parameters"""

parameters: Dict[str, FloatQuantity] = dict()
map_key: Optional[int] = None

@validator("parameters")
def validate_parameters(cls, v):
Expand All @@ -26,6 +27,41 @@ def validate_parameters(cls, v):
v[key] = FloatQuantity.validate_type(val)
return v

def __hash__(self):
return hash(tuple(self.parameters.values()))


class WrappedPotential(DefaultModel):
"""Model storing other Potential model(s) inside inner data"""

class InnerData(DefaultModel):
data: Dict[Potential, float]

_inner_data: InnerData = PrivateAttr()

def __init__(self, data):
if isinstance(data, Potential):
self._inner_data = self.InnerData(data={data: 1.0})
elif isinstance(data, dict):
self._inner_data = self.InnerData(data=data)

@property
def parameters(self):
keys = {
pot for pot in self._inner_data.data.keys() for pot in pot.parameters.keys()
}

params = dict()
for key in keys:
sum_ = 0.0
for pot, coeff in self._inner_data.data.items():
sum_ += coeff * pot.parameters[key]
params.update({key: sum_})
return params

def __repr__(self):
return str(self._inner_data.data)


class PotentialHandler(DefaultModel):
"""Base class for storing parametrized force field data"""
Expand All @@ -39,7 +75,7 @@ class PotentialHandler(DefaultModel):
dict(),
description="A mapping between TopologyKey objects and PotentialKey objects.",
)
potentials: Dict[PotentialKey, Potential] = Field(
potentials: Dict[PotentialKey, Union[Potential, WrappedPotential]] = Field(
dict(),
description="A mapping between PotentialKey objects and Potential objects.",
)
Expand Down Expand Up @@ -73,8 +109,14 @@ def get_force_field_parameters(self):

params: list = list()
for potential in self.potentials.values():
row = [val.magnitude for val in potential.parameters.values()]
params.append(row)
if isinstance(potential, Potential):
params.append([val.magnitude for val in potential.parameters.values()])
elif isinstance(potential, WrappedPotential):
for inner_pot in potential._inner_data.data.keys():
if inner_pot not in params:
params.append(
[val.magnitude for val in inner_pot.parameters.values()]
)

return jax.numpy.array(params)

Expand All @@ -89,17 +131,32 @@ def get_system_parameters(self, p=None):
mapping = self.get_mapping()
q: List = list()

for key in self.slot_map.keys():
q.append(p[mapping[self.slot_map[key]]])
for val in self.slot_map.values():
if val.bond_order:
p_ = p[0] * 0.0
for inner_pot, coeff in self.potentials[val]._inner_data.data.items():
p_ += p[mapping[inner_pot]] * coeff
q.append(p_)
else:
q.append(p[mapping[self.potentials[val]]])

return jax.numpy.array(q)

def get_mapping(self) -> Dict:
mapping: Dict = dict()
for idx, key in enumerate(self.potentials.keys()):
idx = 0
for key, pot in self.potentials.items():
for p in self.slot_map.values():
if key == p:
mapping.update({key: idx})
if isinstance(pot, Potential):
if pot not in mapping:
mapping.update({pot: idx})
idx += 1
elif isinstance(pot, WrappedPotential):
for inner_pot in pot._inner_data.data:
if inner_pot not in mapping:
mapping.update({inner_pot: idx})
idx += 1

return mapping

Expand Down
103 changes: 92 additions & 11 deletions openff/interchange/components/smirnoff.py
Expand Up @@ -26,9 +26,15 @@
from simtk import unit as omm_unit
from typing_extensions import Literal

from openff.interchange.components.potentials import Potential, PotentialHandler
from openff.interchange.components.potentials import (
Potential,
PotentialHandler,
WrappedPotential,
)
from openff.interchange.exceptions import (
InvalidParameterHandlerError,
MissingBondOrdersError,
MissingParametersError,
SMIRNOFFParameterAttributeNotImplementedError,
)
from openff.interchange.models import PotentialKey, TopologyKey
Expand Down Expand Up @@ -140,12 +146,55 @@ def allowed_parameter_handlers(cls):

@classmethod
def supported_parameters(cls):
return ["smirks", "id", "k", "length"]
return ["smirks", "id", "k", "length", "k_bondorder", "length_bondorder"]

@classmethod
def valence_terms(cls, topology):
return [list(b.atoms) for b in topology.topology_bonds]

def store_matches(
self,
parameter_handler: ParameterHandler,
topology: Union["Topology", "OFFBioTop"],
) -> None:
"""
Populate self.slot_map with key-val pairs of slots
and unique potential identifiers

"""
parameter_handler_name = getattr(parameter_handler, "_TAGNAME", None)
if self.slot_map:
# TODO: Should the slot_map always be reset, or should we be able to partially
# update it? Also Note the duplicated code in the child classes
self.slot_map = dict()
matches = parameter_handler.find_matches(topology)
for key, val in matches.items():
param = val.parameter_type
if param.k_bondorder or param.length_bondorder:
top_bond = topology.get_bond_between(*key) # type: ignore[union-attr]
fractional_bond_order = top_bond.bond.fractional_bond_order
if not fractional_bond_order:
raise MissingBondOrdersError(
"Interpolation currently requires bond orders pre-specified"
)
else:
fractional_bond_order = None
topology_key = TopologyKey(
atom_indices=key, bond_order=fractional_bond_order
)
potential_key = PotentialKey(
id=val.parameter_type.smirks, associated_handler=parameter_handler_name
)
self.slot_map[topology_key] = potential_key

valence_terms = self.valence_terms(topology)

parameter_handler._check_all_valence_terms_assigned(
assigned_terms=matches,
valence_terms=valence_terms,
exception_cls=UnassignedValenceParameterException,
)

def store_potentials(self, parameter_handler: "BondHandler") -> None:
"""
Populate self.potentials with key-val pairs of unique potential
Expand All @@ -154,15 +203,41 @@ def store_potentials(self, parameter_handler: "BondHandler") -> None:
"""
if self.potentials:
self.potentials = dict()
for potential_key in self.slot_map.values():
for topology_key, potential_key in self.slot_map.items():
smirks = potential_key.id
parameter_type = parameter_handler.get_parameter({"smirks": smirks})[0]
potential = Potential(
parameters={
"k": parameter_type.k,
"length": parameter_type.length,
},
)
if topology_key.bond_order:
bond_order = topology_key.bond_order
if parameter_type.k_bondorder:
data = parameter_type.k_bondorder
else:
data = parameter_type.length_bondorder
coeffs = _get_interpolation_coeffs(
fractional_bond_order=bond_order,
data=data,
)
pots = []
map_keys = [*data.keys()]
for map_key in map_keys:
pots.append(
Potential(
parameters={
"k": parameter_type.k_bondorder[map_key],
"length": parameter_type.length_bondorder[map_key],
},
map_key=map_key,
)
)
potential = WrappedPotential(
{pot: coeff for pot, coeff in zip(pots, coeffs)}
)
else:
potential = Potential( # type: ignore[assignment]
parameters={
"k": parameter_type.k,
"length": parameter_type.length,
},
)
self.potentials[potential_key] = potential

@classmethod
Expand Down Expand Up @@ -269,8 +344,6 @@ def store_constraints(
else:
# This constraint parameter depends on the BondHandler ...
if bond_handler is None:
from openff.interchange.exceptions import MissingParametersError

raise MissingParametersError(
f"Constraint with SMIRKS pattern {smirks} found with no distance "
"specified, and no corresponding bond parameters were found. The distance "
Expand Down Expand Up @@ -1006,6 +1079,14 @@ def library_charge_from_molecule(
return library_charge_type


def _get_interpolation_coeffs(fractional_bond_order, data):
x1, x2 = data.keys()
coeff1 = (x2 - fractional_bond_order) / (x2 - x1)
coeff2 = (fractional_bond_order - x1) / (x2 - x1)

return coeff1, coeff2


SMIRNOFF_POTENTIAL_HANDLERS = [
SMIRNOFFBondHandler,
SMIRNOFFConstraintHandler,
Expand Down
6 changes: 6 additions & 0 deletions openff/interchange/exceptions.py
Expand Up @@ -130,6 +130,12 @@ class MissingParametersError(BaseException):
"""


class MissingBondOrdersError(BaseException):
"""
Exception for when a parameter handler needs fractional bond orders but they are missing
"""


class MissingUnitError(ValueError):
"""
Exception for data missing a unit tag
Expand Down
5 changes: 5 additions & 0 deletions openff/interchange/models.py
Expand Up @@ -64,6 +64,11 @@ class TopologyKey(DefaultModel):
mult: Optional[int] = Field(
None, description="The index of this duplicate interaction"
)
bond_order: Optional[float] = Field(
None,
description="If this is a key to a WrappedPotential interpolating multiple parameter(s), "
"the bond order determining the coefficients of the wrapped potentials.",
)

def __hash__(self):
return hash((self.atom_indices, self.mult))
Expand Down
40 changes: 39 additions & 1 deletion openff/interchange/tests/components/test_potentials.py
@@ -1,7 +1,45 @@
from openff.interchange.components.potentials import PotentialHandler
from openff.toolkit.typing.engines.smirnoff.parameters import BondHandler
from openff.units import unit

from openff.interchange.components.potentials import (
Potential,
PotentialHandler,
WrappedPotential,
)
from openff.interchange.tests import BaseTest


class TestWrappedPotential(BaseTest):
def test_interpolated_potentials(self):
"""Test the construction of and .parameters getter of WrappedPotential"""

bt = BondHandler.BondType(
smirks="[#6X4:1]~[#8X2:2]",
id="bbo1",
k_bondorder1="100.0 * kilocalories_per_mole/angstrom**2",
k_bondorder2="200.0 * kilocalories_per_mole/angstrom**2",
length_bondorder1="1.4 * angstrom",
length_bondorder2="1.3 * angstrom",
)

pot1 = Potential(
parameters={"k": bt.k_bondorder[1], "length": bt.length_bondorder[1]}
)
pot2 = Potential(
parameters={"k": bt.k_bondorder[2], "length": bt.length_bondorder[2]}
)

interp_pot = WrappedPotential(data={pot1: 0.2, pot2: 0.8})
assert interp_pot.parameters == {
"k": 180 * unit.Unit("kilocalorie / angstrom ** 2 / mole"),
"length": 1.32 * unit.angstrom,
}

# Ensure a single Potential object can be wrapped with similar behavior
simple = WrappedPotential(data=pot2)
assert simple.parameters == pot2.parameters


class TestPotentialHandlerSubclassing(BaseTest):
def test_dummy_potential_handler(self):
handler = PotentialHandler(
Expand Down