Skip to content

Commit

Permalink
fix unit test wrt deprecated functions and trying to remove nonexiste…
Browse files Browse the repository at this point in the history
…nt files
  • Loading branch information
lene committed Dec 22, 2016
1 parent d5a0e14 commit 2395f12
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ graph = MNISTGraph(
)
graph.train(data_sets, max_steps=5000, precision=0.95)

# verify the training worked by testing if a handwritten number is recognized
image_data = MNISTDataSets.read_one_image_from_url(
'http://github.com/lene/nn-wtf/blob/master/nn_wtf/data/7_from_test_set.raw?raw=true'
)
Expand Down
14 changes: 8 additions & 6 deletions nn_wtf/tests/save_and_restore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tempfile import gettempdir
from os import remove
from os.path import join
from os.path import join, isfile

__author__ = 'Lene Preuss <lene.preuss@gmail.com>'

Expand All @@ -30,7 +30,8 @@ def setUp(self):

def tearDown(self):
for filename in self.generated_filenames:
remove(join(gettempdir(), filename))
if isfile(join(gettempdir(), filename)):
remove(join(gettempdir(), filename))

def test_save_untrained_network_runs(self):
graph = init_graph(SavableNetwork())
Expand All @@ -46,22 +47,23 @@ def test_prediction_with_trained_graph(self):
savefile = self._save_trained_graph()

new_graph = init_graph(SavableNetwork(), session=tf.Session())
for v in tf.all_variables():
for v in tf.global_variables():
self.assertFalse(self.is_ndarray_equal(v.eval(new_graph.session), self.variables[v.op.name]))
new_graph.restore()#savefile)
for v in tf.all_variables():
for v in tf.global_variables():
self.assertTrue(self.is_ndarray_equal(v.eval(new_graph.session), self.variables[v.op.name]))

def _save_trained_graph(self):
tf.reset_default_graph()
with train_neural_network(create_train_data_set(), SavableNetwork()) as graph:
for v in tf.all_variables():
for v in tf.global_variables():
self.variables[v.op.name] = v.eval(graph.session)
return graph.save(global_step=graph.trainer.num_steps())

def _add_savefiles_to_list(self, savefile):
self.generated_filenames.extend([savefile, '{}.meta'.format(savefile), 'checkpoint'])

def is_ndarray_equal(self, array_1, array_2):
@staticmethod
def is_ndarray_equal(array_1, array_2):
import numpy
return numpy.array_equal(array_1, array_2)

0 comments on commit 2395f12

Please sign in to comment.