Skip to content

Commit

Permalink
added tests for normalize()
Browse files Browse the repository at this point in the history
  • Loading branch information
lene committed Feb 28, 2016
1 parent 4bc12e9 commit 44c217f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
11 changes: 8 additions & 3 deletions nn_wtf/images_labels_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy

__author__ = 'Lene Preuss <lp@sinnwerkstatt.com>'
__author__ = 'Lene Preuss <lene.preuss@gmail.com>'


class ImagesLabelsDataSet(DataSetBase):
Expand Down Expand Up @@ -33,11 +33,16 @@ def normalize(ndarray):
:param ndarray:
:return:
"""
ndarray = ndarray.astype(numpy.float32)
return numpy.multiply(ndarray, 1.0 / 255.0)
assert isinstance(ndarray, numpy.ndarray)
assert ndarray.dtype == numpy.uint8

return numpy.multiply(ndarray.astype(numpy.float32), 1.0/255.0)


def invert(ndarray):
assert isinstance(ndarray, numpy.ndarray)
assert ndarray.dtype == numpy.float32

return numpy.subtract(1.0, ndarray)


Expand Down
31 changes: 28 additions & 3 deletions nn_wtf/tests/images_labels_data_set_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from nn_wtf.images_labels_data_set import ImagesLabelsDataSet
from nn_wtf.images_labels_data_set import ImagesLabelsDataSet, normalize
from .util import create_minimal_input_placeholder

import numpy

import unittest

__author__ = 'Lene Preuss <lp@sinnwerkstatt.com>'
__author__ = 'Lene Preuss <lene.preuss@gmail.com>'
# pylint: disable=missing-docstring

NUM_TRAINING_SAMPLES = 20
Expand Down Expand Up @@ -58,6 +58,21 @@ def test_next_batch_sets_epochs_completed(self):
_, _ = data_set.next_batch(batch_size)
self.assertEqual(1, data_set.epochs_completed)

def test_normalize_dtype(self):
data = create_empty_image_data()
normalized = normalize(data)
self.assertEqual(normalized.dtype, numpy.float32)

def test_normalize_range(self):
data = create_random_image_data(0, 255)
normalized = normalize(data)
self.assertLessEqual(normalized.max(), 1.)
self.assertGreaterEqual(normalized.min(), 0.)

def test_normalize_bad_input(self):
data = create_random_image_data(0, 255).astype(numpy.float32)
with self.assertRaises(AssertionError):
normalize(data)

def _create_empty_data_set():
images = create_empty_image_data()
Expand All @@ -66,7 +81,17 @@ def _create_empty_data_set():


def create_empty_image_data():
buf = [0] * NUM_TRAINING_SAMPLES * IMAGE_WIDTH * IMAGE_HEIGHT
return image_data_from_list([0] * NUM_TRAINING_SAMPLES * IMAGE_WIDTH * IMAGE_HEIGHT)


def create_random_image_data(min_val, max_val):
from random import randrange
return image_data_from_list(
[randrange(min_val, max_val+1) for _ in range(NUM_TRAINING_SAMPLES * IMAGE_WIDTH * IMAGE_HEIGHT)]
)


def image_data_from_list(buf):
data = numpy.fromiter(buf, dtype=numpy.uint8)
return data.reshape(NUM_TRAINING_SAMPLES, IMAGE_WIDTH, IMAGE_HEIGHT, 1)

Expand Down

0 comments on commit 44c217f

Please sign in to comment.