Skip to content

Commit

Permalink
Further simplify AtomContact
Browse files Browse the repository at this point in the history
  • Loading branch information
leeping committed Apr 1, 2019
1 parent f34c93f commit 4bc53ef
Showing 1 changed file with 20 additions and 31 deletions.
51 changes: 20 additions & 31 deletions geometric/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,41 +961,33 @@ def AtomContact(xyz, pairs, box=None, displace=False):
Parameters
----------
xyz : list or np.ndarray
If a list, must be a list of N_atoms*3 arrays of atom positions.
If an array, must be either a N_atoms*3 (2D) or N_frames*N_atoms*3 (3D).
xyz : np.ndarray
N_frames*N_atoms*3 (3D) array of atomic positions
If you only have a single set of positions, pass in xyz[np.newaxis, :]
pairs : list
List of 2-tuples of atom indices
box : np.ndarray, optional
An array of three numbers (xyz box vectors).
N_frames*3 (2D) array of periodic box vectors
If you only have a single set of positions, pass in box[np.newaxis, :]
displace : bool
If True, also return N_frames*N_pairs*3 array of displacement vectors
Returns
-------
np.ndarray
A Npairs-length array of minimum image convention distances
N_pairs*N_frames (2D) array of minimum image convention distances
np.ndarray (optional)
if displace=True, return a Npairsx3 array of displacement vectors
if displace=True, N_frames*N_pairs*3 array of displacement vectors
"""
if type(xyz) is list:
xyz = np.array(xyz)
# Obtain atom selections for atom pairs
parray = np.array(pairs)
sel1 = parray[:,0]
sel2 = parray[:,1]
if len(xyz.shape) not in (2, 3):
raise RuntimeError("Only provide positions in dimension (traj_length, N_atoms, 3) or (N_atoms, 3)")
single = (len(xyz.shape) == 2)
if single:
xyzpbc = xyz[np.newaxis,:,:].copy()
if box is not None:
box = box[np.newaxis,:].copy()
else:
xyzpbc = xyz.copy()
xyzpbc = xyz.copy()
# Minimum image convention: Place all atoms in the box
# [0, xbox); [0, ybox); [0, zbox)
if box is not None:
box = box[:,np.newaxis,:]
xyzpbc /= box
xyzpbc /= box[:,np.newaxis,:]
xyzpbc = xyzpbc % 1.0
# Obtain atom selections for the pairs to be computed
# These are typically longer than N but shorter than N^2.
Expand All @@ -1006,12 +998,9 @@ def AtomContact(xyz, pairs, box=None, displace=False):
# Apply minimum image convention to displacements
if box is not None:
dxyz = np.mod(dxyz+0.5, 1.0) - 0.5
dxyz *= box
dxyz *= box[:,np.newaxis,:]
dr2 = np.sum(dxyz**2,axis=2)
dr = np.sqrt(dr2)
if single:
dr = dr[0]
dxyz = dxyz[0]
if displace:
return dr, dxyz
else:
Expand Down Expand Up @@ -2072,9 +2061,9 @@ def build_bonds(self):
BondThresh = (BT0+BT1) * Fac
BondThresh = (BondThresh > mindist) * BondThresh + (BondThresh < mindist) * mindist
if hasattr(self, 'boxes') and toppbc:
dxij = AtomContact(self.xyzs[sn], AtomIterator, box=np.array([self.boxes[sn].a, self.boxes[sn].b, self.boxes[sn].c]))
dxij = AtomContact(self.xyzs[sn][np.newaxis, :], AtomIterator, box=np.array([[self.boxes[sn].a, self.boxes[sn].b, self.boxes[sn].c]]))[0]
else:
dxij = AtomContact(self.xyzs[sn], AtomIterator)
dxij = AtomContact(self.xyzs[sn][np.newaxis, :], AtomIterator)[0]

# Update topology settings with what we learned
self.top_settings['toppbc'] = toppbc
Expand Down Expand Up @@ -2160,9 +2149,9 @@ def distance_matrix(self, pbc=True):
np.fromiter(itertools.chain(*[range(i+1,self.na) for i in range(self.na)]),dtype=np.int32))).T)
if hasattr(self, 'boxes') and pbc:
boxes = np.array([[self.boxes[i].a, self.boxes[i].b, self.boxes[i].c] for i in range(len(self))])
drij = AtomContact(self.xyzs, AtomIterator, box=boxes)
drij = AtomContact(np.array(self.xyzs), AtomIterator, box=boxes)
else:
drij = AtomContact(self.xyzs, AtomIterator)
drij = AtomContact(np.array(self.xyzs), AtomIterator)
return AtomIterator, list(drij)

def distance_displacement(self):
Expand All @@ -2171,9 +2160,9 @@ def distance_displacement(self):
np.fromiter(itertools.chain(*[range(i+1,self.na) for i in range(self.na)]),dtype=np.int32))).T)
if hasattr(self, 'boxes') and pbc:
boxes = np.array([[self.boxes[i].a, self.boxes[i].b, self.boxes[i].c] for i in range(len(self))])
drij, dxij = AtomContact(self.xyzs, AtomIterator, box=boxes, displace=True)
drij, dxij = AtomContact(np.array(self.xyzs), AtomIterator, box=boxes, displace=True)
else:
drij, dxij = AtomContact(self.xyzs, AtomIterator, box=None, displace=True)
drij, dxij = AtomContact(np.array(self.xyzs), AtomIterator, box=None, displace=True)
return AtomIterator, list(drij), list(dxij)

def rotate_bond(self, frame, aj, ak, increment=15):
Expand Down Expand Up @@ -2306,9 +2295,9 @@ def find_clashes(self, thre=0.0, pbc=True, groups=None):
clashDists_frames = []
if hasattr(self, 'boxes') and pbc:
boxes = np.array([[self.boxes[i].a, self.boxes[i].b, self.boxes[i].c] for i in range(len(self))])
drij = AtomContact(self.xyzs, AtomIterator_nb, box=boxes)
drij = AtomContact(np.array(self.xyzs), AtomIterator_nb, box=boxes)
else:
drij = AtomContact(self.xyzs, AtomIterator_nb)
drij = AtomContact(np.array(self.xyzs), AtomIterator_nb)
for frame in range(len(self)):
clashPairIdx = np.where(drij[frame] < thre)[0]
clashPairs = AtomIterator_nb[clashPairIdx]
Expand Down

0 comments on commit 4bc53ef

Please sign in to comment.