Skip to content

Commit

Permalink
Fix up regularizer progress-assert utility.
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Johnson committed Dec 28, 2015
1 parent 1325080 commit de67837
Showing 1 changed file with 20 additions and 25 deletions.
45 changes: 20 additions & 25 deletions test/regularizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,27 @@
import util


class TestNetwork(util.Base):
class Mixin:
def assert_progress(self, **kwargs):
start = best = None
for _, val in self.exp.itertrain(
[self.INPUTS, self.OUTPUTS],
algorithm='sgd',
patience=2,
min_improvement=0.01,
batch_size=self.NUM_EXAMPLES,
**kwargs):
if start is None:
start = best = val['loss']
if val['loss'] < best:
best = val['loss']
assert best < start # should have made progress!


class TestNetwork(Mixin, util.Base):
def setUp(self):
self.exp = theanets.Regressor([self.NUM_INPUTS, 10, self.NUM_OUTPUTS])

def assert_progress(self, **kwargs):
train0, valid0 = next(self.exp.itertrain([self.INPUTS, self.OUTPUTS]))
trainN, validN = self.exp.train(
[self.INPUTS, self.OUTPUTS],
algorithm='sgd',
patience=2,
min_improvement=0.01,
batch_size=self.NUM_EXAMPLES,
**kwargs)
assert trainN['loss'] < valid0['loss'] # should have made progress!

def test_input_noise(self):
self.assert_progress(input_noise=0.001)

Expand Down Expand Up @@ -49,21 +55,10 @@ def test_contractive(self):
self.assert_progress(contractive=0.001)


class TestRecurrent(util.RecurrentBase):
class TestRecurrent(Mixin, util.RecurrentBase):
def setUp(self):
self.exp = theanets.recurrent.Regressor([
self.NUM_INPUTS, (10, 'rnn'), self.NUM_OUTPUTS])
self.NUM_INPUTS, (20, 'rnn'), self.NUM_OUTPUTS])

def test_recurrent_l2(self):
self.assert_progress(recurrent_l2=0.001)

def assert_progress(self, **kwargs):
train0, valid0 = next(self.exp.itertrain([self.INPUTS, self.OUTPUTS]))
trainN, validN = self.exp.train(
[self.INPUTS, self.OUTPUTS],
algorithm='sgd',
patience=2,
min_improvement=0.01,
batch_size=self.NUM_EXAMPLES,
**kwargs)
assert trainN['loss'] < valid0['loss'] # should have made progress!

0 comments on commit de67837

Please sign in to comment.