From 4bc53eff15346375f11d776c191bc868dc2f12dc Mon Sep 17 00:00:00 2001 From: Lee-Ping Wang Date: Sun, 31 Mar 2019 19:54:08 -0700 Subject: [PATCH] Further simplify AtomContact --- geometric/molecule.py | 51 +++++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/geometric/molecule.py b/geometric/molecule.py index b936abb9..f8b0c922 100644 --- a/geometric/molecule.py +++ b/geometric/molecule.py @@ -961,41 +961,33 @@ def AtomContact(xyz, pairs, box=None, displace=False): Parameters ---------- - xyz : list or np.ndarray - If a list, must be a list of N_atoms*3 arrays of atom positions. - If an array, must be either a N_atoms*3 (2D) or N_frames*N_atoms*3 (3D). + xyz : np.ndarray + N_frames*N_atoms*3 (3D) array of atomic positions + If you only have a single set of positions, pass in xyz[np.newaxis, :] pairs : list List of 2-tuples of atom indices box : np.ndarray, optional - An array of three numbers (xyz box vectors). + N_frames*3 (2D) array of periodic box vectors + If you only have a single set of positions, pass in box[np.newaxis, :] + displace : bool + If True, also return N_frames*N_pairs*3 array of displacement vectors Returns ------- np.ndarray - A Npairs-length array of minimum image convention distances + N_pairs*N_frames (2D) array of minimum image convention distances np.ndarray (optional) - if displace=True, return a Npairsx3 array of displacement vectors + if displace=True, N_frames*N_pairs*3 array of displacement vectors """ - if type(xyz) is list: - xyz = np.array(xyz) # Obtain atom selections for atom pairs parray = np.array(pairs) sel1 = parray[:,0] sel2 = parray[:,1] - if len(xyz.shape) not in (2, 3): - raise RuntimeError("Only provide positions in dimension (traj_length, N_atoms, 3) or (N_atoms, 3)") - single = (len(xyz.shape) == 2) - if single: - xyzpbc = xyz[np.newaxis,:,:].copy() - if box is not None: - box = box[np.newaxis,:].copy() - else: - xyzpbc = xyz.copy() + xyzpbc = xyz.copy() # Minimum image convention: Place all atoms in the box # [0, xbox); [0, ybox); [0, zbox) if box is not None: - box = box[:,np.newaxis,:] - xyzpbc /= box + xyzpbc /= box[:,np.newaxis,:] xyzpbc = xyzpbc % 1.0 # Obtain atom selections for the pairs to be computed # These are typically longer than N but shorter than N^2. @@ -1006,12 +998,9 @@ def AtomContact(xyz, pairs, box=None, displace=False): # Apply minimum image convention to displacements if box is not None: dxyz = np.mod(dxyz+0.5, 1.0) - 0.5 - dxyz *= box + dxyz *= box[:,np.newaxis,:] dr2 = np.sum(dxyz**2,axis=2) dr = np.sqrt(dr2) - if single: - dr = dr[0] - dxyz = dxyz[0] if displace: return dr, dxyz else: @@ -2072,9 +2061,9 @@ def build_bonds(self): BondThresh = (BT0+BT1) * Fac BondThresh = (BondThresh > mindist) * BondThresh + (BondThresh < mindist) * mindist if hasattr(self, 'boxes') and toppbc: - dxij = AtomContact(self.xyzs[sn], AtomIterator, box=np.array([self.boxes[sn].a, self.boxes[sn].b, self.boxes[sn].c])) + dxij = AtomContact(self.xyzs[sn][np.newaxis, :], AtomIterator, box=np.array([[self.boxes[sn].a, self.boxes[sn].b, self.boxes[sn].c]]))[0] else: - dxij = AtomContact(self.xyzs[sn], AtomIterator) + dxij = AtomContact(self.xyzs[sn][np.newaxis, :], AtomIterator)[0] # Update topology settings with what we learned self.top_settings['toppbc'] = toppbc @@ -2160,9 +2149,9 @@ def distance_matrix(self, pbc=True): np.fromiter(itertools.chain(*[range(i+1,self.na) for i in range(self.na)]),dtype=np.int32))).T) if hasattr(self, 'boxes') and pbc: boxes = np.array([[self.boxes[i].a, self.boxes[i].b, self.boxes[i].c] for i in range(len(self))]) - drij = AtomContact(self.xyzs, AtomIterator, box=boxes) + drij = AtomContact(np.array(self.xyzs), AtomIterator, box=boxes) else: - drij = AtomContact(self.xyzs, AtomIterator) + drij = AtomContact(np.array(self.xyzs), AtomIterator) return AtomIterator, list(drij) def distance_displacement(self): @@ -2171,9 +2160,9 @@ def distance_displacement(self): np.fromiter(itertools.chain(*[range(i+1,self.na) for i in range(self.na)]),dtype=np.int32))).T) if hasattr(self, 'boxes') and pbc: boxes = np.array([[self.boxes[i].a, self.boxes[i].b, self.boxes[i].c] for i in range(len(self))]) - drij, dxij = AtomContact(self.xyzs, AtomIterator, box=boxes, displace=True) + drij, dxij = AtomContact(np.array(self.xyzs), AtomIterator, box=boxes, displace=True) else: - drij, dxij = AtomContact(self.xyzs, AtomIterator, box=None, displace=True) + drij, dxij = AtomContact(np.array(self.xyzs), AtomIterator, box=None, displace=True) return AtomIterator, list(drij), list(dxij) def rotate_bond(self, frame, aj, ak, increment=15): @@ -2306,9 +2295,9 @@ def find_clashes(self, thre=0.0, pbc=True, groups=None): clashDists_frames = [] if hasattr(self, 'boxes') and pbc: boxes = np.array([[self.boxes[i].a, self.boxes[i].b, self.boxes[i].c] for i in range(len(self))]) - drij = AtomContact(self.xyzs, AtomIterator_nb, box=boxes) + drij = AtomContact(np.array(self.xyzs), AtomIterator_nb, box=boxes) else: - drij = AtomContact(self.xyzs, AtomIterator_nb) + drij = AtomContact(np.array(self.xyzs), AtomIterator_nb) for frame in range(len(self)): clashPairIdx = np.where(drij[frame] < thre)[0] clashPairs = AtomIterator_nb[clashPairIdx]