Skip to content

Commit

Permalink
Merge pull request #112 from bartvm/pickling_tests
Browse files Browse the repository at this point in the history
Serializable datasets
  • Loading branch information
bartvm committed Jan 18, 2015
2 parents 5fa63bc + e399e9e commit 3f4d5e4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
8 changes: 5 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ install:
- |
if [ $TESTS ]; then
conda install -q --yes python=$TRAVIS_PYTHON_VERSION nose numpy pip coverage six scipy
pip install -q --no-deps git+git://github.com/Theano/Theano.git
pip install -q --no-deps git+git://github.com/Theano/Theano.git # Development version
pip install -q nose2[coverage-plugin] coveralls
git clone -q git://github.com/lisa-lab/pylearn2.git
(cd pylearn2
(git clone -q git://github.com/lisa-lab/pylearn2.git # Pylearn2 doesn't support pip
cd pylearn2
python setup.py -q develop)
fi
- |
Expand All @@ -48,6 +48,8 @@ install:
script:
- |
if [ $TESTS ]; then
python setup.py -q install # Tests setup.py (also installs dill)
# Must export environment variable so that the subprocess is aware of it
export THEANO_FLAGS=floatX=$TESTS,blas.ldflags='-lblas -lgfortran'
# Running nose2 within coverage makes imports count towards coverage
coverage run --source=blocks -m nose2.__main__ tests
Expand Down
22 changes: 20 additions & 2 deletions blocks/datasets/schemes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from abc import ABCMeta, abstractmethod

import six
from six import add_metaclass


Expand Down Expand Up @@ -91,5 +92,22 @@ def __init__(self, num_examples, batch_size):
self.batch_size = batch_size

def get_request_iterator(self):
return (slice(x, min(self.num_examples, x + self.batch_size))
for x in range(0, self.num_examples, self.batch_size))
return SequentialIterator(self.num_examples, self.batch_size)


class SequentialIterator(six.Iterator):
def __init__(self, num_examples, batch_size):
self.num_examples = num_examples
self.batch_size = batch_size
self.current = 0

def __iter__(self):
return self

def __next__(self):
if self.current >= self.num_examples:
raise StopIteration
slice_ = slice(self.current, min(self.num_examples,
self.current + self.batch_size))
self.current += self.batch_size
return slice_
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
],
keywords='theano machine learning neural networks deep learning',
packages=find_packages(exclude=['docs', 'tests']),
install_requires=['numpy', 'theano', 'six'],
install_requires=['dill', 'numpy', 'theano', 'six'],
extras_require={
'test': ['nose', 'nose2'],
},
Expand Down
38 changes: 38 additions & 0 deletions tests/datasets/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import dill
import numpy

from blocks.datasets import DataStream
from blocks.datasets.mnist import MNIST
from blocks.datasets.schemes import SequentialScheme


def test_in_memory():
# Load MNIST and get two batches
mnist = MNIST('train')
data_stream = DataStream(mnist, iteration_scheme=SequentialScheme(
num_examples=mnist.num_examples, batch_size=256))
epoch = data_stream.get_epoch_iterator()
for i, (features, targets) in enumerate(epoch):
if i == 1:
break
assert numpy.all(features == mnist.data['features'][256:512])

# Pickle the epoch and make sure that the data wasn't dumped
filename = 'epoch_test.pkl'
assert not os.path.exists(filename)
with open(filename, 'wb') as f:
dill.dump(epoch, f, protocol=dill.HIGHEST_PROTOCOL)
try:
assert os.path.getsize(filename) < 1024 * 1024 # Less than 1MB

# Reload the epoch and make sure that the state was maintained
del epoch
with open(filename, 'rb') as f:
epoch = dill.load(f)
features, targets = next(epoch)
assert numpy.all(features == mnist.data['features'][512:768])
finally:
# Clean up
os.remove(filename)

0 comments on commit 3f4d5e4

Please sign in to comment.