Skip to content

Commit

Permalink
added functions to read multiple images from file or URL
Browse files Browse the repository at this point in the history
added test files and tests for reading multiple images from file
  • Loading branch information
lene committed Mar 1, 2016
1 parent 81bc58c commit 365868f
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 10 deletions.
Binary file added nn_wtf/data/7_2.raw
Binary file not shown.
Binary file added nn_wtf/data/7_2_1.raw
Binary file not shown.
Binary file added nn_wtf/data/7_2_1_0.raw
Binary file not shown.
16 changes: 13 additions & 3 deletions nn_wtf/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def maybe_download(filename, base_url, work_directory):
return file_path


def read_images_from_file(filename, rows, cols, num_images, depth=1):
with open(filename, 'rb') as bytestream:
return images_from_bytestream(bytestream, rows, cols, num_images, depth)


def read_images_from_url(url, rows, cols, num_images, depth=1):
with urllib.request.urlopen(url) as bytestream:
return images_from_bytestream(bytestream, rows, cols, num_images, depth)


def images_from_bytestream(bytestream, rows, cols, num_images, depth=1):
buf = bytestream.read(rows * cols * depth * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
Expand All @@ -41,14 +51,14 @@ def images_from_bytestream(bytestream, rows, cols, num_images, depth=1):

def read_one_image_from_file(filename, rows, cols, depth=1):
with open(filename, 'rb') as bytestream:
return read_one_image_from_bytestream(bytestream, rows, cols)
return one_image_from_bytestream(bytestream, rows, cols, depth)


def read_one_image_from_url(url, rows, cols, depth=1):
with urllib.request.urlopen(url) as bytestream:
return read_one_image_from_bytestream(bytestream, rows, cols)
return one_image_from_bytestream(bytestream, rows, cols, depth)


def read_one_image_from_bytestream(bytestream, rows, cols, depth=1):
def one_image_from_bytestream(bytestream, rows, cols, depth=1):
return images_from_bytestream(bytestream, rows, cols, depth)

10 changes: 9 additions & 1 deletion nn_wtf/mnist_data_sets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from nn_wtf.data_sets import DataSets
from nn_wtf.images_labels_data_set import ImagesLabelsDataSet
from nn_wtf.input_data import maybe_download, images_from_bytestream, \
read_one_image_from_file, read_one_image_from_url
read_one_image_from_file, read_one_image_from_url, read_images_from_file, read_images_from_url

import numpy

Expand Down Expand Up @@ -54,6 +54,14 @@ def read_one_image_from_file(cls, filename):
def read_one_image_from_url(cls, url):
return read_one_image_from_url(url, MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE)

@classmethod
def read_images_from_file(cls, filename, num_images):
return read_images_from_file(filename, MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, num_images)

@classmethod
def read_images_from_url(cls, url, num_images):
return read_images_from_url(url, MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, num_images)

def get_extracted_data(self, file_name, extract_function):
local_file = maybe_download(file_name, self.SOURCE_URL, self.train_dir)
return extract_function(local_file, one_hot=self.one_hot)
Expand Down
51 changes: 45 additions & 6 deletions nn_wtf/tests/input_data_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nn_wtf.images_labels_data_set import ImagesLabelsDataSet
from nn_wtf.input_data import read_one_image_from_file
from nn_wtf.input_data import read_images_from_file
from nn_wtf.mnist_data_sets import MNISTDataSets

from nn_wtf.mnist_graph import MNISTGraph
Expand All @@ -21,13 +21,52 @@ def setUp(self):

def test_read_one_image_from_file(self):
self.assertIsInstance(self.data, numpy.ndarray)
self.assertEqual(4, len(self.data.shape))
self.assertEqual(1, self.data.shape[0])
self.assertEqual(MNISTGraph.IMAGE_SIZE, self.data.shape[1])
self.assertEqual(MNISTGraph.IMAGE_SIZE, self.data.shape[2])
self.assertEqual(1, self.data.shape[3])
self._check_is_one_mnist_image(self.data)

def test_image_labels_data_set_from_image(self):
labels = numpy.fromiter([0], dtype=numpy.uint8)
data_set = ImagesLabelsDataSet(self.data, labels)

def test_read_images_from_file_one(self):
data = read_images_from_file(
get_project_root_folder()+'/nn_wtf/data/0.raw',
MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, 1
)
self._check_is_one_mnist_image(data)

def test_read_images_from_file_two(self):
data = read_images_from_file(
get_project_root_folder()+'/nn_wtf/data/7_2.raw',
MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, 2
)
self._check_is_n_mnist_images(2, data)

def test_read_images_from_file_fails_if_file_too_short(self):
with self.assertRaises(ValueError):
data = read_images_from_file(
get_project_root_folder()+'/nn_wtf/data/7_2.raw',
MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, 3
)

def test_read_images_from_file_two_using_mnist_data_sets(self):
data = MNISTDataSets.read_images_from_file(
get_project_root_folder()+'/nn_wtf/data/7_2.raw', 2
)
self._check_is_n_mnist_images(2, data)

def test_read_images_from_file_using_mnist_data_sets_fails_if_file_too_short(self):
with self.assertRaises(ValueError):
data = MNISTDataSets.read_images_from_file(
get_project_root_folder()+'/nn_wtf/data/7_2.raw', 3
)

def _check_is_one_mnist_image(self, data):
self._check_is_n_mnist_images(1, data)

def _check_is_n_mnist_images(self, n, data):
self.assertEqual(4, len(data.shape))
self.assertEqual(n, data.shape[0])
self.assertEqual(MNISTGraph.IMAGE_SIZE, data.shape[1])
self.assertEqual(MNISTGraph.IMAGE_SIZE, data.shape[2])
self.assertEqual(1, data.shape[3])

0 comments on commit 365868f

Please sign in to comment.