Skip to content

Commit

Permalink
Merge pull request #1096 from rizar/fix_unique_scans
Browse files Browse the repository at this point in the history
Compare scan Ops by their ids
  • Loading branch information
dmitriy-serdyuk committed May 30, 2016
2 parents 4458a6e + 78d1235 commit 449a80d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
3 changes: 2 additions & 1 deletion blocks/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def _get_variables(self):
inputs = graph.inputs(self.outputs)
sorted_apply_nodes = graph.io_toposort(inputs, usual_outputs)
self.scans = list(unique([node.op for node in sorted_apply_nodes
if isinstance(node.op, Scan)]))
if isinstance(node.op, Scan)],
key=lambda op: id(op)))
self._scan_graphs = [ComputationGraph(scan.outputs)
for scan in self.scans]

Expand Down
13 changes: 12 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from theano import tensor
from theano.sandbox.rng_mrg import MRG_RandomStreams

from blocks.bricks import MLP, Identity, Logistic
from blocks.bricks import MLP, Identity, Logistic, Tanh
from blocks.bricks.cost import SquaredError
from blocks.bricks.recurrent import SimpleRecurrent
from blocks.filter import VariableFilter
from blocks.graph import (apply_dropout, apply_noise, collect_parameters,
ComputationGraph)
Expand Down Expand Up @@ -203,3 +204,13 @@ def test_collect():
assert numpy.all(W1.eval() == 1.)
assert W2.eval().shape == (100, 784)
assert numpy.all(W2.eval() == 2.)


def test_similar_scans():
x = tensor.tensor3('x')
r1 = SimpleRecurrent(activation=Tanh(), dim=10)
y1 = r1.apply(x)
r2 = SimpleRecurrent(activation=Tanh(), dim=10)
y2 = r2.apply(x)
cg = ComputationGraph([y1, y2])
assert len(cg.scans) == 2

0 comments on commit 449a80d

Please sign in to comment.