Skip to content

Commit

Permalink
Add test for saving model periodically.
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Johnson committed Jul 11, 2015
1 parent 70a47c5 commit c34094f
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions test/graph_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
import os
import tempfile
import theanets

import util


class TestNetwork:
def test_updates(self):
Expand Down Expand Up @@ -95,3 +99,19 @@ def test_param(self):
def test_wildcard(self):
self.assert_monitors({'*.w': 1}, ['err', 'hid1.w<1', 'hid2.w<1', 'out.w<1'])
self.assert_monitors({'hid?.w': 1}, ['err', 'hid1.w<1', 'hid2.w<1'])


class TestSaving(util.Base):
def test_save_every(self):
net = theanets.Autoencoder((self.NUM_INPUTS, (3, 'prelu'), self.NUM_INPUTS))
f, p = tempfile.mkstemp(suffix='pkl')
os.close(f)
os.unlink(p)
train = net.itertrain([self.INPUTS], save_every=2, save_progress=p)
for i, _ in enumerate(zip(train, range(9))):
if i == 3 or i == 5 or i == 7:
assert os.path.isfile(p)
else:
assert not os.path.isfile(p)
if os.path.exists(p):
os.unlink(p)

0 comments on commit c34094f

Please sign in to comment.