Skip to content

Commit

Permalink
Merge pull request #89 from bartvm/sparse_init
Browse files Browse the repository at this point in the history
Sparse init
  • Loading branch information
bartvm committed Jan 7, 2015
2 parents 7cfdffc + 618f263 commit 884830a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 10 deletions.
39 changes: 39 additions & 0 deletions blocks/initialization.py
@@ -1,8 +1,12 @@
"""Objects for encapsulating parameter initialization strategies."""
from abc import ABCMeta, abstractmethod

import numpy
import six
import theano

from blocks.utils import update_instance


class NdarrayInitialization(object):
"""Base class specifying the interface for ndarray initialization."""
Expand Down Expand Up @@ -162,3 +166,38 @@ def generate(self, rng, shape):
# Correct that NumPy doesn't force diagonal of R to be non-negative
Q = Q * numpy.sign(numpy.diag(R))
return Q


class Sparse(NdarrayInitialization):
"""Initialize only a fraction of the weights, row-wise.
Parameters
----------
num_init : int or float
If int, this is the number of weights to initialize per row. If
float, it's the fraction of the weights per row to initialize.
weights_init : :class:`NdarrayInitialization` instance
The initialization scheme to initialize the weights with.
sparse_init : :class:`NdarrayInitialization` instance, optional
What to set the non-initialized weights to (0. by default)
"""
def __init__(self, num_init, weights_init, sparse_init=None):
if sparse_init is None:
sparse_init = Constant(0.)
update_instance(self, locals())

def generate(self, rng, shape):
weights = self.sparse_init.generate(rng, shape)
if isinstance(self.num_init, six.integer_types):
assert self.num_init > 0
num_init = self.num_init
else:
assert 1 >= self.num_init > 0
num_init = int(self.num_init * shape[1])
values = self.weights_init.generate(rng, (shape[0], num_init))
for i in range(shape[0]):
random_indices = numpy.random.choice(shape[1], num_init,
replace=False)
weights[i, random_indices] = values[i]
return weights
48 changes: 38 additions & 10 deletions tests/test_initialization.py
@@ -1,16 +1,17 @@
import numpy
from numpy.testing import assert_, assert_equal, assert_allclose, assert_raises
import six
import theano
from numpy.testing import assert_equal, assert_allclose, assert_raises

from blocks.initialization import Constant, IsotropicGaussian, Uniform
from blocks.initialization import Constant, IsotropicGaussian, Sparse, Uniform


def test_constant():
def check_constant(const, shape, ground_truth):
# rng unused, so pass None.
init = Constant(const).generate(None, ground_truth.shape)
assert_(ground_truth.dtype == theano.config.floatX)
assert_(ground_truth.shape == init.shape)
assert ground_truth.dtype == theano.config.floatX
assert ground_truth.shape == init.shape
assert_equal(ground_truth, init)

# Test scalar init.
Expand All @@ -24,26 +25,26 @@ def check_constant(const, shape, ground_truth):


def test_gaussian():
rng = numpy.random.RandomState([2014, 1, 20])
rng = numpy.random.RandomState(1)

def check_gaussian(rng, mean, std, shape):
weights = IsotropicGaussian(mean, std).generate(rng, shape)
assert_(weights.shape == shape)
assert_(weights.dtype == theano.config.floatX)
assert weights.shape == shape
assert weights.dtype == theano.config.floatX
assert_allclose(weights.mean(), mean, atol=1e-2)
assert_allclose(weights.std(), std, atol=1e-2)
yield check_gaussian, rng, 0, 1, (500, 600)
yield check_gaussian, rng, 5, 3, (600, 500)


def test_uniform():
rng = numpy.random.RandomState([2014, 1, 20])
rng = numpy.random.RandomState(1)

def check_uniform(rng, mean, width, std, shape):
weights = Uniform(mean=mean, width=width,
std=std).generate(rng, shape)
assert_(weights.shape == shape)
assert_(weights.dtype == theano.config.floatX)
assert weights.shape == shape
assert weights.dtype == theano.config.floatX
assert_allclose(weights.mean(), mean, atol=1e-2)
if width is not None:
std_ = width / numpy.sqrt(12)
Expand All @@ -55,3 +56,30 @@ def check_uniform(rng, mean, width, std, shape):
yield check_uniform, rng, 5, None, 0.004, (700, 300)

assert_raises(ValueError, Uniform, 0, 1, 1)


def test_sparse():
rng = numpy.random.RandomState(1)

def check_sparse(rng, num_init, weights_init, sparse_init, shape, total):
weights = Sparse(num_init=num_init, weights_init=weights_init,
sparse_init=sparse_init).generate(rng, shape)
assert weights.shape == shape
assert weights.dtype == theano.config.floatX
if sparse_init is None:
if isinstance(num_init, six.integer_types):
assert (numpy.count_nonzero(weights) <=
weights.size - num_init * weights.shape[0])
else:
assert (numpy.count_nonzero(weights) <=
weights.size - num_init * weights.shape[1])
if total is not None:
assert numpy.sum(weights) == total

yield check_sparse, rng, 5, Constant(1.), None, (10, 10), None
yield check_sparse, rng, 0.5, Constant(1.), None, (10, 10), None
yield check_sparse, rng, 0.5, Constant(1.), Constant(1.), (10, 10), None
yield check_sparse, rng, 3, Constant(1.), None, (10, 10), 30
yield check_sparse, rng, 3, Constant(0.), Constant(1.), (10, 10), 70
yield check_sparse, rng, 0.3, Constant(1.), None, (10, 10), 30
yield check_sparse, rng, 0.3, Constant(0.), Constant(1.), (10, 10), 70

0 comments on commit 884830a

Please sign in to comment.