Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix race condition with new FFT cache #7712

Merged
merged 2 commits into from Jun 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 8 additions & 7 deletions numpy/fft/fftpack.py
Expand Up @@ -55,12 +55,13 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti,
raise ValueError("Invalid number of FFT data points (%d) specified."
% n)

try:
# Thread-safety note: We rely on list.pop() here to atomically
# retrieve-and-remove a wsave from the cache. This ensures that no
# other thread can get the same wsave while we're using it.
wsave = fft_cache.setdefault(n, []).pop()
except (IndexError):
# We have to ensure that only a single thread can access a wsave array
# at any given time. Thus we remove it from the cache and insert it
# again after it has been used. Multiple threads might create multiple
# copies of the wsave array. This is intentional and a limitation of
# the current C code.
wsave = fft_cache.pop_twiddle_factors(n)
if wsave is None:
wsave = init_function(n)

if a.shape[axis] != n:
Expand All @@ -86,7 +87,7 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti,
# As soon as we put wsave back into the cache, another thread could pick it
# up and start using it, so we must not do this until after we're
# completely done using it ourselves.
fft_cache[n].append(wsave)
fft_cache.put_twiddle_factors(n, wsave)

return r

Expand Down
93 changes: 65 additions & 28 deletions numpy/fft/helper.py
Expand Up @@ -4,7 +4,8 @@
"""
from __future__ import division, absolute_import, print_function

from collections import OrderedDict
import collections
import threading

from numpy.compat import integer_types
from numpy.core import (
Expand Down Expand Up @@ -228,7 +229,7 @@ def rfftfreq(n, d=1.0):

class _FFTCache(object):
"""
Cache for the FFT init functions as an LRU (least recently used) cache.
Cache for the FFT twiddle factors as an LRU (least recently used) cache.

Parameters
----------
Expand All @@ -250,37 +251,73 @@ class _FFTCache(object):
def __init__(self, max_size_in_mb, max_item_count):
self._max_size_in_bytes = max_size_in_mb * 1024 ** 2
self._max_item_count = max_item_count
# Much simpler than inheriting from it and having to work around
# recursive behaviour.
self._dict = OrderedDict()

def setdefault(self, key, value):
return self._dict.setdefault(key, value)

def __getitem__(self, key):
# pop + add to move it to the end.
value = self._dict.pop(key)
self._dict[key] = value
self._prune_dict()
return value

def __setitem__(self, key, value):
# Just setting is it not enough to move it to the end if it already
# exists.
try:
del self._dict[key]
except:
pass
self._dict[key] = value
self._prune_dict()

def _prune_dict(self):
self._dict = collections.OrderedDict()
self._lock = threading.Lock()

def put_twiddle_factors(self, n, factors):
"""
Store twiddle factors for an FFT of length n in the cache.

Putting multiple twiddle factors for a certain n will store it multiple
times.

Parameters
----------
n : int
Data length for the FFT.
factors : ndarray
The actual twiddle values.
"""
with self._lock:
# Pop + later add to move it to the end for LRU behavior.
# Internally everything is stored in a dictionary whose values are
# lists.
try:
value = self._dict.pop(n)
except KeyError:
value = []
value.append(factors)
self._dict[n] = value
self._prune_cache()

def pop_twiddle_factors(self, n):
"""
Pop twiddle factors for an FFT of length n from the cache.

Will return None if the requested twiddle factors are not available in
the cache.

Parameters
----------
n : int
Data length for the FFT.

Returns
-------
out : ndarray or None
The retrieved twiddle factors if available, else None.
"""
with self._lock:
if n not in self._dict or not self._dict[n]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

try:
    all_values = self._dict.pop(n)
    value = all_values.pop()
except (KeyError, IndexError):
    return None
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the current solution as its IMHO a bit clearer. Performance is also not a concern at all as the maximum size of the dictionary is bounded.

return None
# Pop + later add to move it to the end for LRU behavior.
all_values = self._dict.pop(n)
value = all_values.pop()
# Only put pack if there are still some arrays left in the list.
if all_values:
self._dict[n] = all_values
return value

def _prune_cache(self):
# Always keep at least one item.
while len(self._dict) > 1 and (
len(self._dict) > self._max_item_count or self._check_size()):
self._dict.popitem(last=False)

def _check_size(self):
item_sizes = [_i[0].nbytes for _i in self._dict.values() if _i]
item_sizes = [sum(_j.nbytes for _j in _i)
for _i in self._dict.values() if _i]
if not item_sizes:
return False
max_size = max(self._max_size_in_bytes, 1.5 * max(item_sizes))
return sum(item_sizes) > max_size
110 changes: 51 additions & 59 deletions numpy/fft/tests/test_helper.py
Expand Up @@ -79,86 +79,78 @@ class TestFFTCache(TestCase):

def test_basic_behaviour(self):
c = _FFTCache(max_size_in_mb=1, max_item_count=4)
# Setting
c[1] = [np.ones(2, dtype=np.float32)]
c[2] = [np.zeros(2, dtype=np.float32)]
# Getting
assert_array_almost_equal(c[1][0], np.ones(2, dtype=np.float32))
assert_array_almost_equal(c[2][0], np.zeros(2, dtype=np.float32))
# Setdefault
c.setdefault(1, [np.array([1, 2], dtype=np.float32)])
assert_array_almost_equal(c[1][0], np.ones(2, dtype=np.float32))
c.setdefault(3, [np.array([1, 2], dtype=np.float32)])
assert_array_almost_equal(c[3][0], np.array([1, 2], dtype=np.float32))

