Skip to content

Commit

Permalink
added a bunch of docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
lene committed Mar 3, 2016
1 parent f70dae0 commit b3f8daa
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 42 deletions.
124 changes: 99 additions & 25 deletions nn_wtf/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

"""Functions for downloading and reading MNIST data."""
"""Functions for reading image data from files or URLs."""

import os

Expand All @@ -23,6 +23,13 @@


def maybe_download(filename, base_url, work_directory):
"""Download a file from a URL if it is not already present in the work directory.
:param filename: Name of the file online and in work directory.
:param base_url: URL of the downloadable file minus the file name.
:param work_directory: Directory to look for or save the file in.
:return: The path to the (downloaded or already present) file.
"""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
file_path = os.path.join(work_directory, filename)
Expand All @@ -33,59 +40,126 @@ def maybe_download(filename, base_url, work_directory):
return file_path


def read_images_from_file(filename, rows, cols, num_images, depth=1):
def read_one_image_from_file(filename, rows, cols, depth=1):
"""Reads one image from a file.
:param filename: The file containing the image data.
:param rows: Image height.
:param cols: Image width.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <1, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
with open(filename, 'rb') as bytestream:
return images_from_bytestream(bytestream, rows, cols, num_images, depth)
return _one_image_from_bytestream(bytestream, rows, cols, depth)


def read_images_from_files(rows, cols, depth, *filenames):
def read_one_image_from_url(url, rows, cols, depth=1):
"""Reads one image from a URL.
:param url: The URL containing the image data.
:param rows: Image height.
:param cols: Image width.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <1, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
return concatenate_images_from_input_function(read_one_image_from_file, rows, cols, depth, filenames)
with urllib.request.urlopen(url) as bytestream:
return _one_image_from_bytestream(bytestream, rows, cols, depth)


def read_images_from_urls(rows, cols, depth, *urls):
def read_images_from_file(filename, rows, cols, num_images, depth=1):
"""Reads multiple images from a single file.
:param filename: The file containing the image data.
:param rows: Image height.
:param cols: Image width.
:param num_images: Number of images to read.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <num_images, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
return concatenate_images_from_input_function(read_one_image_from_url, rows, cols, depth, urls)


def concatenate_images_from_input_function(input_function, rows, cols, depth, input_resources):
image_data = numpy.concatenate(
[input_function(input_resource, rows, cols, depth) for input_resource in input_resources]
)
return image_data
with open(filename, 'rb') as bytestream:
return images_from_bytestream(bytestream, rows, cols, num_images, depth)


def read_images_from_url(url, rows, cols, num_images, depth=1):
"""Reads multiple images from a single URL.
:param url: The URL containing the image data.
:param rows: Image height.
:param cols: Image width.
:param num_images: Number of images to read.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <num_images, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
with urllib.request.urlopen(url) as bytestream:
return images_from_bytestream(bytestream, rows, cols, num_images, depth)


def images_from_bytestream(bytestream, rows, cols, num_images, depth=1):
def read_images_from_files(rows, cols, depth, *filenames):
"""Reads multiple images from a list of files.
:param filenames: The files containing the image data.
:param rows: Image height.
:param cols: Image width.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <num_images, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
buf = bytestream.read(rows * cols * depth * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
return data.reshape(num_images, rows, cols, depth)
return concatenate_images_from_input_function(read_one_image_from_file, rows, cols, depth, filenames)


def read_one_image_from_file(filename, rows, cols, depth=1):
def read_images_from_urls(rows, cols, depth, *urls):
"""Reads multiple images from a list of URLs.
:param urls: The URLs containing the image data.
:param rows: Image height.
:param cols: Image width.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <num_images, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
with open(filename, 'rb') as bytestream:
return one_image_from_bytestream(bytestream, rows, cols, depth)
return concatenate_images_from_input_function(read_one_image_from_url, rows, cols, depth, urls)


