Skip to content

Commit

Permalink
Fix mass and charge in GSD writer; Add angles and dihedrals to GSD wr…
Browse files Browse the repository at this point in the history
…iter (#680)

* Set default mass and charge to 0.0

* fix typo in doc string

* fix attribute error when setting names of bonds

* trying to get gsd writer to work when not using parmed

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* fix naming when bond member is an AtomType

* fix angle writer, use f-strings

* rewrite bond parser to use same process as angles

* remove gsd import; change hoomd version in doc strings

* fix write_dihedrals

* re-use old dihedral sorting logic, remove commented out dihedral lines

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* change default mass and charge values back to None; fix handling of None mass and charges in gsd writer

* add missing f in warning message

* remove periods for consistency

* fix issue of setting typeid index values

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* adding unit test that checks contents of gsd file

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* minor change in docstring gmso/formats/gsd.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Co Quach <43968221+daico007@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 3, 2022
1 parent 0a11627 commit dfda87f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 132 deletions.
243 changes: 115 additions & 128 deletions gmso/formats/gsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from gmso.formats.formats_registry import saves_as
from gmso.utils.geometry import coord_shift
from gmso.utils.io import has_gsd
from gmso.utils.sorting import natural_sort

__all__ = ["write_gsd"]

if has_gsd:
import gsd
import gsd.hoomd


Expand All @@ -31,7 +31,7 @@ def write_gsd(
shift_coords=True,
write_special_pairs=True,
):
"""Output a GSD file (HOOMD v2 default data format).
"""Output a GSD file (HOOMD v3 default data format).
The `GSD` binary file format is the native format of HOOMD-Blue. This file
can be used as a starting point for a HOOMD-Blue simulation, for analysis,
Expand Down Expand Up @@ -85,21 +85,21 @@ def write_gsd(
gsd_snapshot.configuration.box = np.array([lx, ly, lz, xy, xz, yz])

warnings.warn(
"Only writing particle and bond information."
" Angle and dihedral is not currently written to GSD files",
"Only writing particle, bond, angle, and dihedral information."
"Impropers and special pairs are not currently written to GSD files",
NotYetImplementedWarning,
)
_write_particle_information(
gsd_snapshot, top, xyz, ref_distance, ref_mass, ref_energy, rigid_bodies
)
# if write_special_pairs:
# _write_pair_information(gsd_snapshot, top)
if top.n_bonds > 0:
_write_bond_information(gsd_snapshot, top)
# if structure.angles:
# _write_angle_information(gsd_snapshot, top)
# if structure.rb_torsions:
# _write_dihedral_information(gsd_snapshot, top)
if top.n_angles > 0:
_write_angle_information(gsd_snapshot, top)
if top.n_dihedrals > 0:
_write_dihedral_information(gsd_snapshot, top)
# if write_special_pairs:
# _write_pair_information(gsd_snapshot, top)

with gsd.hoomd.open(filename, mode="wb") as gsd_file:
gsd_file.append(gsd_snapshot)
Expand All @@ -110,7 +110,7 @@ def _write_particle_information(
):
"""Write out the particle information."""
gsd_snapshot.particles.N = top.n_sites
warnings.warn("{} particles detected".format(top.n_sites))
warnings.warn(f"{top.n_sites} particles detected")
gsd_snapshot.particles.position = xyz / ref_distance

types = [
Expand All @@ -121,16 +121,18 @@ def _write_particle_information(
unique_types = list(set(types))
unique_types = sorted(unique_types)
gsd_snapshot.particles.types = unique_types
warnings.warn("{} unique particle types detected".format(len(unique_types)))
warnings.warn(f"{len(unique_types)} unique particle types detected")

typeids = np.array([unique_types.index(t) for t in types])
gsd_snapshot.particles.typeid = typeids

masses = np.array([site.mass for site in top.sites])
masses[masses == 0] = 1.0
masses[masses == None] = 1.0
gsd_snapshot.particles.mass = masses / ref_mass

charges = np.array([site.charge for site in top.sites])
charges[charges == None] = 0.0
e0 = u.physical_constants.eps_0.in_units(
u.elementary_charge**2 / u.Unit("kcal*angstrom/mol")
)
Expand Down Expand Up @@ -196,144 +198,129 @@ def _write_bond_information(gsd_snapshot, top):
"""
gsd_snapshot.bonds.N = top.n_bonds
warnings.warn("{} bonds detected".format(top.n_bonds))

unique_bond_types = set()
for bond in top.connections:
if isinstance(bond, Bond):
t1, t2 = (
bond.connection_members[0].atom_type,
bond.connection_members[1].atom_type,
)
if t1 is None or t2 is None:
t1, t2 = (
bond.connection_members[0].name,
bond.connection_members[1].name,
)
t1, t2 = sorted([t1, t2], key=lambda x: x.name)
bond_type = "-".join((t1.name, t2.name))

unique_bond_types.add(bond_type)
unique_bond_types = sorted(list(unique_bond_types))
gsd_snapshot.bonds.types = unique_bond_types
warnings.warn(
"{} unique bond types detected".format(len(unique_bond_types))
)

bond_typeids = []
warnings.warn(f"{top.n_bonds} bonds detected")
bond_groups = []
for bond in top.bonds:
if isinstance(bond, Bond):
t1, t2 = (
bond.connection_members[0].atom_type,
bond.connection_members[1].atom_type,
)
if t1 is None or t2 is None:
t1, t2 = (
bond.connection_members[0].name,
bond.connection_members[1].name,
)
t1, t2 = sorted([t1, t2], key=lambda x: x.name)

bond_type = "-".join((t1.name, t2.name))
bond_typeids.append(unique_bond_types.index(bond_type))
bond_groups.append(
(
top.sites.index(bond.connection_members[0]),
top.sites.index(bond.connection_members[1]),
)
)
bond_typeids = []
bond_types = []

for bond in top.bonds:
t1, t2 = list(bond.connection_members)
if all([t1.atom_type, t2.atom_type]):
_t1 = t1.atom_type.name
_t2 = t2.atom_type.name
else:
_t1 = t1.name
_t2 = t2.name
_t1, _t2 = sorted([_t1, _t2], key=lambda x: x)
bond_type = "-".join((_t1, _t2))
bond_types.append(bond_type)
bond_groups.append(sorted([top.sites.index(t1), top.sites.index(t2)]))

unique_bond_types = list(set(bond_types))
bond_typeids = [unique_bond_types.index(i) for i in bond_types]
gsd_snapshot.bonds.types = unique_bond_types
gsd_snapshot.bonds.typeid = bond_typeids
gsd_snapshot.bonds.group = bond_groups
warnings.warn(f"{len(unique_bond_types)} unique bond types detected")


def _write_angle_information(gsd_snapshot, structure):
def _write_angle_information(gsd_snapshot, top):
"""Write the angles in the system.
Parameters
----------
gsd_snapshot :
The file object of the GSD file being written
structure : parmed.Structure
Parmed structure object holding system information
Warnings
--------
Not yet implemented for gmso.core.topology objects
top : gmso.Topology
Topology object holding system information
"""
# gsd_snapshot.angles.N = len(structure.angles)

# unique_angle_types = set()
# for angle in structure.angles:
# t1, t2, t3 = angle.atom1.type, angle.atom2.type, angle.atom3.type
# t1, t3 = sorted([t1, t3], key=natural_sort)
# angle_type = ('-'.join((t1, t2, t3)))
# unique_angle_types.add(angle_type)
# unique_angle_types = sorted(list(unique_angle_types), key=natural_sort)
# gsd_snapshot.angles.types = unique_angle_types

# angle_typeids = []
# angle_groups = []
# for angle in structure.angles:
# t1, t2, t3 = angle.atom1.type, angle.atom2.type, angle.atom3.type
# t1, t3 = sorted([t1, t3], key=natural_sort)
# angle_type = ('-'.join((t1, t2, t3)))
# angle_typeids.append(unique_angle_types.index(angle_type))
# angle_groups.append((angle.atom1.idx, angle.atom2.idx,
# angle.atom3.idx))

# gsd_snapshot.angles.typeid = angle_typeids
# gsd_snapshot.angles.group = angle_groups
pass


def _write_dihedral_information(gsd_snapshot, structure):
gsd_snapshot.angles.N = top.n_angles
unique_angle_types = set()
angle_typeids = []
angle_groups = []
angle_types = []

for angle in top.angles:
t1, t2, t3 = list(angle.connection_members)
if all([t1.atom_type, t2.atom_type, t3.atom_type]):
_t1, _t3 = sorted(
[t1.atom_type.name, t3.atom_type.name], key=natural_sort
)
_t2 = t2.atom_type.name
else:
_t1, _t3 = sorted([t1.name, t3.name], key=natural_sort)
_t2 = t2.name

angle_type = "-".join((_t1, _t2, _t3))
angle_types.append(angle_type)
angle_groups.append(
(top.sites.index(t1), top.sites.index(t2), top.sites.index(t3))
)

unique_angle_types = list(set(angle_types))
angle_typeids = [unique_angle_types.index(i) for i in angle_types]
gsd_snapshot.angles.types = unique_angle_types
gsd_snapshot.angles.typeid = angle_typeids
gsd_snapshot.angles.group = angle_groups

warnings.warn(f"{top.n_angles} angles detected")
warnings.warn(f"{len(unique_angle_types)} unique angle types detected")


def _write_dihedral_information(gsd_snapshot, top):
"""Write the dihedrals in the system.
Parameters
----------
gsd_snapshot :
The file object of the GSD file being written
structure : parmed.Structure
Parmed structure object holding system information
Warnings
--------
Not yet implemented for gmso.core.topology objects
top : gmso.Topology
Topology object holding system information
"""
# gsd_snapshot.dihedrals.N = len(structure.rb_torsions)

# unique_dihedral_types = set()
# for dihedral in structure.rb_torsions:
# t1, t2 = dihedral.atom1.type, dihedral.atom2.type
# t3, t4 = dihedral.atom3.type, dihedral.atom4.type
# if [t2, t3] == sorted([t2, t3], key=natural_sort):
# dihedral_type = ('-'.join((t1, t2, t3, t4)))
# else:
# dihedral_type = ('-'.join((t4, t3, t2, t1)))
# unique_dihedral_types.add(dihedral_type)
# unique_dihedral_types = sorted(list(unique_dihedral_types), key=natural_sort)
# gsd_snapshot.dihedrals.types = unique_dihedral_types

# dihedral_typeids = []
# dihedral_groups = []
# for dihedral in structure.rb_torsions:
# t1, t2 = dihedral.atom1.type, dihedral.atom2.type
# t3, t4 = dihedral.atom3.type, dihedral.atom4.type
# if [t2, t3] == sorted([t2, t3], key=natural_sort):
# dihedral_type = ('-'.join((t1, t2, t3, t4)))
# else:
# dihedral_type = ('-'.join((t4, t3, t2, t1)))
# dihedral_typeids.append(unique_dihedral_types.index(dihedral_type))
# dihedral_groups.append((dihedral.atom1.idx, dihedral.atom2.idx,
# dihedral.atom3.idx, dihedral.atom4.idx))

# gsd_snapshot.dihedrals.typeid = dihedral_typeids
# gsd_snapshot.dihedrals.group = dihedral_groups
pass
gsd_snapshot.dihedrals.N = top.n_dihedrals
dihedral_groups = []
dihedral_types = []

for dihedral in top.dihedrals:
t1, t2, t3, t4 = list(dihedral.connection_members)
if all([t.atom_type for t in [t1, t2, t3, t4]]):
_t1, _t4 = sorted(
[t1.atom_type.name, t4.atom_type.name], key=natural_sort
)
_t3 = t3.atom_type.name
_t2 = t2.atom_type.name
else:
_t1, _t4 = sorted([t1.name, t4.name], key=natural_sort)
_t2 = t2.name
_t3 = t3.name

if [_t2, _t3] == sorted([_t2, _t3], key=natural_sort):
dihedral_type = "-".join((_t1, _t2, _t3, _t4))
else:
dihedral_type = "-".join((_t4, _t3, _t2, _t1))

dihedral_types.append(dihedral_type)
dihedral_groups.append(
(
top.sites.index(t1),
top.sites.index(t2),
top.sites.index(t3),
top.sites.index(t4),
)
)

unique_dihedral_types = list(set(dihedral_types))
dihedral_typeids = [unique_dihedral_types.index(i) for i in dihedral_types]
gsd_snapshot.dihedrals.types = unique_dihedral_types
gsd_snapshot.dihedrals.typeid = dihedral_typeids
gsd_snapshot.dihedrals.group = dihedral_groups

warnings.warn(f"{top.n_dihedrals} dihedrals detected")
warnings.warn(
f"{len(unique_dihedral_types)} unique dihedral types detected"
)


def _prepare_box_information(top):
Expand Down
27 changes: 23 additions & 4 deletions gmso/tests/test_gsd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import gsd.hoomd
import mbuild as mb
import pytest
import unyt as u

from gmso.external.convert_mbuild import from_mbuild
from gmso.external.convert_parmed import from_parmed
from gmso.tests.base_test import BaseTest
from gmso.utils.io import get_fn, has_gsd, has_parmed, import_
Expand All @@ -12,18 +15,34 @@
@pytest.mark.skipif(not has_gsd, reason="gsd is not installed")
@pytest.mark.skipif(not has_parmed, reason="ParmEd is not installed")
class TestGsd(BaseTest):
# TODO: Have these tests not depend on parmed
def test_write_gsd(self):
def test_write_gsd_untyped(self):
comp = mb.load("CCCC", smiles=True)
system = mb.fill_box(comp, n_compounds=3, density=100)
top = from_mbuild(system)
top.identify_connections()
top.save("out.gsd")
with gsd.hoomd.open("out.gsd") as traj:
snap = traj[0]
assert all([i in snap.particles.types for i in ["C", "H"]])
assert all([i in snap.bonds.types for i in ["C-C", "C-H"]])
assert all([i in snap.angles.types for i in ["C-C-C", "C-C-H"]])
assert all(
[i in snap.dihedrals.types for i in ["C-C-C-C", "C-C-C-H"]]
)

def test_write_gsd(self, hierarchical_compound):
top = from_mbuild(hierarchical_compound)
top.save("out.gsd")

def test_write_gsd_pmd(self):
top = from_parmed(
pmd.load_file(get_fn("ethane.top"), xyz=get_fn("ethane.gro"))
)

top.save("out.gsd")

def test_write_gsd_non_orthogonal(self):
top = from_parmed(
pmd.load_file(get_fn("ethane.top"), xyz=get_fn("ethane.gro"))
)
top.box.angles = u.degree * [90, 90, 120]

top.save("out.gsd")

0 comments on commit dfda87f

Please sign in to comment.