Skip to content

Commit

Permalink
Write out rigid water for GROMACS top file (#771)
Browse files Browse the repository at this point in the history
* support to write out settles section for gromacs

* add isrigid to molecule

* add default val for moleculeType isrigid

* revert change to moleculetype

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* update molecule and residue class, fix unit tests

* use site.name if no element info is found

* add patch for getting non-element site

* let element.py handle non-element issue

* minor fix

* add set_rigid method

* add rigid water fixture for settles tests, update mcf test

* update test for water settles section in gromacs top

* Update gmso/abc/abstract_site.py

Co-authored-by: CalCraven <54594941+CalCraven@users.noreply.github.com>

* Update gmso/abc/abstract_site.py

Co-authored-by: CalCraven <54594941+CalCraven@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* remove isrigid for residue

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* add more check when setting molecule/residue

* Update gmso/formats/top.py

Co-authored-by: CalCraven <54594941+CalCraven@users.noreply.github.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: CalCraven <54594941+CalCraven@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 18, 2024
1 parent 5b3040a commit 444d9e1
Show file tree
Hide file tree
Showing 16 changed files with 332 additions and 44 deletions.
147 changes: 139 additions & 8 deletions gmso/abc/abstract_site.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Basic interaction site in GMSO that all other sites will derive from."""

import warnings
from typing import Any, ClassVar, NamedTuple, Optional, Sequence, TypeVar, Union
from typing import Any, ClassVar, Optional, Sequence, TypeVar, Union

import numpy as np
import unyt as u
from pydantic import (
ConfigDict,
Field,
StrictInt,
StrictStr,
field_serializer,
field_validator,
Expand All @@ -20,8 +19,124 @@
from gmso.exceptions import GMSOError

PositionType = Union[Sequence[float], np.ndarray, u.unyt_array]
MoleculeType = NamedTuple("Molecule", name=StrictStr, number=StrictInt)
ResidueType = NamedTuple("Residue", name=StrictStr, number=StrictInt)


class Molecule(GMSOBase):
def __repr__(self):
return (
f"Molecule(name={self.name}, residue={self.residue}, isrigid={self.isrigid}"
)

__iterable_attributes__: ClassVar[set] = {
"name",
"number",
"isrigid",
}

__base_doc__: ClassVar[str] = "Molecule label for interaction sites."

name_: str = Field(
"",
validate_default=True,
description="Name of the molecule",
alias="name",
)
number_: int = Field(
0,
description="The index/number of the molecule",
alias="number",
)
isrigid_: bool = Field(
False,
description="Indicate whether the molecule is rigid",
)
model_config = ConfigDict(
alias_to_fields={
"name": "name_",
"number": "number_",
"isrigid": "isrigid_",
}
)

@property
def name(self) -> str:
"""Return the name of the molecule."""
return self.__dict__.get("name_")

@property
def number(self) -> int:
"""Return the index/number of the moleucle."""
return self.__dict__.get("number_")

@property
def isrigid(self) -> bool:
"""Return the rigid label of the molecule."""
return self.__dict__.get("isrigid_")

def __hash__(self):
return hash(tuple([(name, val) for name, val in self.__dict__.items()]))

def __eq__(self, other):
"""Test if two objects are equivalent."""
if isinstance(other, (list, tuple)):
return all(
[val1 == val2 for val1, val2 in zip(self.__dict__.values(), other)]
)
else:
return self.__dict__ == other.__dict__


class Residue(GMSOBase):
def __repr__(self):
return f"Residue(name={self.name}, residue={self.residue}"

__iterable_attributes__: ClassVar[set] = {
"name",
"number",
}

__base_doc__: ClassVar[str] = "Residue label for interaction sites."

name_: str = Field(
"",
validate_default=True,
description="Name of the residue",
alias="name",
)
number_: int = Field(
0,
description="The index/number of the residue",
alias="number",
)
model_config = ConfigDict(
alias_to_fields={
"name": "name_",
"number": "number_",
}
)

@property
def name(self) -> str:
"""Return the name of the residue."""
return self.__dict__.get("name_")

@property
def number(self) -> int:
"""Return the index/number of the residue."""
return self.__dict__.get("number_")

def __hash__(self):
return hash(tuple([(name, val) for name, val in self.__dict__.items()]))

def __eq__(self, other):
"""Test if two objects are equivalent."""
if isinstance(other, (list, tuple)):
return all(
[val1 == val2 for val1, val2 in zip(self.__dict__.values(), other)]
)
else:
return self.__dict__ == other.__dict__


SiteT = TypeVar("SiteT", bound="Site")

Expand Down Expand Up @@ -76,13 +191,13 @@ class Site(GMSOBase):
alias="group",
)

molecule_: Optional[MoleculeType] = Field(
molecule_: Optional[Union[Molecule, list, tuple]] = Field(
None,
description="Molecule label for the site, format of (molecule_name, molecule_number)",
alias="molecule",
)

residue_: Optional[ResidueType] = Field(
residue_: Optional[Union[Residue, list, tuple]] = Field(
None,
description="Residue label for the site, format of (residue_name, residue_number)",
alias="residue",
Expand Down Expand Up @@ -126,7 +241,7 @@ def group(self) -> str:
return self.__dict__.get("group_")

@property
def molecule(self) -> tuple:
def molecule(self):
"""Return the molecule of the site."""
return self.__dict__.get("molecule_")

Expand Down Expand Up @@ -185,12 +300,28 @@ def is_valid_position(cls, position):
return position

@field_validator("name_")
def inject_name(cls, value):
def parse_name(cls, value):
if value == "" or value is None:
return cls.__name__
else:
return value

@field_validator("residue_")
def parse_residue(cls, value):
if isinstance(value, (tuple, list)):
assert len(value) == 2
value = Residue(name=value[0], number=value[1])
return value

@field_validator("molecule_")
def parse_molecule(cls, value):
if isinstance(value, (tuple, list)):
if len(value) == 2:
value = Molecule(name=value[0], number=value[1])
elif len(value) == 3:
value = Molecule(name=value[0], number=value[1], isrigid=value[2])
return value

@classmethod
def __new__(cls, *args: Any, **kwargs: Any) -> SiteT:
if cls is Site:
Expand Down
10 changes: 9 additions & 1 deletion gmso/core/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,22 @@ def element_by_smarts_string(smarts_string, verbose=False):
GMSOError
If no matching element is found for the provided smarts string
"""
from lark import UnexpectedCharacters

from gmso.utils.io import import_

foyer = import_("foyer")
SMARTS = foyer.smarts.SMARTS

PARSER = SMARTS()

symbols = PARSER.parse(smarts_string).iter_subtrees_topdown()
try:
symbols = PARSER.parse(smarts_string).iter_subtrees_topdown()
except UnexpectedCharacters:
raise GMSOError(
f"Failed to find an element from SMARTS string {smarts_string}. "
f"The SMARTS string contained unexpected characters."
)

first_symbol = None
for symbol in symbols:
Expand Down
44 changes: 39 additions & 5 deletions gmso/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import itertools
import warnings
from copy import copy
from pathlib import Path

import numpy as np
import unyt as u
from boltons.setutils import IndexedSet

import gmso
from gmso.abc.abstract_site import Site
from gmso.abc.abstract_site import Molecule, Residue, Site
from gmso.abc.serialization_utils import unyt_to_dict
from gmso.core.angle import Angle
from gmso.core.angle_type import AngleType
Expand Down Expand Up @@ -313,7 +314,7 @@ def unique_site_labels(self, label_type="molecule", name_only=False):
unique_tags.add(label.name if label else None)
else:
for site in self.sites:
unique_tags.add(getattr(site, label_type))
unique_tags.add(copy(getattr(site, label_type)))
return unique_tags

@property
Expand Down Expand Up @@ -642,6 +643,18 @@ def get_scaling_factors(self, *, molecule_id=None):
]
)

