Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move cache context management to the builder #1112

Merged
merged 3 commits into from
Jun 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ Release History
- Bugfixes
- Documentation

2.2.0 (unreleased)
2.1.2 (unreleased)
==================

**Bug fixes**

- The DecoderCache is now more robust when used improperly, and no longer
requires changes to backends in order to use properly.
(`#1112 <https://github.com/nengo/nengo/pull/1112>`_)

2.1.1 (June 24, 2016)
=====================
Expand Down
56 changes: 33 additions & 23 deletions nengo/builder/network.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging

import numpy as np
Expand All @@ -7,6 +8,7 @@
from nengo.network import Network

logger = logging.getLogger(__name__)
nullcontext = contextlib.contextmanager(lambda: (yield))


@Builder.register(Network) # noqa: C901
Expand Down Expand Up @@ -63,29 +65,37 @@ def get_seed(obj, rng):
getattr(obj, 'seed', None) is not None)
model.seeds[obj] = get_seed(obj, rng)

logger.debug("Network step 1: Building ensembles and nodes")
for obj in network.ensembles + network.nodes:
model.build(obj)

logger.debug("Network step 2: Building subnetworks")
for subnetwork in network.networks:
model.build(subnetwork)

logger.debug("Network step 3: Building connections")
for conn in network.connections:
# NB: we do these in the order in which they're defined, and build the
# learning rule in the connection builder. Because learning rules are
# attached to connections, the connection that contains the learning
# rule (and the learning rule) are always built *before* a connection
# that attaches to that learning rule. Therefore, we don't have to
# worry about connection ordering here.
# TODO: Except perhaps if the connection being learned
# is in a subnetwork?
model.build(conn)

logger.debug("Network step 4: Building probes")
for probe in network.probes:
model.build(probe)
# If this is the toplevel network, enter the decoder cache
context = (model.decoder_cache if model.toplevel is network
else nullcontext())
with context:

logger.debug("Network step 1: Building ensembles and nodes")
for obj in network.ensembles + network.nodes:
model.build(obj)

logger.debug("Network step 2: Building subnetworks")
for subnetwork in network.networks:
model.build(subnetwork)

logger.debug("Network step 3: Building connections")
for conn in network.connections:
# NB: we do these in the order in which they're defined, and build
# the learning rule in the connection builder. Because learning
# rules are attached to connections, the connection that contains
# the learning rule (and the learning rule) are always built
# *before* a connection that attaches to that learning rule.
# Therefore, we don't have to worry about connection ordering here.
# TODO: Except perhaps if the connection being learned
# is in a subnetwork?
model.build(conn)

logger.debug("Network step 4: Building probes")
for probe in network.probes:
model.build(probe)

if context is model.decoder_cache:
model.decoder_cache.shrink()

# Unset config
model.config = old_config
Expand Down
174 changes: 92 additions & 82 deletions nengo/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,6 @@ def __init__(self, readonly=False, cache_dir=None):
self._index = None
self._fd = None

def _get_fd(self):
if self._fd is None:
self._fd = open(self._key2path(str(uuid1())), 'wb')
return self._fd

def _close_fd(self):
if self._fd is not None:
self._fd.close()
self._fd = None

def __enter__(self):
try:
self._remove_legacy_files()
Expand All @@ -216,7 +206,73 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self._close_fd()
if self._index is not None:
return self._index.__exit__(exc_type, exc_value, traceback)
rval = self._index.__exit__(exc_type, exc_value, traceback)
self._index = None
return rval

@staticmethod
def get_default_dir():
"""Returns the default location of the cache.

Returns
-------
str
"""
return rc.get('decoder_cache', 'path')

def _close_fd(self):
if self._fd is not None:
self._fd.close()
self._fd = None

def _get_fd(self):
if self._fd is None:
self._fd = open(self._key2path(str(uuid1())), 'wb')
return self._fd

