Skip to content

Commit

Permalink
Merge pull request #98 from bartvm/dataset_tests
Browse files Browse the repository at this point in the history
Fixes MNIST and runs test on Travis
  • Loading branch information
bartvm committed Jan 15, 2015
2 parents f61f51d + ad367ab commit a6169bf
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 12 deletions.
28 changes: 25 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,51 @@ env:
before_install:
- |
if [ $TESTS ]; then
# Setup Python environment with BLAS libraries
sudo apt-get install -qq libatlas3gf-base libatlas-dev liblapack-dev gfortran
wget -q http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh
chmod +x miniconda.sh
./miniconda.sh -b
export PATH=/home/travis/miniconda/bin:$PATH
conda update -q --yes conda
# Download MNIST for tests
(mkdir mnist
cd mnist
curl -O http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz \
-O http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz \
-O http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz \
-O http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
gunzip *-ubyte.gz)
export BLOCKS_DATA_PATH=$PWD
fi
install:
# Install all Python dependencies
- |
if [ $TESTS ]; then
conda install -q --yes python=$TRAVIS_PYTHON_VERSION nose numpy pip coverage six scipy
pip install -q --no-deps git+git://github.com/Theano/Theano.git
pip install -q nose2[coverage-plugin] coveralls
git clone -q git://github.com/lisa-lab/pylearn2.git
(cd pylearn2; python setup.py -q develop)
(cd pylearn2
python setup.py -q develop)
fi
- |
if [ $FORMAT ]; then
pip install -q flake8
pip install -q git+git://github.com/bartvm/pep257.git@numpy
fi
script:
- if [ $TESTS ]; then THEANO_FLAGS=floatX=$TESTS,blas.ldflags='-lblas -lgfortran' coverage run --source=blocks -m nose2.__main__ tests; fi
- if [ $FORMAT ]; then flake8 blocks tests; pep257 blocks --numpy --ignore=D100,D101,D102,D103; fi
- |
if [ $TESTS ]; then
THEANO_FLAGS=floatX=$TESTS,blas.ldflags='-lblas -lgfortran' \
# Running nose2 within coverage makes imports count towards coverage
coverage run --source=blocks -m nose2.__main__ tests
fi
- |
if [ $FORMAT ]; then
flake8 blocks tests
# Ignore D100-103 errors (non-existing docstrings)
pep257 blocks --numpy --ignore=D100,D101,D102,D103
fi
after_script:
- if [ $TESTS ]; then coveralls; fi
1 change: 0 additions & 1 deletion blocks/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, sources=None):
raise ValueError("Unable to provide requested sources")
self.sources = sources

@abstractmethod
def open(self):
"""Return the state if the dataset requires one.
Expand Down
20 changes: 12 additions & 8 deletions blocks/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ class MNIST(Dataset):
The first example to load
stop : int, optional
The last example to load
binary : bool, optional
If ``True``, returns binary (black/white) images instead of
grayscale. ``False`` by default.
"""
sources = ('features', 'targets')

def __init__(self, which_set, start=None, stop=None, **kwargs):
def __init__(self, which_set, start=None, stop=None, binary=False,
**kwargs):
if which_set == 'train':
data = 'train-images-idx3-ubyte'
labels = 'train-labels-idx1-ubyte'
Expand All @@ -53,21 +57,21 @@ def __init__(self, which_set, start=None, stop=None, **kwargs):
labels = 't10k-labels-idx1-ubyte'
else:
raise ValueError("MNIST only has a train and test set")
data_path = os.path.join(config.data_path, 'mnist')
X = read_mnist_images(
os.path.join(config.data_path, data),
theano.config.floatX)[start:stop]
os.path.join(data_path, data),
'bool' if binary else theano.config.floatX)[start:stop]
X = X.reshape((X.shape[0], numpy.prod(X.shape[1:])))
y = read_mnist_labels(
os.path.join(config.data_path, labels))[start:stop, numpy.newaxis]
os.path.join(data_path, labels))[start:stop, numpy.newaxis]
self.X, self.y = X, y
self.num_examples = len(X)
self.default_scheme = SequentialScheme(self.num_examples, 1)
super(MNIST, self).__init__(**kwargs)

def get_data(self, state=None, request=None, sources=None):
data = dict(zip(self.sources, (self.X, self.y)))
sources = self.sources if sources is None else sources
return tuple(data[source][request] for source in sources)
def get_data(self, request=None):
data = dict(zip(('features', 'targets'), (self.X, self.y)))
return tuple(data[source][request] for source in self.sources)


def read_mnist_images(filename, dtype=None):
Expand Down
24 changes: 24 additions & 0 deletions tests/datasets/test_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from numpy.testing import assert_raises

from blocks.datasets.mnist import MNIST


def test_mnist():
mnist_train = MNIST('train', start=20000)
assert len(mnist_train.X) == 40000
assert len(mnist_train.y) == 40000
mnist_test = MNIST('test', sources=('targets',))
assert len(mnist_test.X) == 10000
assert len(mnist_test.y) == 10000

first_feature, first_target = mnist_train.get_data(request=[0])
assert first_feature.shape == (1, 784)
assert first_target.shape == (1, 1)

first_target, = mnist_test.get_data(request=[0, 1])
assert first_target.shape == (2, 1)

binary_mnist = MNIST('test', binary=True, sources=('features',))
first_feature, = binary_mnist.get_data(request=[0])
assert first_feature.dtype.kind == 'b'
assert_raises(ValueError, MNIST, 'valid')

0 comments on commit a6169bf

Please sign in to comment.