self.assertEqual(len(c._dict), 3)

# Put
c.put_twiddle_factors(1, np.ones(2, dtype=np.float32))
c.put_twiddle_factors(2, np.zeros(2, dtype=np.float32))

# Get
assert_array_almost_equal(c.pop_twiddle_factors(1),
np.ones(2, dtype=np.float32))
assert_array_almost_equal(c.pop_twiddle_factors(2),
np.zeros(2, dtype=np.float32))

# Nothing should be left.
self.assertEqual(len(c._dict), 0)

# Now put everything in twice so it can be retrieved once and each will
# still have one item left.
for _ in range(2):
c.put_twiddle_factors(1, np.ones(2, dtype=np.float32))
c.put_twiddle_factors(2, np.zeros(2, dtype=np.float32))
assert_array_almost_equal(c.pop_twiddle_factors(1),
np.ones(2, dtype=np.float32))
assert_array_almost_equal(c.pop_twiddle_factors(2),
np.zeros(2, dtype=np.float32))
self.assertEqual(len(c._dict), 2)

def test_automatic_pruning(self):
# Thats around 2600 single precision samples.
# That's around 2600 single precision samples.
c = _FFTCache(max_size_in_mb=0.01, max_item_count=4)
c[1] = [np.ones(200, dtype=np.float32)]
c[2] = [np.ones(200, dtype=np.float32)]

# Don't raise errors.
c[1], c[2], c[1], c[2]
c.put_twiddle_factors(1, np.ones(200, dtype=np.float32))
c.put_twiddle_factors(2, np.ones(200, dtype=np.float32))
self.assertEqual(list(c._dict.keys()), [1, 2])

# This is larger than the limit but should still be kept.
c[3] = [np.ones(3000, dtype=np.float32)]
# Should exist.
c[1], c[2], c[3]
c.put_twiddle_factors(3, np.ones(3000, dtype=np.float32))
self.assertEqual(list(c._dict.keys()), [1, 2, 3])
# Add one more.
c[4] = [np.ones(3000, dtype=np.float32)]

c.put_twiddle_factors(4, np.ones(3000, dtype=np.float32))
# The other three should no longer exist.
with self.assertRaises(KeyError):
c[1]
with self.assertRaises(KeyError):
c[2]
with self.assertRaises(KeyError):
c[3]
self.assertEqual(list(c._dict.keys()), [4])

# Now test the max item count pruning.
c = _FFTCache(max_size_in_mb=0.01, max_item_count=2)
c[1] = [np.empty(2)]
c[2] = [np.empty(2)]
c.put_twiddle_factors(2, np.empty(2))
c.put_twiddle_factors(1, np.empty(2))
# Can still be accessed.
c[2], c[1]

c[3] = [np.empty(2)]
self.assertEqual(list(c._dict.keys()), [2, 1])

c.put_twiddle_factors(3, np.empty(2))
# 1 and 3 can still be accessed - c[2] has been touched least recently
# and is thus evicted.
c[1], c[3]

with self.assertRaises(KeyError):
c[2]

c[1], c[3]
self.assertEqual(list(c._dict.keys()), [1, 3])

# One last test. We will add a single large item that is slightly
# bigger then the cache size. Some small items can still be added.
c = _FFTCache(max_size_in_mb=0.01, max_item_count=5)
c[1] = [np.ones(3000, dtype=np.float32)]
c[1]
c[2] = [np.ones(2, dtype=np.float32)]
c[3] = [np.ones(2, dtype=np.float32)]
c[4] = [np.ones(2, dtype=np.float32)]
c[1], c[2], c[3], c[4]

# One more big item.
c[5] = [np.ones(3000, dtype=np.float32)]

# c[1] no longer in the cache. Rest still in the cache.
c[2], c[3], c[4], c[5]
with self.assertRaises(KeyError):
c[1]
c.put_twiddle_factors(1, np.ones(3000, dtype=np.float32))
c.put_twiddle_factors(2, np.ones(2, dtype=np.float32))
c.put_twiddle_factors(3, np.ones(2, dtype=np.float32))
c.put_twiddle_factors(4, np.ones(2, dtype=np.float32))
self.assertEqual(list(c._dict.keys()), [1, 2, 3, 4])

# One more big item. This time it is 6 smaller ones but they are
# counted as one big item.
for _ in range(6):
c.put_twiddle_factors(5, np.ones(500, dtype=np.float32))
# '1' no longer in the cache. Rest still in the cache.
self.assertEqual(list(c._dict.keys()), [2, 3, 4, 5])

# Another big item - should now be the only item in the cache.
c[6] = [np.ones(4000, dtype=np.float32)]
for _i in range(1, 6):
with self.assertRaises(KeyError):
c[_i]
c[6]
c.put_twiddle_factors(6, np.ones(4000, dtype=np.float32))
self.assertEqual(list(c._dict.keys()), [6])


if __name__ == "__main__":
Expand Down