-
Notifications
You must be signed in to change notification settings - Fork 5
/
pairing.py
164 lines (133 loc) · 5.45 KB
/
pairing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import numpy as np
import mdtraj as md
import itertools
from scipy.optimize import curve_fit
def calc_pairing(trj, cutoff, names, chunk_size=100,
check_reform=False, normalize=False):
"""Calculate the number of molecular pairs over a trajectory."""
c = np.zeros(shape=(len(trj), 2))
for i, frame in enumerate(trj):
if i % chunk_size == 0:
pairs = build_initial_state(frame, frame_index=0,
names=names, cutoff=cutoff)
# If no pairs were found, set number of pairs to 0
if len(pairs) == 0:
c[i] = [frame.time[0], 0]
continue
for pair in pairs:
# Check unless both pair[2] and check_reform are False
if pair[2] or check_reform:
pair[2] = get_paired_state(frame, pair[0], pair[1],
frame_index=0, cutoff=0.8, use_mdtraj=True)
c[i] = [frame.time[0], np.sum(np.array(pairs)[:, 2])]
hbar = np.zeros(shape=(chunk_size, 2))
for chunk in chunks(c, chunk_size):
if len(chunk) < chunk_size:
# Neglect data in remainder of modulus
continue
hbar[:, 0] = chunk[:, 0] - chunk[0, 0]
hbar[:, 1] += chunk[:, 1]
if normalize:
hbar[:, 1] /= hbar[0, 1]
return hbar
def calc_caging(trj, cutoff, names, chunk_size=100, normalize=False):
"""Calculate the number of molecular cages over a trajectory."""
c = np.zeros(shape=(len(trj), 2))
for i, frame in enumerate(trj):
if i % chunk_size == 0:
cages = build_initial_cages(frame, frame_index=0,
names=names, cutoff=cutoff)
counter = 0
for cage in cages:
# if not check_cage(frame, cage, names):
# cages.remove(cage)
#c[i] = [frame.time[0], len(cages)]
cage = check_cage(frame, cage, names)
if cage[-1]:
counter += 1
c[i] = [frame.time[0], counter]
print(c[i])
hbar = np.zeros(shape=(chunk_size, 2))
for chunk in chunks(c, chunk_size):
if len(chunk) < chunk_size:
continue
hbar[:, 0] = chunk[:, 0] - chunk[0, 0]
hbar[:, 1] += chunk[:, 1]
if normalize:
hbar[:, 1] /= hbar[0, 1]
return hbar
def get_paired_state(trj, id_i, id_j, frame_index=0, cutoff=1, use_mdtraj=True):
"""Check to see if a given pair is still paired."""
if use_mdtraj:
dist = md.compute_distances(traj=trj,
atom_pairs=[(id_i, id_j)],
periodic=True,
opt=True)
else:
dist = np.sqrt(np.sum((trj.xyz[frame_index, id_i] -trj.xyz[frame_index, id_j]) ** 2))
if dist < cutoff:
paired = True
else:
paired = False
return paired
def build_initial_state(trj, names, frame_index=0, cutoff=1):
"""Build initial pair list. See 10.1021/acs.jpclett.5b00003 for a
definition. The re-forming of pairs is supported with a flag."""
atom_ids = [a.index for a in trj.topology.atoms if a.name in names]
pairs = [prod for prod in itertools.combinations(atom_ids, r=2)]
pairs = [list([*pair, False]) for pair in pairs]
for i, pair in enumerate(pairs):
if pair[0] == pair[1]:
continue
pair[2] = get_paired_state(trj, pair[0], pair[1],
frame_index=frame_index, use_mdtraj=True)
pairs = [pair for pair in pairs if pair[2] == True]
return pairs
def build_initial_cages(trj, names, frame_index=0, cutoff=1):
"""Build initial cage list. See 10.1021/acs.jpclett.5b00003 for a
definition. The re-forming of cages is not permitted."""
atom_ids = [a.index for a in trj.topology.atoms if a.name in names]
cages = list()
for id_i in atom_ids:
current_cage = list()
current_cage.append(id_i)
for id_j in atom_ids:
pair_check = get_paired_state(trj, id_i, id_j, frame_index=frame_index,
cutoff=cutoff, use_mdtraj=True)
if pair_check:
current_cage.append(id_j)
if len(current_cage) > 1:
current_cage.append(True)
cages.append(current_cage)
return cages
def check_cage(trj, cage, names):
"""Check if a given cage still meets its defined criteria."""
atom_ids = [a.index for a in trj.topology.atoms if a.name in names]
# Check to see if any ions left the cage
for id_j in cage[1:-2]:
# Verify ions still exist in shell
check = get_paired_state(trj, cage[0], id_j, frame_index=0, cutoff=0.8, use_mdtraj=True)
if not check:
cage[-1] = check
return cage
# See if any new ions entered the shell
for id_k in atom_ids:
if id_k in cage[:-2]:
continue
pair_check = get_paired_state(trj, cage[0], id_k, frame_index=0, cutoff=0.8, use_mdtraj=True)
if pair_check:
cage[-1] = False
return cage
return cage
def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
def stretched_exp(x, a, b):
"""Define a stretched exponential function."""
f = np.exp(-1 * b * x ** a)
return f
def fit(func, t, n_pairs):
"""Fit pairing data to a stretched exponential function"""
popt, pcov = curve_fit(stretched_exp, t, n_pairs)
return popt