def read_one_image_from_url(url, rows, cols, depth=1):
def images_from_bytestream(bytestream, rows, cols, num_images, depth=1):
"""Reads a number of images from a byte stream.
:param bytestream: The byte stream containing the image data.
:param rows: Image height.
:param cols: Image width.
:param num_images: Number of images to read.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <1, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
with urllib.request.urlopen(url) as bytestream:
return one_image_from_bytestream(bytestream, rows, cols, depth)
buf = bytestream.read(rows * cols * depth * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
return data.reshape(num_images, rows, cols, depth)


def one_image_from_bytestream(bytestream, rows, cols, depth=1):
def _one_image_from_bytestream(bytestream, rows, cols, depth=1):
"""Reads one image from a byte stream.
:param bytestream: The byte stream containing the image data.
:param rows: Image height.
:param cols: Image width.
:param depth: Color depth of the image in bytes.
:return: A numpy.ndarray of shape <1, rows, cols, depth>
"""
_check_describes_image_geometry(rows, cols, depth)
return images_from_bytestream(bytestream, rows, cols, depth)


def concatenate_images_from_input_function(input_function, rows, cols, depth, input_resources):
image_data = numpy.concatenate(
[input_function(input_resource, rows, cols, depth) for input_resource in input_resources]
)
return image_data


def _check_describes_image_geometry(rows, cols, depth):
assert rows > 0
assert cols > 0
Expand Down
79 changes: 62 additions & 17 deletions nn_wtf/mnist_data_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

class MNISTDataSets(DataSets):

"""Data sets (training, validation and test data) containing the MNIST data.
MNIST data are 8-bit grayscale images of size 28x28 pixels, representing
handwritten numbers from 0 to 9.
"""

SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
Expand All @@ -28,13 +34,18 @@ class MNISTDataSets(DataSets):
MNIST_MAGIC_IMAGES = 2051

def __init__(self, train_dir, one_hot=False):
"""Construct the data set, locating and if necessarily downloading the MNIST data.
:param train_dir: Where to store the MNIST data files.
:param one_hot:
"""
self.train_dir = train_dir
self.one_hot = one_hot

train_images = self.get_extracted_data(self.TRAIN_IMAGES, self.extract_images)
train_labels = self.get_extracted_data(self.TRAIN_LABELS, self.extract_labels)
test_images = self.get_extracted_data(self.TEST_IMAGES, self.extract_images)
test_labels = self.get_extracted_data(self.TEST_LABELS, self.extract_labels)
train_images = self._get_extracted_data(self.TRAIN_IMAGES, self._extract_images)
train_labels = self._get_extracted_data(self.TRAIN_LABELS, self._extract_labels)
test_images = self._get_extracted_data(self.TEST_IMAGES, self._extract_images)
test_labels = self._get_extracted_data(self.TEST_LABELS, self._extract_labels)

validation_images = train_images[:self.VALIDATION_SIZE]
validation_labels = train_labels[:self.VALIDATION_SIZE]
Expand All @@ -49,33 +60,67 @@ def __init__(self, train_dir, one_hot=False):

@classmethod
def read_one_image_from_file(cls, filename):
"""Reads one MNIST image from a file.
:param filename: The file containing the image data.
:return: A numpy.ndarray of shape <1, 28, 28, 1>
"""
return read_one_image_from_file(filename, MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE)

@classmethod
def read_one_image_from_url(cls, url):
"""Reads one MNIST image from a URL.
:param url: The URL containing the image data.
:return: A numpy.ndarray of shape <1, 28, 28, 1>
"""
return read_one_image_from_url(url, MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE)

@classmethod
def read_images_from_file(cls, filename, num_images):
"""Reads multiple MNIST images from a single file.
:param filename: The file containing the image data.
:param num_images: Number of images to read.
:return: A numpy.ndarray of shape <num_images, 28, 28, 1>
"""
return read_images_from_file(filename, MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, num_images)

@classmethod
def read_images_from_url(cls, url, num_images):
"""Reads multiple MNIST images from a single URL.
:param url: The URL containing the image data.
:param num_images: Number of images to read.
:return: A numpy.ndarray of shape <num_images, 28, 28, 1>
"""
return read_images_from_url(url, MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, num_images)

@classmethod
def read_images_from_files(cls, *filenames):
"""Reads multiple MNIST images from a list of files.
:param filenames: The files containing the image data.
:return: A numpy.ndarray of shape <num_images, 28, 28, 1>
"""
return read_images_from_files(MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, 1, *filenames)

@classmethod
def read_images_from_urls(cls, *filenames):
return read_images_from_urls(MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, 1, *filenames)
def read_images_from_urls(cls, *urls):
"""Reads multiple MNIST images from a list of URLs.
:param urls: The URLs containing the image data.
:return: A numpy.ndarray of shape <num_images, 28, 28, 1>
"""
return read_images_from_urls(MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, 1, *urls)

############################################################################

def get_extracted_data(self, file_name, extract_function):
def _get_extracted_data(self, file_name, extract_function):
local_file = maybe_download(file_name, self.SOURCE_URL, self.train_dir)
return extract_function(local_file, one_hot=self.one_hot)

def extract_images(self, filename, one_hot=False):
def _extract_images(self, filename, one_hot=False):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
Expand All @@ -89,7 +134,7 @@ def extract_images(self, filename, one_hot=False):
cols = _read32(bytestream)
return images_from_bytestream(bytestream, rows, cols, num_images)

def extract_labels(self, filename, one_hot=False):
def _extract_labels(self, filename, one_hot=False):
"""Extract the labels into a 1D uint8 numpy array [index]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
Expand All @@ -102,17 +147,17 @@ def extract_labels(self, filename, one_hot=False):
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return self.dense_to_one_hot(labels)
return _dense_to_one_hot(labels)
return labels

def dense_to_one_hot(self, labels_dense, num_classes=10):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot

def _dense_to_one_hot(labels_dense, num_classes=10):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot


def _read32(bytestream):
Expand Down

0 comments on commit b3f8daa

Please sign in to comment.