Skip to content

Commit

Permalink
got rid of the fake_data flag, it's rubbish
Browse files Browse the repository at this point in the history
  • Loading branch information
lene committed Feb 26, 2016
1 parent a58c510 commit 11fb2f7
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 45 deletions.
31 changes: 8 additions & 23 deletions nn_wtf/images_labels_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,22 @@

class ImagesLabelsDataSet:

def __init__(self, images, labels, fake_data=False, one_hot=False):
def __init__(self, images, labels):
"""Construct a DataSet. one_hot arg is used only if fake_data is true.
Args:
images: 4D numpy.ndarray of shape (num images, image height, image width, image depth)
labels: 1D numpy.ndarray of shape (num images)
"""

if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
else:
_check_constructor_arguments_valid(images, labels)
_check_constructor_arguments_valid(images, labels)

self._num_examples = images.shape[0]
self._num_examples = images.shape[0]

# Convert shape from [num examples, rows, columns, depth] to [num examples, rows*columns]
# TODO: assumes depth == 1
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])
images = normalize(images)
# Convert shape from [num examples, rows, columns, depth] to [num examples, rows*columns]
# TODO: assumes depth == 1
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])
images = normalize(images)

self._images = images
self._labels = labels
Expand All @@ -47,11 +43,8 @@ def num_examples(self):
def epochs_completed(self):
return self._epochs_completed

def next_batch(self, batch_size, fake_data=False):
def next_batch(self, batch_size):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
return self._fake_batch(batch_size)

return self._next_batch_in_epoch(batch_size)

def _next_batch_in_epoch(self, batch_size):
Expand All @@ -75,14 +68,6 @@ def _shuffle_data(self):
self._images = self._images[perm]
self._labels = self._labels[perm]

def _fake_batch(self, batch_size):
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label = 0
return [fake_image for _ in range(batch_size)], [fake_label for _ in range(batch_size)]


def normalize(ndarray):
"""Transform a ndarray that contains uint8 values to floats between 0. and 1.
Expand Down
7 changes: 0 additions & 7 deletions nn_wtf/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,3 @@ def read_data_sets(train_dir, one_hot=False):
ImagesLabelsDataSet(test_images, test_labels)
)


def fake_data_sets(one_hot):
return DataSets(
ImagesLabelsDataSet([], [], fake_data=True, one_hot=one_hot),
ImagesLabelsDataSet([], [], fake_data=True, one_hot=one_hot),
ImagesLabelsDataSet([], [], fake_data=True, one_hot=one_hot)
)
24 changes: 13 additions & 11 deletions nn_wtf/mnist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
self.hidden += (hidden3,)
self.batch_size = batch_size
self.train_dir = ensure_is_dir(train_dir)
self.fake_data = False

self._build_graph()
self._setup_summaries()
Expand All @@ -45,15 +44,7 @@ def train(self, data_sets, max_steps, precision=None, steps_between_checks=100):
for self.step in range(max_steps):
start_time = time.time()

# Fill a feed dictionary with the actual set of images and labels for this particular
# training step.
feed_dict = self.fill_feed_dict(data_sets.train)

# Run one step of the model. The return values are the activations from the `train_op`
# (which is discarded) and the `loss` Op. To inspect the values of your Ops or
# variables, you may include them in the list passed to session.run() and the value
# tensors will be returned in the tuple from the call.
_, loss_value = self.session.run([self.train_op, self.loss], feed_dict=feed_dict)
feed_dict, loss_value = self.run_training_steps(data_sets)

duration = time.time() - start_time

Expand All @@ -70,6 +61,17 @@ def train(self, data_sets, max_steps, precision=None, steps_between_checks=100):
self.saver.save(self.session, save_path=self.train_dir, global_step=self.step)
self.print_evaluations(data_sets)

def run_training_steps(self, data_sets):
# Fill a feed dictionary with the actual set of images and labels for this particular
# training step.
feed_dict = self.fill_feed_dict(data_sets.train)
# Run one step of the model. The return values are the activations from the `train_op`
# (which is discarded) and the `loss` Op. To inspect the values of your Ops or
# variables, you may include them in the list passed to session.run() and the value
# tensors will be returned in the tuple from the call.
_, loss_value = self.session.run([self.train_op, self.loss], feed_dict=feed_dict)
return feed_dict, loss_value

def print_evaluations(self, data_sets):
if self.verbose: print('Training Data Eval:')
self.print_eval(data_sets.train)
Expand Down Expand Up @@ -106,7 +108,7 @@ def fill_feed_dict(self, data_set):
feed_dict: The feed dictionary mapping from placeholders to values.
"""
# Create the feed_dict for the placeholders filled with the next `batch size ` examples.
images_feed, labels_feed = data_set.next_batch(self.batch_size, self.fake_data)
images_feed, labels_feed = data_set.next_batch(self.batch_size)
feed_dict = {
self.images_placeholder: images_feed,
self.labels_placeholder: labels_feed,
Expand Down
4 changes: 0 additions & 4 deletions nn_wtf/tests/images_labels_data_set_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ class ImagesLabelsDataSetTest(unittest.TestCase):
def test_init_without_fake_data_runs(self):
_create_empty_data_set()

def test_init_with_fake_data_runs(self):
images = create_minimal_input_placeholder()
ImagesLabelsDataSet(images, images, fake_data=True)

def test_init_with_different_label_size_fails(self):
images = create_empty_image_data()
labels = create_empty_label_dataof_size(NUM_TRAINING_SAMPLES+1)
Expand Down
12 changes: 12 additions & 0 deletions nn_wtf/tests/predictor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from nn_wtf.neural_network_graph import NeuralNetworkGraph

from .util import MINIMAL_INPUT_SIZE, MINIMAL_OUTPUT_SIZE, MINIMAL_LAYER_GEOMETRY, MINIMAL_BATCH_SIZE
from .util import create_minimal_input_placeholder

import tensorflow as tf

import unittest

__author__ = 'Lene Preuss <lene.preuss@gmail.com>'

class NeuralNetworkGraphTest(unittest.TestCase):

0 comments on commit 11fb2f7

Please sign in to comment.