Skip to content

Commit

Permalink
Merge pull request #959 from rizar/lookup_table_update
Browse files Browse the repository at this point in the history
Lookup table changes
  • Loading branch information
rizar committed Jan 28, 2016
2 parents 1bce70e + abecbff commit fa3a290
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
10 changes: 9 additions & 1 deletion blocks/bricks/lookup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Introduces Lookup brick."""
from blocks.bricks import Initializable
from blocks.bricks.base import application, lazy
from blocks.roles import WEIGHT, add_role
from blocks.utils import check_theano_variable, shared_floatx_nans


Expand Down Expand Up @@ -35,11 +36,12 @@ def W(self):
def _allocate(self):
self.parameters.append(shared_floatx_nans((self.length, self.dim),
name='W'))
add_role(self.parameters[-1], WEIGHT)

def _initialize(self):
self.weights_init.initialize(self.W, self.rng)

@application
@application(inputs=['indices'], outputs=['output'])
def apply(self, indices):
"""Perform lookup.
Expand All @@ -61,3 +63,9 @@ def apply(self, indices):
output_shape = [indices.shape[i]
for i in range(indices.ndim)] + [self.dim]
return self.W[indices.flatten()].reshape(output_shape)

def get_dim(self, name):
if name == 'output':
return self.dim
if name == 'indices':
return 0
4 changes: 4 additions & 0 deletions tests/bricks/test_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def test_lookup_table():
desired = numpy.array([[[3, 4, 5], [6, 7, 8]], [[0, 1, 2], [9, 10, 11]]],
dtype=theano.config.floatX)
assert_equal(f(x_val)[0], desired)

# Test get_dim
assert_equal(lt.get_dim(lt.apply.inputs[0]), 0)
assert_equal(lt.get_dim(lt.apply.outputs[0]), lt.dim)

0 comments on commit fa3a290

Please sign in to comment.