Skip to content

Commit

Permalink
Merge pull request #232 from nlesc-nano/guess
Browse files Browse the repository at this point in the history
BUG: Fixed an issue wherein guessed parameters would overwrite ones that were explicitly specified
  • Loading branch information
BvB93 committed Apr 8, 2021
2 parents 3f5c0cc + 650fec6 commit 6ae3fa5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 34 deletions.
77 changes: 45 additions & 32 deletions FOX/armc/guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
"""

from __future__ import annotations

from types import MappingProxyType
from itertools import chain
from typing import (
Union,
Iterable,
Mapping,
MutableMapping,
Tuple,
Optional,
Dict,
Set,
Container,
FrozenSet,
TYPE_CHECKING
)

Expand All @@ -39,14 +38,12 @@
from ..ff import UFF_DF, SIGMA_DF, LJDataFrame, estimate_lj

if TYPE_CHECKING:
from ..classes import MultiMolecule
else:
from ..type_alias import MultiMolecule
from FOX import MultiMolecule

__all__ = ['guess_param']

Param = Literal['epsilon', 'sigma']
Mode = Literal[
ParamKind = Literal['epsilon', 'sigma']
ModeKind = Literal[
'ionic_radius',
'ion_radius',
'ionic_radii',
Expand All @@ -58,36 +55,40 @@
]

#: A :class:`frozenset` with alias for the :code:`"ion_radius"` guessing mode.
ION_SET: FrozenSet[str] = frozenset({
ION_SET = frozenset({
'ionic_radius',
'ion_radius',
'ionic_radii',
'ion_radii'
})

#: A :class:`frozenset` with alias for the :code:`"crystal_radius"` guessing mode.
CRYSTAL_SET: FrozenSet[str] = frozenset({
CRYSTAL_SET = frozenset({
'crystal_radius',
'crystal_radii'
})

#: A :class:`frozenset` containing all allowed values for the ``mode`` parameter.
MODE_SET: FrozenSet[str] = ION_SET | CRYSTAL_SET | {'rdf', 'uff'}
MODE_SET = ION_SET | CRYSTAL_SET | {'rdf', 'uff'}

#: A :class:`~collections.abc.Mapping` containing the default unit for each ``param`` value.
DEFAULT_UNIT: Mapping[Param, str] = MappingProxyType({
DEFAULT_UNIT: MappingProxyType[ParamKind, str] = MappingProxyType({
'epsilon': 'kcal/mol',
'sigma': 'angstrom',
})


def guess_param(mol_list: Iterable[MultiMolecule], param: Param,
mode: Mode = 'rdf',
cp2k_settings: Optional[MutableMapping] = None,
prm: Union[None, PathType, PRMContainer] = None,
psf_list: Optional[Iterable[Union[PathType, PSFContainer]]] = None,
unit: Optional[str] = None
) -> Dict[Tuple[str, str], float]:
def guess_param(
mol_list: Iterable[MultiMolecule],
param: ParamKind,
mode: ModeKind = 'rdf',
*,
cp2k_settings: None | MutableMapping = None,
prm: None | PathType | PRMContainer = None,
psf_list: None | Iterable[PathType | PSFContainer] = None,
unit: None | str = None,
param_mapping: None | Mapping[tuple[str, str], float] = None,
) -> Dict[Tuple[str, str], float]:
"""Estimate all Lennard-Jones missing forcefield parameters.
Examples
Expand All @@ -101,11 +102,11 @@ def guess_param(mol_list: Iterable[MultiMolecule], param: Param,
>>> prm = str(...)
>>> psf_list = [str(...), ...]
>>> epsilon_dict = guess_Param(mol_list, 'epsilon', prm=prm, psf_list=psf_list)
>>> sigma_dict = guess_Param(mol_list, 'sigma', prm=prm, psf_list=psf_list)
>>> epsilon_dict = guess_ParamKind(mol_list, 'epsilon', prm=prm, psf_list=psf_list)
>>> sigma_dict = guess_ParamKind(mol_list, 'sigma', prm=prm, psf_list=psf_list)
Parameters
ParamKindeters
----------
mol_list : :class:`Iterable[FOX.MultiMolecule] <collections.abc.Iterable>`
An iterable of molecules.
Expand Down Expand Up @@ -135,6 +136,10 @@ def guess_param(mol_list: Iterable[MultiMolecule], param: Param,
# Validate param and mode
param = _validate_arg(param, name='param', ref={'epsilon', 'sigma'}) # type: ignore
mode = _validate_arg(mode, name='mode', ref=MODE_SET) # type: ignore
if unit is not None:
convert_unit = Units.conversion_ratio(DEFAULT_UNIT[param], unit)
else:
convert_unit = 1

# Construct a set with all valid atoms types
mol_list = [mol.copy() for mol in mol_list]
Expand All @@ -152,6 +157,10 @@ def guess_param(mol_list: Iterable[MultiMolecule], param: Param,
if cp2k_settings is not None:
df.overlay_cp2k_settings(cp2k_settings)

if param_mapping is not None:
for k, v in param_mapping.items():
df.loc[k, param] = v / convert_unit

if prm is not None:
prm_: PRMContainer = prm if isinstance(prm, PRMContainer) else PRMContainer.read(prm)
df.overlay_prm(prm_)
Expand All @@ -165,8 +174,7 @@ def guess_param(mol_list: Iterable[MultiMolecule], param: Param,

# Construct the to-be returned series and set them to the correct units
ret = _guess_param(series, mode, mol_list=mol_list, prm_dict=prm_dict)
if unit is not None:
ret *= Units.conversion_ratio(DEFAULT_UNIT[param], unit)
ret *= convert_unit
return ret


Expand All @@ -189,10 +197,13 @@ def _validate_arg(value: str, name: str, ref: Container[str]) -> str:
return ret


def _guess_param(series: pd.Series, mode: Mode,
mol_list: Iterable[MultiMolecule],
prm_dict: MutableMapping[str, float],
unit: Optional[str] = None) -> pd.Series:
def _guess_param(
series: pd.Series,
mode: ModeKind,
mol_list: Iterable[MultiMolecule],
prm_dict: MutableMapping[str, float],
unit: None | str = None,
) -> pd.Series:
"""Perform the parameter guessing as specified by **mode**.
Returns
Expand Down Expand Up @@ -295,9 +306,11 @@ def _arithmetic_mean(a, b):


@prepend_exception('No reference parameters available for atom type: ', exception=KeyError)
def _set_radii(series: pd.Series,
prm_mapping: Mapping[str, float],
ref_mapping: Mapping[str, float]) -> None:
def _set_radii(
series: pd.Series,
prm_mapping: Mapping[str, float],
ref_mapping: Mapping[str, float],
) -> None:
if series.name == 'epsilon':
func = _geometric_mean
elif series.name == 'sigma':
Expand All @@ -319,7 +332,7 @@ def _set_radii(series: pd.Series,
series[i, j] = func(value_i, value_j)


def _nb_from_prm(prm: PRMContainer, param: Param) -> Dict[str, float]:
def _nb_from_prm(prm: PRMContainer, param: ParamKind) -> Dict[str, float]:
r"""Extract a dict from **prm** with all :math:`\varepsilon` or :math:`\sigma` values."""
if prm.nonbonded is None:
return {}
Expand Down
12 changes: 10 additions & 2 deletions FOX/armc/sanitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,16 @@ def _guess_param(mc: MonteCarloABC, prm: dict,
unit = UNIT_MAP[v.get('unit') or ('k_e' if param == 'epsilon' else 'angstrom')]
unit_lst.append(unit)

prm_series = guess_param(mc.molecule, param, mode=mode,
psf_list=psf, prm=prm_file, unit=unit)
try:
param_mapping = mc.param.param.loc[(k, param), 0].copy()
param_mapping.index = pd.MultiIndex.from_tuples(i.split() for i in param_mapping.index)
except KeyError:
param_mapping = None

prm_series = guess_param(
mc.molecule, param, mode,
psf_list=psf, prm=prm_file, unit=unit, param_mapping=param_mapping,
)
prm_dict = {' '.join(_k for _k in sorted(k)): v for k, v in prm_series.items()}
prm_dict['param'] = param
seq.append((k, prm_dict))
Expand Down

0 comments on commit 6ae3fa5

Please sign in to comment.