Skip to content

Commit

Permalink
added a function to predict multiple input images at once
Browse files Browse the repository at this point in the history
  • Loading branch information
lene committed Mar 1, 2016
1 parent 365868f commit 2d639a2
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 5 deletions.
14 changes: 12 additions & 2 deletions nn_wtf/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,26 @@ def predict(self, image):
predictions = self._run_prediction_op(image, self.session)
return predictions[0][0]

def predict_multiple(self, images, num):
predictions = self._run_multi_prediction_op(images, num, self.session)
return predictions[0].tolist()

def prediction_probabilities(self, image):
predictions = self._run_prediction_op(image, self.session)
return predictions[1][0]

def _run_prediction_op(self, image, session):
image_data = image.reshape(self.graph.input_size)
def _run_prediction_op(self, images, session):
image_data = images.reshape(self.graph.input_size)
feed_dict = {self.graph.input_placeholder: [image_data]}
best = session.run([self.prediction_op, self.probabilities_op], feed_dict)
return best

def _run_multi_prediction_op(self, images, num_images, session):
image_data = images.reshape(self.graph.input_size, num_images)
feed_dict = {self.graph.input_placeholder: image_data}
best = session.run([self.prediction_op], feed_dict)
return best

def _setup_prediction(self):
if self.prediction_op is not None:
return
Expand Down
4 changes: 2 additions & 2 deletions nn_wtf/tests/input_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_read_images_from_file_two(self):

def test_read_images_from_file_fails_if_file_too_short(self):
with self.assertRaises(ValueError):
data = read_images_from_file(
read_images_from_file(
get_project_root_folder()+'/nn_wtf/data/7_2.raw',
MNISTGraph.IMAGE_SIZE, MNISTGraph.IMAGE_SIZE, 3
)
Expand All @@ -56,7 +56,7 @@ def test_read_images_from_file_two_using_mnist_data_sets(self):

def test_read_images_from_file_using_mnist_data_sets_fails_if_file_too_short(self):
with self.assertRaises(ValueError):
data = MNISTDataSets.read_images_from_file(
MNISTDataSets.read_images_from_file(
get_project_root_folder()+'/nn_wtf/data/7_2.raw', 3
)

Expand Down
16 changes: 15 additions & 1 deletion nn_wtf/tests/predictor_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from nn_wtf.data_sets import DataSets
from nn_wtf.neural_network_graph import NeuralNetworkGraph
from nn_wtf.tests.util import MINIMAL_LAYER_GEOMETRY, create_train_data_set, train_data_input
from nn_wtf.tests.util import MINIMAL_LAYER_GEOMETRY, create_train_data_set, train_data_input, train_data_0_1, \
train_data_1_0

import unittest

Expand All @@ -11,14 +12,27 @@
class PredictorTest(unittest.TestCase):

def test_all_prediction_functions_at_once_to_save_computing_time(self):
"""Training takes time, if I run tests separately I have to train for each test."""

graph = train_neural_network(create_train_data_set())

self.assertEqual(0, graph.get_predictor().predict(train_data_input(0)))
self.assertEqual(1, graph.get_predictor().predict(train_data_input(1)))

probabilities_for_0 = graph.get_predictor().prediction_probabilities(train_data_input(0))
self.assertGreater(probabilities_for_0[0], probabilities_for_0[1])

probabilities_for_1 = graph.get_predictor().prediction_probabilities(train_data_input(1))
self.assertGreater(probabilities_for_1[1], probabilities_for_1[0])

def test_predict_multiple(self):
graph = train_neural_network(create_train_data_set())

predictions = graph.get_predictor().predict_multiple(train_data_0_1(), 2)
self.assertListEqual([0, 1], predictions)

predictions = graph.get_predictor().predict_multiple(train_data_1_0(), 2)
self.assertListEqual([1, 0], predictions)

def train_neural_network(train_data):
data_sets = DataSets(train_data, train_data, train_data)
Expand Down
6 changes: 6 additions & 0 deletions nn_wtf/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ def train_data_input(value):
return numpy.fromiter([value, value], numpy.dtype(numpy.float32))


def train_data_0_1():
return numpy.fromiter([0, 0, 1, 1], numpy.dtype(numpy.float32))


def train_data_1_0():
return numpy.fromiter([1, 1, 0, 0], numpy.dtype(numpy.float32))

0 comments on commit 2d639a2

Please sign in to comment.