def set_rigid(self, molecule):
"""Set molecule tags to rigid if they match the name or number specified.
Parameters
----------
molecule : str, Molecule, or tuple of 2
Specified the molecule name and number to be set rigid.
If only string is provided, make all molecule of that name rigid.
"""
for site in self.iter_sites(key="molecule", value=molecule):
site.molecule.isrigid = True

def remove_site(self, site):
"""Remove a site from the topology.
Expand Down Expand Up @@ -1382,9 +1395,30 @@ def iter_sites(self, key, value):
for site in self._sites:
if getattr(site, key) and getattr(site, key).name == value:
yield site
for site in self._sites:
if getattr(site, key) == value:
yield site
elif isinstance(value, (tuple, list)):
containers_dict = {"molecule": Molecule, "residue": Residue}
if len(value) == 2:
tmp = containers_dict[key](name=value[0], number=value[1])
elif len(value) == 3:
tmp = containers_dict[key](
name=value[0], number=value[1], isrigid=value[2]
)
else:
raise ValueError(
f"""
Argument value was passed as {value},
but should be an indexible iterable of
[name, number, isrigid] where name is type string,
number is type int, and isrigid is type bool.
"""
)
for site in self._sites:
if getattr(site, key) and getattr(site, key) == tmp:
yield site
else:
for site in self._sites:
if getattr(site, key) == value:
yield site

def iter_sites_by_residue(self, residue_tag):
"""Iterate through this topology's sites which contain this specific residue name.
Expand Down
7 changes: 5 additions & 2 deletions gmso/external/convert_mbuild.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from boltons.setutils import IndexedSet
from unyt import Unit

