Skip to content

Commit

Permalink
Merge pull request #1111 from rizar/repr_for_init_scheme
Browse files Browse the repository at this point in the history
__repr__ for initialization schemes
  • Loading branch information
rizar committed Jun 7, 2016
2 parents 3a7dde7 + f5d9330 commit d7a3060
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
37 changes: 27 additions & 10 deletions blocks/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import theano
from six import add_metaclass

from blocks.utils import repr_attrs


@add_metaclass(ABCMeta)
class NdarrayInitialization(object):
Expand Down Expand Up @@ -62,13 +64,16 @@ class Constant(NdarrayInitialization):
"""
def __init__(self, constant):
self._constant = numpy.asarray(constant)
self.constant = numpy.asarray(constant)

def generate(self, rng, shape):
dest = numpy.empty(shape, dtype=theano.config.floatX)
dest[...] = self._constant
dest[...] = self.constant
return dest

def __repr__(self):
return repr_attrs(self, 'constant')


class IsotropicGaussian(NdarrayInitialization):
"""Initialize parameters from an isotropic Gaussian distribution.
Expand All @@ -87,13 +92,16 @@ class IsotropicGaussian(NdarrayInitialization):
"""
def __init__(self, std=1, mean=0):
self._mean = mean
self._std = std
self.mean = mean
self.std = std

def generate(self, rng, shape):
m = rng.normal(self._mean, self._std, size=shape)
m = rng.normal(self.mean, self.std, size=shape)
return m.astype(theano.config.floatX)

def __repr__(self):
return repr_attrs(self, 'mean', 'std')


class Uniform(NdarrayInitialization):
"""Initialize parameters from a uniform distribution.
Expand All @@ -120,16 +128,19 @@ def __init__(self, mean=0., width=None, std=None):
"but not both")
if std is not None:
# Variance of a uniform is 1/12 * width^2
self._width = numpy.sqrt(12) * std
self.width = numpy.sqrt(12) * std
else:
self._width = width
self._mean = mean
self.width = width
self.mean = mean

def generate(self, rng, shape):
w = self._width / 2
m = rng.uniform(self._mean - w, self._mean + w, size=shape)
w = self.width / 2
m = rng.uniform(self.mean - w, self.mean + w, size=shape)
return m.astype(theano.config.floatX)

def __repr__(self):
return repr_attrs(self, 'mean', 'width')


class Identity(NdarrayInitialization):
"""Initialize to the identity matrix.
Expand All @@ -152,6 +163,9 @@ def generate(self, rng, shape):
rows, cols = shape
return self.mult * numpy.eye(rows, cols, dtype=theano.config.floatX)

def __repr__(self):
return repr_attrs(self, 'mult')


class Orthogonal(NdarrayInitialization):
"""Initialize a random orthogonal matrix.
Expand Down Expand Up @@ -199,6 +213,9 @@ def generate(self, rng, shape):
n_min = min(shape[0], shape[1])
return numpy.dot(Q1[:, :n_min], Q2[:n_min, :]) * self.scale

def __repr__(self):
return repr_attrs(self, 'scale')


class Sparse(NdarrayInitialization):
"""Initialize only a fraction of the weights, row-wise.
Expand Down
14 changes: 13 additions & 1 deletion tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.testing import assert_equal, assert_allclose, assert_raises

from blocks.initialization import Constant, IsotropicGaussian, Sparse
from blocks.initialization import Uniform, Orthogonal
from blocks.initialization import Uniform, Orthogonal, Identity


def test_constant():
Expand All @@ -25,6 +25,12 @@ def check_constant(const, shape, ground_truth):
yield (check_constant, numpy.array([[1], [2], [3]]), (3, 2),
numpy.array([[1, 1], [2, 2], [3, 3]], dtype=theano.config.floatX))

assert str(Constant(1.0)).endswith(' constant=1.0>')


def test_identity():
assert str(Identity(2.0)).endswith(' mult=2.0>')


def test_gaussian():
rng = numpy.random.RandomState(1)
Expand All @@ -38,6 +44,8 @@ def check_gaussian(rng, mean, std, shape):
yield check_gaussian, rng, 0, 1, (500, 600)
yield check_gaussian, rng, 5, 3, (600, 500)

assert str(IsotropicGaussian(1.0, 2.0)).endswith(' mean=2.0, std=1.0>')


def test_uniform():
rng = numpy.random.RandomState(1)
Expand All @@ -59,6 +67,8 @@ def check_uniform(rng, mean, width, std, shape):

assert_raises(ValueError, Uniform, 0, 1, 1)

assert str(Uniform(1.0, 2.0)).endswith(' mean=1.0, width=2.0>')


def test_sparse():
rng = numpy.random.RandomState(1)
Expand Down Expand Up @@ -124,3 +134,5 @@ def check_orthogonal(rng, shape, scale=1.0):
yield check_orthogonal, rng, (50, 50), .5
yield check_orthogonal, rng, (50, 51), .5
yield check_orthogonal, rng, (51, 50), .5

assert str(Orthogonal(3.0)).endswith(' scale=3.0>')

0 comments on commit d7a3060

Please sign in to comment.