Skip to content

Commit

Permalink
Merge pull request #104 from bartvm/serializable_datasets
Browse files Browse the repository at this point in the history
WIP: First version of lazy properties
  • Loading branch information
rizar committed Jan 18, 2015
2 parents a0ce00d + 9b2fade commit 5fa63bc
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ install:
script:
- |
if [ $TESTS ]; then
THEANO_FLAGS=floatX=$TESTS,blas.ldflags='-lblas -lgfortran' \
export THEANO_FLAGS=floatX=$TESTS,blas.ldflags='-lblas -lgfortran'
# Running nose2 within coverage makes imports count towards coverage
coverage run --source=blocks -m nose2.__main__ tests
fi
Expand Down
50 changes: 34 additions & 16 deletions blocks/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,45 @@ class ConfigurationError(Exception):

class Configuration(object):
def __init__(self):
self.config = {}

def load_yaml(self):
if 'BLOCKS_CONFIG' in os.environ:
yaml_file = os.environ['BLOCKS_CONFIG']
else:
yaml_file = os.path.expanduser('~/.blocksrc')
if os.path.isfile(yaml_file):
with open(yaml_file) as f:
self.yaml_settings = yaml.safe_load(f)
else:
self.yaml_settings = {}
self.config = {}
for key, value in yaml.safe_load(f).items():
if key not in self.config:
raise ValueError("Unrecognized config in YAML: {}"
.format(key))
self.config[key]['yaml'] = value

def __getattr__(self, key):
if key not in self.config:
raise ConfigurationError("Unknown configuration: {}".format(key))
if key == 'config' or key not in self.config:
raise AttributeError
config = self.config[key]
if config['env_var'] is not None and config['env_var'] in os.environ:
if 'value' in config:
value = config['value']
elif 'env_var' in config and config['env_var'] in os.environ:
value = os.environ[config['env_var']]
elif key in self.yaml_settings:
value = self.yaml_settings[key]
else:
elif 'yaml' in config:
value = config['yaml']
elif 'default' in config:
value = config['default']
if value is NOT_SET:
else:
raise ConfigurationError("Configuration not set and no default "
"provided: {}.".format(key))
return config['type'](value)

def add_config(self, key, type, default=NOT_SET, env_var=None):
def __setattr__(self, key, value):
if key != 'config' and key in self.config:
self.config[key]['value'] = value
else:
super(Configuration, self).__setattr__(key, value)

def add_config(self, key, type_, default=NOT_SET, env_var=None):
"""Add a configuration setting.
Parameters
Expand All @@ -105,9 +117,15 @@ def add_config(self, key, type, default=NOT_SET, env_var=None):
YAML configuration file.
"""
self.config[key] = {'default': default,
'env_var': env_var,
'type': type}
self.config[key] = {'type': type_}
if env_var is not None:
self.config[key]['env_var'] = env_var
if default is not NOT_SET:
self.config[key]['default'] = default

config = Configuration()
config.add_config('data_path', env_var='BLOCKS_DATA_PATH', type=str)

# Define configuration options
config.add_config('data_path', env_var='BLOCKS_DATA_PATH', type_=str)

config.load_yaml()
156 changes: 156 additions & 0 deletions blocks/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
from abc import ABCMeta, abstractmethod

import numpy
Expand Down Expand Up @@ -130,6 +131,161 @@ def get_default_stream(self):
return DataStream(self, iteration_scheme=self.default_scheme)


