Skip to content

Commit

Permalink
Merge pull request #169 from smcantab/master
Browse files Browse the repository at this point in the history
make fft calls memory safe
  • Loading branch information
kyleabeauchamp committed Jan 30, 2015
2 parents 9dffe6f + 724c19c commit b922878
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions pymbar/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def detectEquilibration(A_t, fast=True, nskip=1):
return (t, g, Neff_max)


def statisticalInefficiency_fft(A_n, mintime=3):
def statisticalInefficiency_fft(A_n, mintime=3, memsafe=True):
"""Compute the (cross) statistical inefficiency of (two) timeseries.
Parameters
Expand All @@ -787,7 +787,10 @@ def statisticalInefficiency_fft(A_n, mintime=3):
The algorithm terminates after computing the correlation time out to mintime when the
correlation function furst goes negative. Note that this time may need to be increased
if there is a strong initial negative peak in the correlation function.
memsafe: bool, optional, default=False
If this function is used several times on arrays of comparable size then one might benefit
from setting this option to False. If set to True then clear np.fft cache to avoid a fast
increase in memory consumption when this function is called on many arrays of different sizes.
Returns
-------
g : np.ndarray,
Expand Down Expand Up @@ -822,6 +825,12 @@ def statisticalInefficiency_fft(A_n, mintime=3):
t_grid = np.arange(N).astype('float')
g_t = 2.0 * C_t * (1.0 - t_grid / float(N))

#make function memory safe by clearing np.fft cache
#this assumes that statsmodels uses np.fft
if memsafe:
np.fft.fftpack._fft_cache.clear()
np.fft.fftpack._real_fft_cache.clear()

try:
ind = np.where((C_t <= 0) & (t_grid > mintime))[0][0]
except IndexError:
Expand Down Expand Up @@ -879,7 +888,7 @@ def detectEquilibration_binary_search(A_t, bs_nodes=10):

for k, t in enumerate(time_grid):
if t < T-1:
g_t[k] = statisticalInefficiency_fft(A_t[t:])
g_t[k] = statisticalInefficiency_fft(A_t[t:], memsafe=True)
Neff_t[k] = (T - t + 1) / g_t[k]

Neff_max = Neff_t.max()
Expand Down

0 comments on commit b922878

Please sign in to comment.