Skip to content

Commit

Permalink
Merge pull request #146 from uppittu11/dihedral_class
Browse files Browse the repository at this point in the history
Add dihedral class
  • Loading branch information
mattwthompson committed Oct 11, 2019
2 parents 023c965 + bdca0f5 commit 393b08b
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 48 deletions.
3 changes: 3 additions & 0 deletions topology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from .core.connection import Connection
from .core.bond import Bond
from .core.angle import Angle
from .core.dihedral import Dihedral


from .core.potential import Potential
from .core.atom_type import AtomType
from .core.bond_type import BondType
from .core.angle_type import AngleType
from .core.dihedral_type import DihedralType
65 changes: 65 additions & 0 deletions topology/core/dihedral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import warnings

from topology.core.connection import Connection
from topology.core.dihedral_type import DihedralType
from topology.exceptions import TopologyError


class Dihedral(Connection):
"""A 4-partner connection between sites.
Partners
--------
connection_members: list of topology.Site
Should be length 4
connection_type : topology.DihedralType
name : name of the dihedral
inherits the name attribute from Connection
Notes
-----
Inherits some methods from Connection:
__eq__, __repr__, _validate methods
Addiitonal _validate methods are presented
"""

def __init__(self, connection_members=[], connection_type=None, name="Dihedral"):
connection_members = _validate_four_partners(connection_members)
connection_type = _validate_dihedraltype(connection_type)

super(Dihedral, self).__init__(connection_members=connection_members,
connection_type=connection_type, name=name)

def __eq__(self, other):
return hash(self) == hash(other)

def __hash__(self):
if self.connection_type:
return hash(
tuple(
(
self.name,
self.connection_type,
tuple(self.connection_members),
)
)
)
return hash(tuple(self.connection_members))


def _validate_four_partners(connection_members):
"""Ensure 4 partners are involved in Dihedral"""
if len(connection_members) != 4:
raise TopologyError("Trying to create an Dihedral "
"with {} connection members". format(len(connection_members)))

return connection_members


def _validate_dihedraltype(contype):
"""Ensure connection_type is a DihedralType """
if contype is None:
warnings.warn("Non-parametrized Dihedral detected")
elif not isinstance(contype, DihedralType):
raise TopologyError("Supplied non-DihedralType {}".format(contype))
return contype
67 changes: 67 additions & 0 deletions topology/core/dihedral_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import warnings
import unyt as u

from topology.core.potential import Potential
from topology.exceptions import TopologyError


class DihedralType(Potential):
"""A Potential between 4-bonded partners.
Parameters
----------
name : str
expression : str or sympy.Expression
See `Potential` documentation for more information
parameters : dict {str, unyt.unyt_quantity}
See `Potential` documentation for more information
independent vars : set of str
See `Potential` documentation for more information
member_types : list of topology.AtomType.name (str)
Notes
----
Inherits many functions from topology.Potential:
__eq__, _validate functions
"""

def __init__(self,
name='DihedralType',
expression='k * (1 + cos(n * phi - phi_eq))**2',
parameters={
'k': 1000 * u.Unit('kJ / (deg**2)'),
'phi_eq': 180 * u.deg,
'n': 1*u.dimensionless
},
independent_variables={'phi'},
member_types=[]):

super(DihedralType, self).__init__(name=name, expression=expression,
parameters=parameters, independent_variables=independent_variables)

self._member_types = _validate_four_member_type_names(member_types)

@property
def member_types(self):
return self._member_types

@member_types.setter
def member_types(self, val):
if self.member_types != val:
warnings.warn("Changing an DihedralType's constituent "
"member types: {} to {}".format(self.member_types, val))
self._member_types = _validate_four_member_type_names(val)

def __repr__(self):
return "<DihedralType {}, id {}>".format(self.name, id(self))

def _validate_four_member_type_names(types):
"""Ensure 4 partners are involved in DihedralType"""
if len(types) != 4 and len(types) != 0:
raise TopologyError("Trying to create an DihedralType "
"with {} constituent types". format(len(types)))
if not all([isinstance(t, str) for t in types]):
raise TopologyError("Types passed to DihedralType "
"need to be strings corresponding to AtomType names")

return types
51 changes: 46 additions & 5 deletions topology/core/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

from topology.core.bond import Bond
from topology.core.angle import Angle
from topology.core.dihedral import Dihedral
from topology.core.potential import Potential
from topology.core.bond_type import BondType
from topology.core.angle_type import AngleType
from topology.core.dihedral_type import DihedralType
from topology.exceptions import TopologyError


Expand All @@ -25,18 +27,18 @@ def __init__(self, name="Topology", box=None):
self._name = name
self._box = box
self._sites = IndexedSet()
self._typed = False
self._connections = IndexedSet()
self._bonds = IndexedSet()
self._angles = IndexedSet()

self._dihedrals = IndexedSet()
self._subtops = IndexedSet()

