-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from ejhumphrey/dev_20160225_py3compat
Dev 20160225 py3compat
- Loading branch information
Showing
16 changed files
with
1,264 additions
and
186 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ | |
*.pyc | ||
*.so | ||
.ipynb_checkpoints | ||
.cache | ||
.coverage | ||
|
||
# Packages # | ||
############ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
sudo: false | ||
|
||
# addons: | ||
# apt: | ||
# packages: | ||
# - | ||
|
||
cache: | ||
directories: | ||
- $HOME/env | ||
|
||
language: python | ||
|
||
notifications: | ||
email: false | ||
|
||
python: | ||
- "2.7" | ||
- "3.4" | ||
- "3.5" | ||
|
||
before_install: | ||
- bash .travis_dependencies.sh | ||
- export PATH="$HOME/env/miniconda$TRAVIS_PYTHON_VERSION/bin:$PATH"; | ||
- hash -r | ||
- source activate test-environment | ||
|
||
install: | ||
- pip install pytest pytest-cov | ||
- pip install coveralls | ||
- pip install -e . | ||
|
||
script: | ||
- python --version | ||
- py.test -vs --cov=optimus ./tests | ||
|
||
after_success: | ||
- coveralls | ||
- pip uninstall -y optimus | ||
|
||
after_failure: | ||
- pip uninstall -y optimus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#!/bin/sh | ||
|
||
ENV_NAME="test-environment" | ||
set -e | ||
|
||
conda_create () | ||
{ | ||
|
||
hash -r | ||
conda config --set always_yes yes --set changeps1 no | ||
conda update -q conda | ||
conda config --add channels pypi | ||
conda info -a | ||
deps='pip numpy scipy pandas theano coverage matplotlib' | ||
|
||
conda create -q -n $ENV_NAME "python=$TRAVIS_PYTHON_VERSION" $deps | ||
conda update --all | ||
} | ||
|
||
src="$HOME/env/miniconda$TRAVIS_PYTHON_VERSION" | ||
if [ ! -d "$src" ]; then | ||
mkdir -p $HOME/env | ||
pushd $HOME/env | ||
|
||
# Download miniconda packages | ||
wget http://repo.continuum.io/miniconda/Miniconda-3.16.0-Linux-x86_64.sh -O miniconda.sh; | ||
|
||
# Install both environments | ||
bash miniconda.sh -b -p $src | ||
|
||
export PATH="$src/bin:$PATH" | ||
conda_create | ||
|
||
source activate $ENV_NAME | ||
|
||
pip install python-coveralls | ||
|
||
if [ "$ENABLE_NUMBA" = true ]; then | ||
conda install numba | ||
fi | ||
|
||
source deactivate | ||
popd | ||
else | ||
echo "Using cached dependencies" | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import gzip | ||
import numpy as np | ||
import pickle | ||
|
||
|
||
def load_mnist(mnist_file): | ||
"""Load the MNIST dataset into memory. | ||
Parameters | ||
---------- | ||
mnist_file : str | ||
Path to gzipped MNIST file. | ||
Returns | ||
------- | ||
train, valid, test: tuples of np.ndarrays | ||
Each consists of (data, labels), where data.shape=(N, 1, 28, 28) and | ||
labels.shape=(N,). | ||
""" | ||
dsets = [] | ||
with gzip.open(mnist_file, 'rb') as fp: | ||
for split in pickle.load(fp): | ||
n_samples = len(split[1]) | ||
data = np.zeros([n_samples, 1, 28, 28]) | ||
labels = np.zeros([n_samples], dtype=int) | ||
for n, (x, y) in enumerate(zip(*split)): | ||
data[n, ...] = x.reshape(1, 28, 28) | ||
labels[n] = y | ||
dsets.append((data, labels)) | ||
|
||
return dsets | ||
|
||
|
||
def load_mnist_npz(mnist_file): | ||
"""Load the MNIST dataset into memory from an NPZ. | ||
Parameters | ||
---------- | ||
mnist_file : str | ||
Path to an NPZ file of MNIST data. | ||
Returns | ||
------- | ||
train, valid, test: tuples of np.ndarrays | ||
Each consists of (data, labels), where data.shape=(N, 1, 28, 28) and | ||
labels.shape=(N,). | ||
""" | ||
data = np.load(mnist_file) | ||
dsets = [] | ||
for name in 'train', 'valid', 'test': | ||
x = data['x_{}'.format(name)].reshape(-1, 1, 28, 28) | ||
y = data['y_{}'.format(name)] | ||
dsets.append([x, y]) | ||
|
||
return dsets | ||
|
||
|
||
def minibatch(data, labels, batch_size, max_iter=np.inf): | ||
"""Random mini-batches generator. | ||
Parameters | ||
---------- | ||
data : array_like, len=N | ||
Observation data. | ||
labels : array_like, len=N | ||
Labels corresponding the the given data. | ||
batch_size : int | ||
Number of datapoints to return at each iteration. | ||
max_iter : int, default=inf | ||
Number of iterations before raising a StopIteration. | ||
Yields | ||
------ | ||
batch : dict | ||
Random batch of datapoints, under the keys `data` and `labels`. | ||
""" | ||
if len(data) != len(labels): | ||
raise ValueError("data and labels must have the same number of items.") | ||
|
||
num_points = len(labels) | ||
if num_points <= batch_size: | ||
raise ValueError("batch_size cannot exceed number of data points") | ||
|
||
count = 0 | ||
order = np.random.permutation(num_points) | ||
idx = 0 | ||
while count < max_iter: | ||
x, y = [], [] | ||
while len(y) < batch_size: | ||
x.append(data[order[idx]]) | ||
y.append(labels[order[idx]]) | ||
idx += 1 | ||
if idx >= num_points: | ||
idx = 0 | ||
np.random.shuffle(order) | ||
yield dict(data=np.asarray(x), labels=np.asarray(y)) | ||
count += 1 |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.