Skip to content

Commit

Permalink
Expand test coverage!
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Johnson committed Jan 18, 2016
1 parent 4fd051f commit 2469350
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
7 changes: 6 additions & 1 deletion test/feedforward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def test_predict_proba_onelayer(self):
z = net.predict_proba(self.INPUTS)
self.assert_shape(z.shape, self.NUM_CLASSES)

def test_predict_logit_onelayer(self):
net = self._build(13)
z = net.predict_logit(self.INPUTS)
self.assert_shape(z.shape, self.NUM_CLASSES)

def test_predict_twolayer(self):
net = self._build(13, 14)
z = net.predict(self.INPUTS)
Expand Down Expand Up @@ -136,7 +141,7 @@ def test_score_onelayer(self):

def test_encode_onelayer(self):
net = self._build(13)
z = net.encode(self.INPUTS)
z = net.encode(self.INPUTS, 'hid1')
self.assert_shape(z.shape, 13)

def test_encode_twolayer(self):
Expand Down
12 changes: 12 additions & 0 deletions test/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def test_layer_tied(self):
assert isinstance(m.layers[2], theanets.layers.feedforward.Tied)
assert m.layers[2].partner is m.layers[1]

def test_layer_tied_partner(self):
m = theanets.Regressor((1, 2, dict(size=1, form='tied', partner='hid1')))
assert len(m.layers) == 3
assert isinstance(m.layers[2], theanets.layers.feedforward.Tied)
assert m.layers[2].partner is m.layers[1]

def test_layer_tied_no_partner(self):
try:
theanets.Regressor((1, (2, 'tied'), (2, 'tied'), (1, 'tied')))
Expand Down Expand Up @@ -99,6 +105,12 @@ def test_find_missing(self):
except KeyError:
pass

def test_train(self):
m = theanets.Regressor((1, 2, 1))
tm, vm = m.train([np.random.randn(100, 1).astype('f'),
np.random.randn(100, 1).astype('f')])
assert tm['loss'] > 0


class TestMonitors:
def setUp(self):
Expand Down
18 changes: 17 additions & 1 deletion test/regularizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import util


class Mixin:
class Mixin(object):
def assert_progress(self, **kwargs):
start = best = None
for _, val in self.exp.itertrain(
Expand All @@ -21,6 +21,22 @@ def assert_progress(self, **kwargs):
assert best < start # should have made progress!


class TestBuild(util.Base):
def setUp(self):
self.exp = theanets.Regressor(
[self.NUM_INPUTS, 20, self.NUM_OUTPUTS], rng=131)

def test_regularizers_dict(self):
regs = theanets.regularizers.from_kwargs(
self.exp, regularizers=dict(input_noise=0.01))
assert len(regs) == 1

def test_regularizers_list(self):
reg = theanets.regularizers.Regularizer.build('weight_l2', 0.01)
regs = theanets.regularizers.from_kwargs(self.exp, regularizers=[reg])
assert len(regs) == 1


class TestNetwork(Mixin, util.Base):
def setUp(self):
self.exp = theanets.Regressor(
Expand Down

0 comments on commit 2469350

Please sign in to comment.