Skip to content

Commit

Permalink
Introduce sparse initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvm committed Jan 7, 2015
1 parent 71ff646 commit a21e0ff
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions blocks/initialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Objects for encapsulating parameter initialization strategies."""
from abc import ABCMeta, abstractmethod

import numpy
import six
import theano

from blocks.utils import update_instance


class NdarrayInitialization(object):
"""Base class specifying the interface for ndarray initialization."""
Expand Down Expand Up @@ -162,3 +166,38 @@ def generate(self, rng, shape):
# Correct that NumPy doesn't force diagonal of R to be non-negative
Q = Q * numpy.sign(numpy.diag(R))
return Q


class Sparse(NdarrayInitialization):
"""Initialize only a fraction of the weights, row-wise.
Parameters
----------
num_init : int or float
If int, this is the number of weights to initialize per row. If
float, it's the fraction of the weights per row to initialize.
weights_init : :class:`NdarrayInitialization` instance
The initialization scheme to initialize the weights with.
sparse_init : :class:`NdarrayInitialization` instance, optional
What to set the non-initialized weights to (0. by default)
"""
def __init__(self, num_init, weights_init, sparse_init=None):
if sparse_init is None:
sparse_init = Constant(0.)
update_instance(self, locals())

def generate(self, rng, shape):
weights = self.sparse_init.generate(rng, shape)
if isinstance(self.num_init, six.integer_types):
assert self.num_init > 0
num_init = self.num_init
else:
assert 1 >= self.num_init > 0
num_init = int(self.num_init * shape[1])
values = self.weights_init.generate(rng, (shape[0], num_init))
for i in range(shape[0]):
random_indices = numpy.random.choice(shape[1], num_init,
replace=False)
weights[i, random_indices] = values[i]
return weights

0 comments on commit a21e0ff

Please sign in to comment.