self._typed = False
self._atom_types = IndexedSet()
self._connection_types = IndexedSet()
self._bond_types = IndexedSet()
self._angle_types = IndexedSet()

self._dihedral_types = IndexedSet()
self._combining_rule = 'lorentz'

@property
Expand Down Expand Up @@ -118,6 +120,9 @@ def add_connection(self, connection, update_types=True):
self.update_bonds()
elif isinstance(connection, Angle):
self.update_angles()
elif isinstance(connection, Dihedral):
self.update_dihedrals()

self.update_connections()

if update_types:
Expand All @@ -127,6 +132,8 @@ def add_connection(self, connection, update_types=True):
self.update_bond_types()
elif isinstance(connection, Angle):
self.update_angle_types()
elif isinstance(connection, Dihedral):
self.update_dihedral_types()
self.update_connection_types()

def add_subtopology(self, subtop):
Expand All @@ -151,6 +158,10 @@ def n_bonds(self):
def n_angles(self):
return len(self.angles)

@property
def n_dihedrals(self):
return len(self.dihedrals)

@property
def subtops(self):
return self._subtops
Expand All @@ -175,6 +186,10 @@ def bonds(self):
def angles(self):
return self._angles

@property
def dihedrals(self):
return self._dihedrals

@property
def atom_types(self):
return self._atom_types
Expand All @@ -191,6 +206,10 @@ def bond_types(self):
def angle_types(self):
return self._angle_types

@property
def dihedral_types(self):
return self._dihedral_types

@property
def atom_type_expressions(self):
return list(set([atype.expression for atype in self.atom_types]))
Expand All @@ -207,24 +226,30 @@ def bond_type_expressions(self):
def angle_type_expressions(self):
return list(set([atype.expression for atype in self.angle_types]))

@property
def dihedral_type_expressions(self):
return list(set([atype.expression for atype in self.dihedral_types]))

def update_top(self, update_types=True):
""" Update the entire topology's attributes
Notes
-----
Will update: sites, connections, bonds, angles,
atom_types, connectiontypes, bondtypes, angletypes
Will update: sites, connections, bonds, angles, dihedrals
atom_types, connectiontypes, bondtypes, angletypes, dihedral_types
"""
self.update_sites()
self.update_connections()
self.update_bonds()
self.update_angles()
self.update_dihedrals()

if update_types:
self.update_atom_types()
self.update_connection_types()
self.update_bond_types()
self.update_angle_types()
self.update_dihedral_types()
self.is_typed()

def update_sites(self):
Expand All @@ -251,6 +276,10 @@ def update_angles(self):
""" Rebuild the angle list by filtering through connection list """
self._angles = [a for a in self.connections if isinstance(a, Angle)]

def update_dihedrals(self):
""" Rebuild the dihedral list by filtering through connection list """
self._dihedrals = [d for d in self.connections if isinstance(d, Dihedral)]

def update_atom_types(self):
""" Update the atom types based on the site list """
#self._atom_types = []
Expand Down Expand Up @@ -296,6 +325,18 @@ def update_angle_types(self):
elif a.connection_type not in self.angle_types:
self.angle_types.add(a.connection_type)

def update_dihedral_types(self):
""" Update the dihedral types based on the dihedral list """
#self._dihedral_types = []
for d in self.dihedrals:
if d.connection_type is None:
warnings.warn("Non-parametrized Dihedral {} detected".format(d))
elif not isinstance(d.connection_type, DihedralType):
raise TopologyError("Non-DihedralType {} found in Dihedral {}".format(
d.connection_type, d))
elif d.connection_type not in self.dihedral_types:
self.dihedral_types.add(d.connection_type)

def __repr__(self):
descr = list('<')
descr.append(self.name + ' ')
Expand Down
4 changes: 2 additions & 2 deletions topology/tests/test_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_angle_parametrized(self):
assert site3.n_connections == 1
assert len(connect.connection_members) == 3
assert connect.connection_type is not None
assert connect.name == "angle_name"
assert connect.name == 'angle_name'

def test_angle_fake(self):
site1 = Site(name='site1')
Expand All @@ -67,7 +67,7 @@ def test_angle_constituent_types(self):
site3 = Site(name='site3', position=[1,1,0], atom_type=AtomType(name='C'))
angtype = AngleType(member_types=[site1.atom_type.name, site2.atom_type.name,
site3.atom_type.name])
ang = Angle(connection_members=[site1, site2,site3],
ang = Angle(connection_members=[site1, site2, site3],
connection_type=angtype)
assert 'A' in ang.connection_type.member_types
assert 'B' in ang.connection_type.member_types
Expand Down
2 changes: 1 addition & 1 deletion topology/tests/test_bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_bond_parametrized(self):
assert site2.n_connections == 1
assert len(connect.connection_members) == 2
assert connect.connection_type is not None
assert connect.name == "bond_name"
assert connect.name == 'bond_name'

def test_bond_fake(self):
site1 = Site(name='site1')
Expand Down
Loading

0 comments on commit 393b08b

Please sign in to comment.