class InMemoryDataset(Dataset):
"""Datasets who hold all of their data in memory.
For small datasets like e.g. MNIST it is easiest to simply load the
entire dataset into memory. All data streams will then access the same
data in memory.
Notes
-----
Datasets which hold data in memory must be treated differently when
serializing (saving) the training progress, because it would be very
inefficient to save the data along with the training process. Hence,
in-memory datasets support the :meth:`lazy_properties` decorator. This
decorator creates a series of properties whose values won't be
serialized; instead, their values will be reloaded (e.g. from disk) by
the :meth:`load` function after deserializing the object.
If the files from which the data were loaded are no longer available,
the de-serialization could fail. Hence the reloading of these
properties happens lazily i.e. only when the properties are requested.
This allows the user to intervene and change the location from which
files are loaded after de-serialization, before the :meth:`load` method
is ever called.
>>> import pickle
>>> from blocks.datasets.mnist import MNIST
>>> mnist = MNIST('train')
>>> print("{:,d} KB".format(
... mnist.data['features'].nbytes / 1024)) # doctest: +SKIP
183,750 KB
>>> with open('mnist.pkl', 'wb') as f:
... pickle.dump(mnist, f, protocol=pickle.HIGHEST_PROTOCOL)
You will notice that the dumping of the dataset was relatively quick,
because it didn't attempt to write MNIST to disk. We can now reload it,
and if the data file has not been moved, it will be as if nothing
happened.
>>> with open('mnist.pkl', 'rb') as f:
... mnist = pickle.load(f)
>>> print(mnist.data['features'].shape)
(60000, 784)
However, if the data files can't be found on disk, accessing the data
will fail.
>>> from blocks import config
>>> correct_path = config.data_path
>>> config.data_path = '/non/existing/path'
>>> with open('mnist.pkl', 'rb') as f:
... mnist = pickle.load(f)
>>> print(mnist.data['features'].shape) # doctest: +SKIP
Traceback (most recent call last):
...
FileNotFoundError: [Errno 2] No such file or directory: ...
Because the loading happens lazily, we can still deserialize our
dataset, correct the situation, and then continue.
>>> config.data_path = correct_path
>>> print(mnist.data['features'].shape)
(60000, 784)
.. doctest::
:hide:
>>> import os
>>> os.remove('mnist.pkl')
"""
def load(self):
"""Load data from e.g. the file system.
Any interaction with the outside world e.g. the file system,
database connections, servers, etc. should be done in this method.
This allows datasets to be pickled and unpickled, even in
environments where the original data is unavailable or has changed
position.
"""
pass


def lazy_properties(*lazy_properties):
r"""Decorator to assign lazy properties.
Used to assign "lazy properties" on :class:`InMemoryDataset` classes.
Please see the documentation there for a discussion on what lazy
properties are and why they are needed.
Parameters
----------
\*lazy_properties : strings
The names of the attributes that are lazy.
Notes
-----
The pickling behaviour of the dataset is only overridden if the dataset
does not have a ``__getstate__`` method implemented.
Examples
--------
In order to make sure that attributes are not serialized with the
dataset, and are lazily reloaded by the :meth:`~InMemoryDataset.load`
method after deserialization, use the decorator with the names of the
attributes as an argument.
>>> @lazy_properties('features', 'targets')
... class TestDataset(InMemoryDataset):
... def load(self):
... self.features = range(10 ** 6)
... self.targets = range(10 ** 6)[::-1]
"""
def lazy_property_factory(lazy_property):
"""Create properties that perform lazy loading of attributes."""
def lazy_property_getter(self):
if not hasattr(self, '_' + lazy_property):
self.load()
if not hasattr(self, '_' + lazy_property):
raise ValueError("{} wasn't loaded".format(lazy_property))
return getattr(self, '_' + lazy_property)

def lazy_property_setter(self, value):
setattr(self, '_' + lazy_property, value)

return lazy_property_getter, lazy_property_setter

def wrap_dataset(dataset):
if not issubclass(dataset, InMemoryDataset):
raise ValueError("Only InMemoryDataset supports lazy loading")

# Attach the lazy loading properties to the class
for lazy_property in lazy_properties:
setattr(dataset, lazy_property,
property(*lazy_property_factory(lazy_property)))

# Delete the values of lazy properties when serializing
if not hasattr(dataset, '__getstate__'):
def __getstate__(self):
serializable_state = self.__dict__.copy()
for lazy_property in lazy_properties:
attr = serializable_state.get('_' + lazy_property)
# Iterators would lose their state
if isinstance(attr, collections.Iterator):
raise ValueError("Iterators can't be lazy loaded")
serializable_state.pop('_' + lazy_property, None)
return serializable_state
setattr(dataset, '__getstate__', __getstate__)

return dataset
return wrap_dataset


class ContainerDataset(Dataset):
"""Equips a Python container with the dataset interface.
Expand Down
37 changes: 22 additions & 15 deletions blocks/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
import theano

from blocks import config
from blocks.datasets import Dataset
from blocks.datasets import InMemoryDataset, lazy_properties
from blocks.datasets.schemes import SequentialScheme
from blocks.utils import update_instance


MNIST_IMAGE_MAGIC = 2051
MNIST_LABEL_MAGIC = 2049


class MNIST(Dataset):
@lazy_properties('data')
class MNIST(InMemoryDataset):
"""The MNIST dataset of handwritten digits.
.. todo::
Expand Down Expand Up @@ -49,30 +51,35 @@ class MNIST(Dataset):