from gmso.abc.abstract_site import Residue
from gmso.core.atom import Atom
from gmso.core.bond import Bond
from gmso.core.box import Box
Expand Down Expand Up @@ -179,12 +180,14 @@ def to_mbuild(topology, infer_hierarchy=True):
particle = _parse_particle(particle_map, site)
# Try to add the particle to a residue level
residue_tag = (
site.residue if site.residue else ("DefaultResidue", 0)
site.residue
if site.residue
else Residue(name="DefaultResidue", number=0)
) # the 0 idx is placeholder and does nothing
if residue_tag in residue_dict:
residue_dict_particles[residue_tag] += [particle]
else:
residue_dict[residue_tag] = mb.Compound(name=residue_tag[0])
residue_dict[residue_tag] = mb.Compound(name=residue_tag.name)
residue_dict_particles[residue_tag] = [particle]

for key, item in residue_dict.items():
Expand Down
6 changes: 4 additions & 2 deletions gmso/formats/lammpsdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from unyt.array import allclose_units

import gmso
from gmso.abc.abstract_site import MoleculeType
from gmso.abc.abstract_site import Molecule
from gmso.core.angle import Angle
from gmso.core.atom import Atom
from gmso.core.atom_type import AtomType
Expand Down Expand Up @@ -488,7 +488,9 @@ def _get_atoms(filename, topology, base_unyts, type_list):
charge=charge,
position=coord,
atom_type=copy.deepcopy(type_list[int(atom_type) - 1]), # 0-index
molecule=MoleculeType(atom_line[1], int(atom_line[1]) - 1), # 0-index
molecule=Molecule(
name=atom_line[1], number=int(atom_line[1]) - 1
), # 0-index
)
element = element_by_mass(site.atom_type.mass.value)
site.name = element.name if element else site.atom_type.name
Expand Down
6 changes: 3 additions & 3 deletions gmso/formats/mol2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unyt as u

from gmso import Atom, Bond, Box, Topology
from gmso.abc.abstract_site import MoleculeType, ResidueType
from gmso.abc.abstract_site import Molecule, Residue
from gmso.core.element import element_by_name, element_by_symbol
from gmso.formats.formats_registry import loads_as

Expand Down Expand Up @@ -151,8 +151,8 @@ def parse_ele(*symbols):
position=position.to("nm"),
element=element,
charge=charge,
residue=ResidueType(content[7], int(content[6])),
molecule=MoleculeType(molecule, 0),
residue=Residue(name=content[7], number=int(content[6])),
molecule=Molecule(name=molecule, number=0),
)
top.add_site(atom)

Expand Down
Loading

0 comments on commit 444d9e1

Please sign in to comment.