Skip to content

Commit

Permalink
Sparse initialization unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvm committed Jan 7, 2015
1 parent a21e0ff commit 618f263
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import numpy
from numpy.testing import assert_, assert_equal, assert_allclose, assert_raises
import six
import theano
from numpy.testing import assert_equal, assert_allclose, assert_raises

from blocks.initialization import Constant, IsotropicGaussian, Uniform
from blocks.initialization import Constant, IsotropicGaussian, Sparse, Uniform


def test_constant():
def check_constant(const, shape, ground_truth):
# rng unused, so pass None.
init = Constant(const).generate(None, ground_truth.shape)
assert_(ground_truth.dtype == theano.config.floatX)
assert_(ground_truth.shape == init.shape)
assert ground_truth.dtype == theano.config.floatX
assert ground_truth.shape == init.shape
assert_equal(ground_truth, init)

# Test scalar init.
Expand All @@ -24,26 +25,26 @@ def check_constant(const, shape, ground_truth):


def test_gaussian():
rng = numpy.random.RandomState([2014, 1, 20])
rng = numpy.random.RandomState(1)

def check_gaussian(rng, mean, std, shape):
weights = IsotropicGaussian(mean, std).generate(rng, shape)
assert_(weights.shape == shape)
assert_(weights.dtype == theano.config.floatX)
assert weights.shape == shape
assert weights.dtype == theano.config.floatX
assert_allclose(weights.mean(), mean, atol=1e-2)
assert_allclose(weights.std(), std, atol=1e-2)
yield check_gaussian, rng, 0, 1, (500, 600)
yield check_gaussian, rng, 5, 3, (600, 500)


def test_uniform():
rng = numpy.random.RandomState([2014, 1, 20])
rng = numpy.random.RandomState(1)

def check_uniform(rng, mean, width, std, shape):
weights = Uniform(mean=mean, width=width,
std=std).generate(rng, shape)
assert_(weights.shape == shape)
assert_(weights.dtype == theano.config.floatX)
assert weights.shape == shape
assert weights.dtype == theano.config.floatX
assert_allclose(weights.mean(), mean, atol=1e-2)
if width is not None:
std_ = width / numpy.sqrt(12)
Expand All @@ -55,3 +56,30 @@ def check_uniform(rng, mean, width, std, shape):
yield check_uniform, rng, 5, None, 0.004, (700, 300)

assert_raises(ValueError, Uniform, 0, 1, 1)


def test_sparse():
rng = numpy.random.RandomState(1)

def check_sparse(rng, num_init, weights_init, sparse_init, shape, total):
weights = Sparse(num_init=num_init, weights_init=weights_init,
sparse_init=sparse_init).generate(rng, shape)
assert weights.shape == shape
assert weights.dtype == theano.config.floatX
if sparse_init is None:
if isinstance(num_init, six.integer_types):
assert (numpy.count_nonzero(weights) <=
weights.size - num_init * weights.shape[0])
else:
assert (numpy.count_nonzero(weights) <=
weights.size - num_init * weights.shape[1])
if total is not None:
assert numpy.sum(weights) == total

yield check_sparse, rng, 5, Constant(1.), None, (10, 10), None
yield check_sparse, rng, 0.5, Constant(1.), None, (10, 10), None
yield check_sparse, rng, 0.5, Constant(1.), Constant(1.), (10, 10), None
yield check_sparse, rng, 3, Constant(1.), None, (10, 10), 30
yield check_sparse, rng, 3, Constant(0.), Constant(1.), (10, 10), 70
yield check_sparse, rng, 0.3, Constant(1.), None, (10, 10), 30
yield check_sparse, rng, 0.3, Constant(0.), Constant(1.), (10, 10), 70

0 comments on commit 618f263

Please sign in to comment.