def __init__(self, which_set, start=None, stop=None, binary=False,
**kwargs):
if which_set == 'train':
if which_set not in ('train', 'test'):
raise ValueError("MNIST only has a train and test set")
num_examples = (stop if stop else 60000) - (start if start else 0)
default_scheme = SequentialScheme(num_examples, 1)
update_instance(self, locals())
super(MNIST, self).__init__(**kwargs)

def load(self):
if self.which_set == 'train':
data = 'train-images-idx3-ubyte'
labels = 'train-labels-idx1-ubyte'
elif which_set == 'test':
elif self.which_set == 'test':
data = 't10k-images-idx3-ubyte'
labels = 't10k-labels-idx1-ubyte'
else:
raise ValueError("MNIST only has a train and test set")
data_path = os.path.join(config.data_path, 'mnist')
X = read_mnist_images(
os.path.join(data_path, data),
'bool' if binary else theano.config.floatX)[start:stop]
'bool' if self.binary
else theano.config.floatX)[self.start:self.stop]
X = X.reshape((X.shape[0], numpy.prod(X.shape[1:])))
y = read_mnist_labels(
os.path.join(data_path, labels))[start:stop, numpy.newaxis]
self.X, self.y = X, y
self.num_examples = len(X)
self.default_scheme = SequentialScheme(self.num_examples, 1)
super(MNIST, self).__init__(**kwargs)
os.path.join(data_path, labels))[self.start:self.stop,
numpy.newaxis]
self.data = {'features': X, 'targets': y}

def get_data(self, state=None, request=None):
assert state is None
data = dict(zip(('features', 'targets'), (self.X, self.y)))
return tuple(data[source][request] for source in self.sources)
if state is not None:
raise ValueError("MNIST does not have a state")
return tuple(self.data[source][request] for source in self.sources)


def read_mnist_images(filename, dtype=None):
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinxcontrib.napoleon',
'sphinx.ext.todo',
'sphinx.ext.mathjax',
Expand Down
13 changes: 9 additions & 4 deletions tests/datasets/test_mnist.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import numpy
from numpy.testing import assert_raises

import theano

from blocks.datasets.mnist import MNIST


def test_mnist():
mnist_train = MNIST('train', start=20000)
assert len(mnist_train.X) == 40000
assert len(mnist_train.y) == 40000
assert len(mnist_train.data['features']) == 40000
assert len(mnist_train.data['targets']) == 40000
mnist_test = MNIST('test', sources=('targets',))
assert len(mnist_test.X) == 10000
assert len(mnist_test.y) == 10000
assert len(mnist_test.data['features']) == 10000
assert len(mnist_test.data['targets']) == 10000

first_feature, first_target = mnist_train.get_data(request=[0])
assert first_feature.shape == (1, 784)
assert first_feature.dtype is numpy.dtype(theano.config.floatX)
assert first_target.shape == (1, 1)
assert first_target.dtype is numpy.dtype('uint8')

first_target, = mnist_test.get_data(request=[0, 1])
assert first_target.shape == (2, 1)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ def test_config_parser():
config.add_config('config_with_default', int, default='1',
env_var='BLOCKS_CONFIG_TEST')
config.add_config('config_without_default', str)
config.load_yaml()
assert config.data_path == 'yaml_path'
os.environ['BLOCKS_DATA_PATH'] = 'env_path'
assert config.data_path == 'env_path'
assert config.config_with_default == 1
os.environ['BLOCKS_CONFIG_TEST'] = '2'
assert config.config_with_default == 2
assert_raises(ConfigurationError, getattr, config,
assert_raises(AttributeError, getattr, config,
'non_existing_config')
assert_raises(ConfigurationError, getattr, config,
'config_without_default')
config.data_path = 'manual_path'
assert config.data_path == 'manual_path'
config.new_config = 'new_config'
assert config.new_config == 'new_config'
finally:
os.remove(os.environ['BLOCKS_CONFIG'])
os.environ.clear()
Expand Down

0 comments on commit 5fa63bc

Please sign in to comment.