diff --git a/topology/__init__.py b/topology/__init__.py index 020520d25..8e5f2ca33 100644 --- a/topology/__init__.py +++ b/topology/__init__.py @@ -6,8 +6,11 @@ from .core.connection import Connection from .core.bond import Bond from .core.angle import Angle +from .core.dihedral import Dihedral + from .core.potential import Potential from .core.atom_type import AtomType from .core.bond_type import BondType from .core.angle_type import AngleType +from .core.dihedral_type import DihedralType diff --git a/topology/core/dihedral.py b/topology/core/dihedral.py new file mode 100644 index 000000000..c6b713088 --- /dev/null +++ b/topology/core/dihedral.py @@ -0,0 +1,65 @@ +import warnings + +from topology.core.connection import Connection +from topology.core.dihedral_type import DihedralType +from topology.exceptions import TopologyError + + +class Dihedral(Connection): + """A 4-partner connection between sites. + + Partners + -------- + connection_members: list of topology.Site + Should be length 4 + connection_type : topology.DihedralType + name : name of the dihedral + inherits the name attribute from Connection + + Notes + ----- + Inherits some methods from Connection: + __eq__, __repr__, _validate methods + Addiitonal _validate methods are presented + """ + + def __init__(self, connection_members=[], connection_type=None, name="Dihedral"): + connection_members = _validate_four_partners(connection_members) + connection_type = _validate_dihedraltype(connection_type) + + super(Dihedral, self).__init__(connection_members=connection_members, + connection_type=connection_type, name=name) + + def __eq__(self, other): + return hash(self) == hash(other) + + def __hash__(self): + if self.connection_type: + return hash( + tuple( + ( + self.name, + self.connection_type, + tuple(self.connection_members), + ) + ) + ) + return hash(tuple(self.connection_members)) + + +def _validate_four_partners(connection_members): + """Ensure 4 partners are involved in Dihedral""" + if len(connection_members) != 4: + raise TopologyError("Trying to create an Dihedral " + "with {} connection members". format(len(connection_members))) + + return connection_members + + +def _validate_dihedraltype(contype): + """Ensure connection_type is a DihedralType """ + if contype is None: + warnings.warn("Non-parametrized Dihedral detected") + elif not isinstance(contype, DihedralType): + raise TopologyError("Supplied non-DihedralType {}".format(contype)) + return contype diff --git a/topology/core/dihedral_type.py b/topology/core/dihedral_type.py new file mode 100644 index 000000000..dc88bb701 --- /dev/null +++ b/topology/core/dihedral_type.py @@ -0,0 +1,67 @@ +import warnings +import unyt as u + +from topology.core.potential import Potential +from topology.exceptions import TopologyError + + +class DihedralType(Potential): + """A Potential between 4-bonded partners. + + Parameters + ---------- + name : str + expression : str or sympy.Expression + See `Potential` documentation for more information + parameters : dict {str, unyt.unyt_quantity} + See `Potential` documentation for more information + independent vars : set of str + See `Potential` documentation for more information + member_types : list of topology.AtomType.name (str) + + Notes + ---- + Inherits many functions from topology.Potential: + __eq__, _validate functions + """ + + def __init__(self, + name='DihedralType', + expression='k * (1 + cos(n * phi - phi_eq))**2', + parameters={ + 'k': 1000 * u.Unit('kJ / (deg**2)'), + 'phi_eq': 180 * u.deg, + 'n': 1*u.dimensionless + }, + independent_variables={'phi'}, + member_types=[]): + + super(DihedralType, self).__init__(name=name, expression=expression, + parameters=parameters, independent_variables=independent_variables) + + self._member_types = _validate_four_member_type_names(member_types) + + @property + def member_types(self): + return self._member_types + + @member_types.setter + def member_types(self, val): + if self.member_types != val: + warnings.warn("Changing an DihedralType's constituent " + "member types: {} to {}".format(self.member_types, val)) + self._member_types = _validate_four_member_type_names(val) + + def __repr__(self): + return "".format(self.name, id(self)) + +def _validate_four_member_type_names(types): + """Ensure 4 partners are involved in DihedralType""" + if len(types) != 4 and len(types) != 0: + raise TopologyError("Trying to create an DihedralType " + "with {} constituent types". format(len(types))) + if not all([isinstance(t, str) for t in types]): + raise TopologyError("Types passed to DihedralType " + "need to be strings corresponding to AtomType names") + + return types diff --git a/topology/core/topology.py b/topology/core/topology.py index f46d3d4c1..012549cfc 100644 --- a/topology/core/topology.py +++ b/topology/core/topology.py @@ -6,9 +6,11 @@ from topology.core.bond import Bond from topology.core.angle import Angle +from topology.core.dihedral import Dihedral from topology.core.potential import Potential from topology.core.bond_type import BondType from topology.core.angle_type import AngleType +from topology.core.dihedral_type import DihedralType from topology.exceptions import TopologyError @@ -25,18 +27,18 @@ def __init__(self, name="Topology", box=None): self._name = name self._box = box self._sites = IndexedSet() + self._typed = False self._connections = IndexedSet() self._bonds = IndexedSet() self._angles = IndexedSet() - + self._dihedrals = IndexedSet() self._subtops = IndexedSet() - self._typed = False self._atom_types = IndexedSet() self._connection_types = IndexedSet() self._bond_types = IndexedSet() self._angle_types = IndexedSet() - + self._dihedral_types = IndexedSet() self._combining_rule = 'lorentz' @property @@ -118,6 +120,9 @@ def add_connection(self, connection, update_types=True): self.update_bonds() elif isinstance(connection, Angle): self.update_angles() + elif isinstance(connection, Dihedral): + self.update_dihedrals() + self.update_connections() if update_types: @@ -127,6 +132,8 @@ def add_connection(self, connection, update_types=True): self.update_bond_types() elif isinstance(connection, Angle): self.update_angle_types() + elif isinstance(connection, Dihedral): + self.update_dihedral_types() self.update_connection_types() def add_subtopology(self, subtop): @@ -151,6 +158,10 @@ def n_bonds(self): def n_angles(self): return len(self.angles) + @property + def n_dihedrals(self): + return len(self.dihedrals) + @property def subtops(self): return self._subtops @@ -175,6 +186,10 @@ def bonds(self): def angles(self): return self._angles + @property + def dihedrals(self): + return self._dihedrals + @property def atom_types(self): return self._atom_types @@ -191,6 +206,10 @@ def bond_types(self): def angle_types(self): return self._angle_types + @property + def dihedral_types(self): + return self._dihedral_types + @property def atom_type_expressions(self): return list(set([atype.expression for atype in self.atom_types])) @@ -207,24 +226,30 @@ def bond_type_expressions(self): def angle_type_expressions(self): return list(set([atype.expression for atype in self.angle_types])) + @property + def dihedral_type_expressions(self): + return list(set([atype.expression for atype in self.dihedral_types])) + def update_top(self, update_types=True): """ Update the entire topology's attributes Notes ----- - Will update: sites, connections, bonds, angles, - atom_types, connectiontypes, bondtypes, angletypes + Will update: sites, connections, bonds, angles, dihedrals + atom_types, connectiontypes, bondtypes, angletypes, dihedral_types """ self.update_sites() self.update_connections() self.update_bonds() self.update_angles() + self.update_dihedrals() if update_types: self.update_atom_types() self.update_connection_types() self.update_bond_types() self.update_angle_types() + self.update_dihedral_types() self.is_typed() def update_sites(self): @@ -251,6 +276,10 @@ def update_angles(self): """ Rebuild the angle list by filtering through connection list """ self._angles = [a for a in self.connections if isinstance(a, Angle)] + def update_dihedrals(self): + """ Rebuild the dihedral list by filtering through connection list """ + self._dihedrals = [d for d in self.connections if isinstance(d, Dihedral)] + def update_atom_types(self): """ Update the atom types based on the site list """ #self._atom_types = [] @@ -296,6 +325,18 @@ def update_angle_types(self): elif a.connection_type not in self.angle_types: self.angle_types.add(a.connection_type) + def update_dihedral_types(self): + """ Update the dihedral types based on the dihedral list """ + #self._dihedral_types = [] + for d in self.dihedrals: + if d.connection_type is None: + warnings.warn("Non-parametrized Dihedral {} detected".format(d)) + elif not isinstance(d.connection_type, DihedralType): + raise TopologyError("Non-DihedralType {} found in Dihedral {}".format( + d.connection_type, d)) + elif d.connection_type not in self.dihedral_types: + self.dihedral_types.add(d.connection_type) + def __repr__(self): descr = list('<') descr.append(self.name + ' ') diff --git a/topology/tests/test_angle.py b/topology/tests/test_angle.py index b908a7993..0ffbddc01 100644 --- a/topology/tests/test_angle.py +++ b/topology/tests/test_angle.py @@ -44,7 +44,7 @@ def test_angle_parametrized(self): assert site3.n_connections == 1 assert len(connect.connection_members) == 3 assert connect.connection_type is not None - assert connect.name == "angle_name" + assert connect.name == 'angle_name' def test_angle_fake(self): site1 = Site(name='site1') @@ -67,7 +67,7 @@ def test_angle_constituent_types(self): site3 = Site(name='site3', position=[1,1,0], atom_type=AtomType(name='C')) angtype = AngleType(member_types=[site1.atom_type.name, site2.atom_type.name, site3.atom_type.name]) - ang = Angle(connection_members=[site1, site2,site3], + ang = Angle(connection_members=[site1, site2, site3], connection_type=angtype) assert 'A' in ang.connection_type.member_types assert 'B' in ang.connection_type.member_types diff --git a/topology/tests/test_bond.py b/topology/tests/test_bond.py index 8ab011d86..0ec04d270 100644 --- a/topology/tests/test_bond.py +++ b/topology/tests/test_bond.py @@ -38,7 +38,7 @@ def test_bond_parametrized(self): assert site2.n_connections == 1 assert len(connect.connection_members) == 2 assert connect.connection_type is not None - assert connect.name == "bond_name" + assert connect.name == 'bond_name' def test_bond_fake(self): site1 = Site(name='site1') diff --git a/topology/tests/test_dihedral.py b/topology/tests/test_dihedral.py new file mode 100644 index 000000000..b116acb27 --- /dev/null +++ b/topology/tests/test_dihedral.py @@ -0,0 +1,107 @@ +import pytest + +from topology.core.dihedral import Dihedral +from topology.core.dihedral_type import DihedralType +from topology.core.atom_type import AtomType +from topology.core.site import Site +from topology.tests.base_test import BaseTest +from topology.exceptions import TopologyError + + +class TestDihedral(BaseTest): + def test_dihedral_nonparametrized(self): + site1 = Site(name='site1') + site2 = Site(name='site2') + site3 = Site(name='site3') + site4 = Site(name='site4') + + assert site1.n_connections == 0 + assert site2.n_connections == 0 + assert site3.n_connections == 0 + assert site4.n_connections == 0 + + connect = Dihedral(connection_members=[site1, site2, site3, site4]) + + assert site1.n_connections == 1 + assert site2.n_connections == 1 + assert site3.n_connections == 1 + assert site4.n_connections == 1 + assert connect.connection_type is None + + def test_dihedral_parametrized(self): + site1 = Site(name='site1') + site2 = Site(name='site2') + site3 = Site(name='site3') + site4 = Site(name='site4') + + assert site1.n_connections == 0 + assert site2.n_connections == 0 + assert site3.n_connections == 0 + assert site4.n_connections == 0 + dihedral_type = DihedralType() + + connect = Dihedral(connection_members=[site1, site2, site3, site4], + connection_type=dihedral_type, + name='dihedral_name') + + assert site1.n_connections == 1 + assert site2.n_connections == 1 + assert site3.n_connections == 1 + assert site4.n_connections == 1 + assert len(connect.connection_members) == 4 + assert connect.connection_type is not None + assert connect.name == 'dihedral_name' + + def test_dihedral_fake(self): + site1 = Site(name='site1') + site2 = Site(name='site2') + site3 = Site(name='site3') + site4 = Site(name='site4') + with pytest.raises(TopologyError): + Dihedral(connection_members=['fakesite1', 'fakesite2', 4.2]) + + def test_dihedral_fake_dihedraltype(self): + site1 = Site(name='site1') + site2 = Site(name='site2') + site3 = Site(name='site3') + site4 = Site(name='site4') + with pytest.raises(TopologyError): + Dihedral(connection_members=[site1, site2, site3, site4], + connection_type='Fake dihedraltype') + + def test_dihedral_constituent_types(self): + site1 = Site(name='site1', position=[0,0,0], atom_type=AtomType(name='A')) + site2 = Site(name='site2', position=[1,0,0], atom_type=AtomType(name='B')) + site3 = Site(name='site3', position=[1,1,0], atom_type=AtomType(name='C')) + site4 = Site(name='site4', position=[1,1,4], atom_type=AtomType(name='D')) + dihtype = DihedralType(member_types=[site1.atom_type.name, + site2.atom_type.name, + site3.atom_type.name, + site4.atom_type.name]) + dih = Dihedral(connection_members=[site1, site2, site3, site4], + connection_type=dihtype) + assert 'A' in dih.connection_type.member_types + assert 'B' in dih.connection_type.member_types + assert 'C' in dih.connection_type.member_types + assert 'D' in dih.connection_type.member_types + + def test_dihedral_eq(self): + site1 = Site(name='site1', position=[0, 0, 0]) + site2 = Site(name='site2', position=[1, 0, 0]) + site3 = Site(name='site3', position=[1, 1, 0]) + site4 = Site(name='site4', position=[1, 1, 1]) + + ref_dihedral = Dihedral( + connection_members=[site1, site2, site3, site4], + ) + + same_dihedral = Dihedral( + connection_members=[site1, site2, site3, site4], + ) + + diff_dihedral = Dihedral( + connection_members=[site1, site2, site4, site3], + ) + + assert ref_dihedral == same_dihedral + assert ref_dihedral != diff_dihedral \ No newline at end of file diff --git a/topology/tests/test_topology.py b/topology/tests/test_topology.py index f631854eb..a74d99bbb 100644 --- a/topology/tests/test_topology.py +++ b/topology/tests/test_topology.py @@ -10,9 +10,11 @@ from topology.core.site import Site from topology.core.bond import Bond from topology.core.angle import Angle +from topology.core.dihedral import Dihedral from topology.core.atom_type import AtomType from topology.core.bond_type import BondType from topology.core.angle_type import AngleType +from topology.core.dihedral_type import DihedralType from topology.external.convert_parmed import from_parmed from topology.tests.base_test import BaseTest @@ -142,6 +144,21 @@ def test_eq_angles(self): assert ref != bad_angle_type + @pytest.mark.skipif(not has_parmed, reason="ParmEd is not installed") + def test_eq_dihedrals(self): + ref = pmd.load_file(get_fn('ethane.top'), + xyz=get_fn('ethane.gro')) + + missing_dihedral = deepcopy(ref) + missing_dihedral.rb_torsions[0].delete() + + assert ref != missing_dihedral + + bad_dihedral_type = deepcopy(ref) + bad_dihedral_type.rb_torsion_types[0].k = 22 + + assert ref != bad_dihedral_type + @pytest.mark.skipif(not has_parmed, reason="ParmEd is not installed") def test_eq_overall(self): ref = pmd.load_file(get_fn('ethane.top'), @@ -224,14 +241,8 @@ def test_top_update(self): top.add_site(site1) site2 = Site(name='site2', atom_type=atomtype) top.add_site(site2) + assert top.n_sites == 2 - #assert len(top.atom_types) == 0 - #assert len(top.atom_type_expressions) == 0 - #assert top.n_connections == 0 - #assert len(top.connection_types) == 0 - #assert len(top.connection_type_expressions) == 0 - #top.update_atom_types() - #assert top.n_sites == 2 assert len(top.atom_types) == 1 assert len(top.atom_type_expressions) == 1 assert top.n_connections == 0 @@ -243,13 +254,7 @@ def test_top_update(self): connection_12 = Bond(connection_members=[site1, site2], connection_type=ctype) top.add_connection(connection_12) - #assert top.n_sites == 2 - #assert len(top.atom_types) == 1 - #assert len(top.atom_type_expressions) == 1 - #assert top.n_connections == 1 - #assert len(top.connection_types) == 0 - #assert len(top.connection_type_expressions) == 0 - #top.update_connection_types() + assert top.n_sites == 2 assert len(top.atom_types) == 1 assert len(top.atom_type_expressions) == 1 @@ -285,11 +290,7 @@ def test_atomtype_update(self): site2 = Site('b', atom_type=atype2) top.add_site(site1) top.add_site(site2) - #assert top.n_sites == 2 - #assert len(top.atom_types) == 0 - #assert len(top.atom_type_expressions) == 0 - #top.update_atom_types() assert top.n_sites == 2 assert len(top.atom_types) == 2 assert len(top.atom_type_expressions) == 2 @@ -307,17 +308,6 @@ def test_bond_bondtype_update(self): top.add_site(site2) top.add_connection(bond) - #assert top.n_connections == 1 - #assert top.n_bonds == 0 - #assert len(top.bond_types) == 0 - #assert len(top.bond_type_expressions) == 0 - - #top.update_bond_list() - #assert top.n_bonds == 1 - #assert len(top.bond_types) == 0 - #assert len(top.bond_type_expressions) == 0 - - #top.update_bond_types() assert top.n_bonds == 1 assert len(top.bond_types) == 1 assert len(top.bond_type_expressions) == 1 @@ -337,22 +327,33 @@ def test_angle_angletype_update(self): top.add_site(site3) top.add_connection(angle) - #assert top.n_connections == 1 - #assert top.n_angles == 0 - #assert len(top.angle_types) == 0 - #assert len(top.angle_type_expressions) == 0 - - #top.update_angle_list() - #assert top.n_angles == 1 - #assert len(top.angle_types) == 0 - #assert len(top.angle_type_expressions) == 0 - - #top.update_angle_types() assert top.n_angles == 1 assert len(top.angle_types) == 1 assert len(top.angle_type_expressions) == 1 assert len(top.atom_type_expressions) == 2 + def test_dihedral_dihedraltype_update(self): + top = Topology() + + atype1 = AtomType(expression='sigma + epsilon') + atype2 = AtomType(expression='sigma * epsilon') + site1 = Site('a', atom_type=atype1) + site2 = Site('b', atom_type=atype2) + site3 = Site('c', atom_type=atype2) + site4 = Site('d', atom_type=atype1) + atype = DihedralType() + dihedral = Dihedral(connection_members=[site1, site2, site3, site4], connection_type=atype) + top.add_site(site1) + top.add_site(site2) + top.add_site(site3) + top.add_site(site4) + top.add_connection(dihedral) + + assert top.n_dihedrals == 1 + assert len(top.dihedral_types) == 1 + assert len(top.dihedral_type_expressions) == 1 + assert len(top.atom_type_expressions) == 2 + def test_add_subtopology(self): top = Topology() subtop = SubTopology()