Skip to content

Commit

Permalink
minimal tests for saving graphs - check if saving runs at all
Browse files Browse the repository at this point in the history
  • Loading branch information
lene committed Mar 5, 2016
1 parent 8302f87 commit 7b310b0
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 15 deletions.
2 changes: 1 addition & 1 deletion nn_wtf/neural_network_graph_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, session, train_dir=DEFAULT_TRAIN_DIR):
self.saver = tf.train.Saver()

def save(self, **kwargs):
self.saver.save(self.session, save_path=self.train_dir, **kwargs)
return self.saver.save(self.session, save_path=self.train_dir, **kwargs)


class SummaryWriterMixin(NeuralNetworkGraphMixin):
Expand Down
1 change: 1 addition & 0 deletions nn_wtf/run_tests.py → nn_wtf/tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from nn_wtf.tests.predictor_test import PredictorTest
from nn_wtf.tests.neural_network_optimizer_test import NeuralNetworkOptimizerTest
from nn_wtf.tests.trainer_test import TrainerTest
from nn_wtf.tests.save_and_restore_test import SaveAndRestoreTest

import unittest

Expand Down
43 changes: 43 additions & 0 deletions nn_wtf/tests/save_and_restore_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from nn_wtf.neural_network_graph import NeuralNetworkGraph
from nn_wtf.neural_network_graph_mixins import SaverMixin
from nn_wtf.tests.util import MINIMAL_LAYER_GEOMETRY, init_graph, train_neural_network, create_train_data_set

import unittest

from tempfile import gettempdir
from os import remove
from os.path import join

__author__ = 'Lene Preuss <lene.preuss@gmail.com>'


class SavableNetwork(NeuralNetworkGraph, SaverMixin):
def __init__(self):
super().__init__(2, MINIMAL_LAYER_GEOMETRY, 2)

def set_session(self, session=None, verbose=True, train_dir=gettempdir()):
super().set_session()
SaverMixin.__init__(self, self.session, train_dir)


class SaveAndRestoreTest(unittest.TestCase):

def setUp(self):
self.generated_filenames = []

def tearDown(self):
for filename in self.generated_filenames:
remove(join(gettempdir(), filename))

def test_save_untrained_network_runs(self):
graph = init_graph(SavableNetwork())
saved = graph.save(global_step=graph.trainer.num_steps())
self._add_savefiles_to_list(saved)

def test_save_trained_network_runs(self):
graph = train_neural_network(create_train_data_set(), SavableNetwork())
saved = graph.save(global_step=graph.trainer.num_steps())
self._add_savefiles_to_list(saved)

def _add_savefiles_to_list(self, savefile):
self.generated_filenames.extend([savefile, '{}.meta'.format(savefile), 'checkpoint'])
14 changes: 10 additions & 4 deletions nn_wtf/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def train_data_input(value):
return create_vector([value, value])


def train_neural_network(train_data):
def train_neural_network(train_data, graph=None):
data_sets = DataSets(train_data, train_data, train_data)
graph = NeuralNetworkGraph(train_data.input.shape[0], MINIMAL_LAYER_GEOMETRY, len(train_data.labels))
graph.init_trainer()
graph.set_session()
if graph is None:
graph = NeuralNetworkGraph(train_data.input.shape[0], MINIMAL_LAYER_GEOMETRY, len(train_data.labels))
init_graph(graph)

graph.train(
data_sets=data_sets, steps_between_checks=50, max_steps=1000, batch_size=train_data.num_examples,
Expand All @@ -55,6 +55,12 @@ def train_neural_network(train_data):
return graph


def init_graph(graph):
graph.init_trainer()
graph.set_session()
return graph


def allow_fail(max_times_fail=1, silent=True):
"""Runs a test, if necessary repeatedly, allowing it to fail up to max_times_fail times.
Expand Down
16 changes: 6 additions & 10 deletions nn_wtf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,10 @@ def _training(self, loss, learning_rate):
The Op returned by this function is what must be passed to the
`sess.run()` call to cause the model to train.
Args:
loss: Loss tensor, from loss().
learning_rate: The learning rate to use for gradient descent.
:param loss: Loss tensor, from loss().
:param learning_rate: The learning rate to use for gradient descent.
Returns:
train_op: The Op for training.
:return train_op: The Op for training.
"""
# Add a scalar summary for the snapshot loss.
tf.scalar_summary(loss.op.name, loss)
Expand All @@ -177,13 +175,11 @@ def _training(self, loss, learning_rate):
def _evaluation(self, logits, labels):
"""Evaluate the quality of the logits at predicting the label.
Args:
logits: Logits tensor, float - [batch_size, NUM_CLASSES].
labels: Labels tensor, int32 - [batch_size], with values in the
:param logits: Logits tensor, float - [batch_size, NUM_CLASSES].
:param labels: Labels tensor, int32 - [batch_size], with values in the
range [0, NUM_CLASSES).
Returns:
A scalar int32 tensor with the number of examples (out of batch_size)
:return A scalar int32 tensor with the number of examples (out of batch_size)
that were predicted correctly.
"""
# For a classifier model, we can use the in_top_k Op. It returns a bool
Expand Down

0 comments on commit 7b310b0

Please sign in to comment.