def _check_legacy_file(self):
"""Checks if the legacy file is up to date."""
legacy_file = os.path.join(self.cache_dir, self._LEGACY)
if os.path.exists(legacy_file):
with open(legacy_file, 'r') as lf:
text = lf.read()
try:
lv, pp = tuple(int(x.strip()) for x in text.split('.'))
except ValueError:
# Will be raised with old legacy.txt format
lv = pp = -1
else:
lv = pp = -1
return lv == self._LEGACY_VERSION and pp == self._PICKLE_PROTOCOL

def _remove_legacy_files(self):
"""Remove files from now invalid locations in the cache.

This will not remove any files if a legacy file exists and is
up to date. Once legacy files are removed, a legacy file will be
written to avoid a costly ``os.listdir`` after calling this.
"""
lock_filename = 'legacy.lock'
with FileLock(os.path.join(self.cache_dir, lock_filename)):
if self._check_legacy_file():
return

for f in os.listdir(self.cache_dir):
if f == lock_filename:
continue
path = os.path.join(self.cache_dir, f)
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)

self._write_legacy_file()

def _write_legacy_file(self):
"""Writes a legacy file, indicating that legacy files do not exist."""
legacy_file = os.path.join(self.cache_dir, self._LEGACY)
with open(legacy_file, 'w') as lf:
lf.write("%d.%d\n" % (self._LEGACY_VERSION, self._PICKLE_PROTOCOL))

def get_files(self):
"""Returns all of the files in the cache.
Expand All @@ -232,6 +288,15 @@ def get_files(self):
files.extend(os.path.join(path, f) for f in os.listdir(path))
return files

def get_size(self):
"""Returns the size of the cache with units as a string.

Returns
-------
str
"""
return bytes2human(self.get_size_in_bytes())

def get_size_in_bytes(self):
"""Returns the size of the cache in bytes as an int.

Expand All @@ -243,16 +308,13 @@ def get_size_in_bytes(self):
return sum(byte_align(st.st_size, self._fragment_size)
for st in stats if st is not None)

def get_size(self):
"""Returns the size of the cache with units as a string.

Returns
-------
str
"""
return bytes2human(self.get_size_in_bytes())
def invalidate(self):
"""Invalidates the cache (i.e. removes all cache files)."""
self._close_fd()
for path in self.get_files():
safe_remove(path)

def shrink(self, limit=None):
def shrink(self, limit=None): # noqa: C901
"""Reduces the size of the cache to meet a limit.

Parameters
Expand All @@ -261,6 +323,11 @@ def shrink(self, limit=None):
Maximum size of the cache in bytes.
"""
if self.readonly:
logger.info("Tried to shrink a readonly cache.")
return

if self._index is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._index won't be reset to None in __exit__. Maybe we should do that for consistency? (I think technically shrink would work outside of the with block if there has been one call to __enter__ before even after the __exit__.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we should probably reset self._index to None. It feels like things should only be able to happen inside the context manager; if it's just that kind of one-time thing then it should be in the __init__ (but yes, I know __init__ and __del__ are not feasible for the cache).

warnings.warn("Cannot shrink outside of a `with cache` block.")
return

if limit is None:
Expand Down Expand Up @@ -292,67 +359,7 @@ def shrink(self, limit=None):

self._index.sync()

def invalidate(self):
"""Invalidates the cache (i.e. removes all cache files)."""
self._close_fd()
for path in self.get_files():
safe_remove(path)

def _check_legacy_file(self):
"""Checks if the legacy file is up to date."""
legacy_file = os.path.join(self.cache_dir, self._LEGACY)
if os.path.exists(legacy_file):
with open(legacy_file, 'r') as lf:
text = lf.read()
try:
lv, pp = tuple(int(x.strip()) for x in text.split('.'))
except ValueError:
# Will be raised with old legacy.txt format
lv = pp = -1
else:
lv = pp = -1
return lv == self._LEGACY_VERSION and pp == self._PICKLE_PROTOCOL

def _write_legacy_file(self):
"""Writes a legacy file, indicating that legacy files do not exist."""
legacy_file = os.path.join(self.cache_dir, self._LEGACY)
with open(legacy_file, 'w') as lf:
lf.write("%d.%d\n" % (self._LEGACY_VERSION, self._PICKLE_PROTOCOL))

def _remove_legacy_files(self):
"""Remove files from now invalid locations in the cache.

This will not remove any files if a legacy file exists and is
up to date. Once legacy files are removed, a legacy file will be
written to avoid a costly ``os.listdir`` after calling this.
"""
lock_filename = 'legacy.lock'
with FileLock(os.path.join(self.cache_dir, lock_filename)):
if self._check_legacy_file():
return

for f in os.listdir(self.cache_dir):
if f == lock_filename:
continue
path = os.path.join(self.cache_dir, f)
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)

self._write_legacy_file()

@staticmethod
def get_default_dir():
"""Returns the default location of the cache.

Returns
-------
str
"""
return rc.get('decoder_cache', 'path')

def wrap_solver(self, solver_fn):
def wrap_solver(self, solver_fn): # noqa: C901
"""Takes a decoder solver and wraps it to use caching.

Parameters
Expand Down Expand Up @@ -388,9 +395,12 @@ def cached_solver(solver, neuron_type, gain, bias, x, targets,
solver_info, decoders = nco.read(f)
except:
logger.debug("Cache miss [%s].", key)
if self._index is None:
warnings.warn("Cannot use cached solver outside of "
"`with cache` block.")
decoders, solver_info = solver_fn(
solver, neuron_type, gain, bias, x, targets, rng=rng, E=E)
if not self.readonly:
if not self.readonly and self._index is not None:
fd = self._get_fd()
start = fd.tell()
nco.write(fd, solver_info, decoders)
Expand Down
24 changes: 8 additions & 16 deletions nengo/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,16 @@ class Simulator(object):
def __init__(self, network, dt=0.001, seed=None, model=None):
self.closed = False

if model is None or model.decoder_cache is None:
cache = get_default_decoder_cache()
if model is None:
self.model = Model(dt=float(dt),
label="%s, dt=%f" % (network, dt),
decoder_cache=get_default_decoder_cache())
else:
cache = model.decoder_cache
self.model = model

with cache:
if model is None:
self.model = Model(dt=float(dt),
label="%s, dt=%f" % (network, dt),
decoder_cache=cache)
else:
self.model = model

if network is not None:
# Build the network into the model
self.model.build(network)

cache.shrink()
if network is not None:
# Build the network into the model
self.model.build(network)

# -- map from Signal.base -> ndarray
self.signals = SignalDict()
Expand Down
15 changes: 15 additions & 0 deletions nengo/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nengo.exceptions import FingerprintError
from nengo.solvers import LstsqL2
from nengo.utils.compat import int_types
from nengo.utils.testing import warns


class SolverMock(object):
Expand Down Expand Up @@ -546,3 +547,17 @@ def test_compare_cache_shrink_benchmark(self, analytics_data, plt, logger):

plt.scatter(np.ones_like(d1), d1, c='b')
plt.scatter(2 * np.ones_like(d2), d2, c='g')


def test_warns_out_of_context(tmpdir):
cache_dir = str(tmpdir)
cache = DecoderCache(cache_dir=cache_dir)

with warns(UserWarning):
cache.shrink()

solver_mock = SolverMock()
solver = cache.wrap_solver(solver_mock)
with warns(UserWarning):
solver(**get_solver_test_args())
assert SolverMock.n_calls[solver_mock] == 1
2 changes: 1 addition & 1 deletion nengo/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

name = "nengo"
version_info = (2, 2, 0) # (major, minor, patch)
version_info = (2, 1, 2) # (major, minor, patch)
dev = 0

version = "{v}{dev}".format(v='.'.join(str(v) for v in version_info),
Expand Down