Skip to content

Commit

Permalink
Merge pull request #120 from rizar/dataset_evaluator
Browse files Browse the repository at this point in the history
Validation is up and running.
  • Loading branch information
rizar committed Jan 20, 2015
2 parents 7bbd347 + 7bdb346 commit c152a8f
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 54 deletions.
15 changes: 13 additions & 2 deletions blocks/bricks/cost.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import theano
from theano import tensor

from blocks.bricks import application, Brick


floatX = theano.config.floatX


class Cost(Brick):
pass

Expand All @@ -13,7 +17,7 @@ class CostMatrix(Cost):
Assumes that the data has format (batch, features).
"""
@application
@application(outputs=["cost"])
def apply(self, y, y_hat):
return self.cost_matrix.application_method(
self, y, y_hat).sum(axis=1).mean()
Expand Down Expand Up @@ -41,7 +45,14 @@ def cost_matrix(self, y, y_hat):


class CategoricalCrossEntropy(Cost):
@application
@application(outputs=["cost"])
def apply(self, y, y_hat):
cost = tensor.nnet.categorical_crossentropy(y_hat, y).mean()
return cost


class MisclassficationRate(Cost):
@application(outputs=["error_rate"])
def apply(self, y, y_hat):
return (tensor.sum(tensor.neq(y, y_hat.argmax(axis=1)))
/ y.shape[0].astype(floatX))
6 changes: 5 additions & 1 deletion blocks/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def __init__(self, which_set, start=None, stop=None, binary=False,

if which_set not in ('train', 'test'):
raise ValueError("MNIST only has a train and test set")
self.num_examples = (stop if stop else 60000) - (start if start else 0)
if not self.stop:
self.stop = 60000 if which_set == "train" else 10000
if not self.start:
self.start = 0
self.num_examples = self.stop - self.start
self.default_scheme = SequentialScheme(self.num_examples, 1)
super(MNIST, self).__init__(**kwargs)

Expand Down
2 changes: 1 addition & 1 deletion blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(self, **kwargs):
super(Printing, self).__init__(**kwargs)

def _print_attributes(self, attribute_tuples):
for attr, value in attribute_tuples:
for attr, value in sorted(attribute_tuples, key=lambda t: t[0]):
if not attr.startswith("_"):
print("\t", "{}:".format(attr), value)

Expand Down
37 changes: 37 additions & 0 deletions blocks/extensions/monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Extensions for monitoring the training process."""
from blocks.extensions import SimpleExtension
from blocks.monitoring.evaluators import DatasetEvaluator


class DataStreamMonitoring(SimpleExtension):
"""Monitors values of Theano expressions on a data stream.
Parameters
----------
expressions : list of Theano variables
The expressions to monitor. The variable names are used as
expression names.
data_stream : instance of :class:`DataStream`
The data stream to monitor on. A data epoch is requsted
each time monitoring is done.
prefix : str
A prefix to add to expression names when adding records to the
log. An underscore will be used to separate the prefix.
"""
PREFIX_SEPARATOR = '_'

def __init__(self, expressions, data_stream, prefix, **kwargs):
kwargs.setdefault("after_every_epoch", True)
kwargs.setdefault("before_first_epoch", True)
super(DataStreamMonitoring, self).__init__(**kwargs)
self._evaluator = DatasetEvaluator(expressions)
self.data_stream = data_stream
self.prefix = prefix

def do(self, callback_name, *args):
"""Write the values of monitored expressions to the log."""
value_dict = self._evaluator.evaluate(self.data_stream)
for name, value in value_dict.items():
prefixed_name = self.prefix + self.PREFIX_SEPARATOR + name
setattr(self.main_loop.log.current_row, prefixed_name, value)
142 changes: 142 additions & 0 deletions blocks/monitoring/evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from collections import OrderedDict
import logging

import theano

from blocks.utils import dict_subset
from blocks.monitoring.aggregation import _DataIndependent, Mean
from blocks.graph import ComputationGraph

logger = logging.getLogger()


class DatasetEvaluator(object):
"""A DatasetEvaluator evaluates many Theano expressions on a dataset.
The DatasetEvaluator provides a do-it-all method, :meth:`evaluate`,
which computes values of ``expressions`` on a dataset.
Alternatively, methods :meth:`_initialize_computation`,
:meth:`_process_batch`, :meth:`_readout_expressions` can be used with a
custom loop over data.
The values computed on subsets of the given dataset are aggregated
using the :class:`AggregationScheme`s provided in the
`aggregation_scheme` tags. If no tag is given, the value is **averaged
over minibatches**. However, care is taken to ensure that variables
which do not depend on data are not unnecessarily recomputed.
Parameters
----------
expressions : dict or list
If a list of Theano variables. The variable names are used as
expression names. All the variables names must be different.
Each variable can be tagged with an :class:`AggregationScheme` that
specifies how the value can be computed for a data set by
aggregating minibatches.
"""
def __init__(self, expressions):
self.expressions = OrderedDict(
[(var.name, var) for var in expressions])
if len(self.expressions) < len(expressions):
raise ValueError(
"Expression variables should have different names")

self.inputs = ComputationGraph(
list(self.expressions.values())).inputs
self._compile()

def _compile(self):
"""Compiles Theano functions.
.. todo::
The current compilation method does not account for updates
attached to `ComputationGraph` elements. Compiling should
be out-sourced to `ComputationGraph` to deal with it.
"""
self._initialize_updates = []
self._accumulate_updates = []
self._readout = OrderedDict()

for k, v in self.expressions.items():
logger.debug('Expression to evaluate: %s', v.name)
if not hasattr(v.tag, 'aggregation_scheme'):
if ComputationGraph([v]).inputs == []:
logger.debug('Using _DataIndependent aggregation scheme'
' for %s since it does not depend on'
' the data', k)
v.tag.aggregation_scheme = _DataIndependent(variable=v)
else:
logger.debug('Using the default (average over minibatches)'
' aggregation scheme for %s', k)
v.tag.aggregation_scheme = Mean(v, 1.0)

aggregator = v.tag.aggregation_scheme.get_aggregator()
self._initialize_updates.extend(aggregator.initialization_updates)
self._accumulate_updates.extend(aggregator.accumulation_updates)
self._readout[k] = aggregator.readout_expression

if self._initialize_updates:
self._initialize_fun = theano.function(
[], [], updates=self._initialize_updates)
else:
self._initialize_fun = None

self._initialized = False
self._input_names = [v.name for v in self.inputs]

if self._accumulate_updates:
self._accumulate_fun = theano.function(
self.inputs, [], updates=self._accumulate_updates)
else:
self._accumulate_fun = None

self._readout_fun = theano.function([], list(self._readout.values()))

def _initialize_computation(self):
"""Initialize the aggragators to process a dataset."""
self._initialized = True
if self._initialize_fun is not None:
self._initialize_fun()

def _process_batch(self, batch):
batch = dict_subset(batch, self._input_names)
if self._accumulate_fun is not None:
self._accumulate_fun(**batch)

def _readout_expressions(self):
if not self._initialized:
raise Exception("To readout you must first initialize, then"
"process batches!")
self._initialized = False
ret_vals = self._readout_fun()
return dict(zip(self.expressions.keys(), ret_vals))

def evaluate(self, data_stream):
"""Compute the expressions over a data stream.
Parameters
----------
data_stream : instance of :class:`DataStream`
The data stream. Only the first epoch of data is used.
Returns
-------
A mapping from expression names to the values computed on the
provided dataset.
"""
self._initialize_computation()

if self._accumulate_fun is not None:
for batch in data_stream.get_epoch_iterator(as_dict=True):
self._process_batch(batch)
else:
logger.debug('Only constant monitors are used, will not'
'iterate the over data!')

return self._readout_expressions()
31 changes: 7 additions & 24 deletions blocks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def check_theano_variable(variable, n_dim, dtype_prefix):
dtype_prefix, variable.dtype))


def named_copy(variable, new_name):
"""Clones a variable and set a new name to the clone."""
result = variable.copy()
result.name = new_name
return result


def is_graph_input(variable):
"""Check if variable is a user-provided graph input.
Expand Down Expand Up @@ -250,30 +257,6 @@ def is_shared_variable(variable):
return isinstance(variable, SharedVariable)


def graph_inputs(variables, blockers=None):
"""Compute inputs needed to compute values in variables.
This function is similar to :meth:`theano.gof.graph.inputs`. However,
it doesn't treat shared and constant values as inputs.
Parameters
----------
variables : list of theano variables
The outputs whose inputs are sought for.
blockers : list of theano variables
See :meth:`theano.gof.graph.inputs` for documentation.
Returns
-------
list
Theano variables which are non-constant and non-shared inputs to
the computational graph.
"""
inps = theano.gof.graph.inputs(variables, blockers=blockers)
return [i for i in inps if is_graph_input(i)]


def dict_subset(dict_, keys, pop=False, must_have=True):
"""Return a subset of a dictionary corresponding to a set of keys.
Expand Down
42 changes: 33 additions & 9 deletions examples/mnist.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,60 @@
#!/usr/bin/env python

from argparse import ArgumentParser

from theano import tensor

from blocks.algorithms import GradientDescent, SteepestDescent
from blocks.bricks import MLP, Tanh, Identity
from blocks.bricks.cost import CategoricalCrossEntropy
from blocks.bricks import MLP, Tanh, Softmax
from blocks.bricks.cost import CategoricalCrossEntropy, MisclassficationRate
from blocks.initialization import IsotropicGaussian, Constant
from blocks.datasets import DataStream
from blocks.datasets.mnist import MNIST
from blocks.datasets.schemes import SequentialScheme
from blocks.extensions import FinishAfter, Printing
from blocks.extensions.saveload import SerializeMainLoop
from blocks.extensions.monitoring import DataStreamMonitoring
from blocks.main_loop import MainLoop


def main(save_to="mnist.pkl", num_epochs=2):
mlp = MLP([Tanh(), Identity()], [784, 100, 10])
def main(save_to, num_epochs):
mlp = MLP([Tanh(), Softmax()], [784, 100, 10],
weights_init=IsotropicGaussian(0, 0.01),
biases_init=Constant(0))
mlp.initialize()
x = tensor.matrix('features')
y = tensor.lmatrix('targets')
cost = CategoricalCrossEntropy().apply(y.flatten() - 1, mlp.apply(x))
probs = mlp.apply(x)
cost = CategoricalCrossEntropy().apply(y.flatten(), probs)
error_rate = MisclassficationRate().apply(y.flatten(), probs)

mnist = MNIST("train")
mnist_train = MNIST("train")
mnist_test = MNIST("test")

main_loop = MainLoop(
mlp,
DataStream(mnist,
iteration_scheme=SequentialScheme(mnist.num_examples, 50)),
DataStream(mnist_train,
iteration_scheme=SequentialScheme(
mnist_train.num_examples, 50)),
GradientDescent(cost=cost,
step_rule=SteepestDescent(learning_rate=0.1)),
extensions=[FinishAfter(after_n_epochs=num_epochs),
DataStreamMonitoring(
[cost, error_rate],
DataStream(mnist_test,
iteration_scheme=SequentialScheme(
mnist_test.num_examples, 500)),
prefix="test"),
SerializeMainLoop(save_to),
Printing()])
main_loop.run()

if __name__ == "__main__":
main()
parser = ArgumentParser("An example of training an MLP on"
" the MNIST dataset.")
parser.add_argument("--num-epochs", type=int, default=2,
help="Number of training epochs to do.")
parser.add_argument("save_to", default="mnist.pkl", nargs="?",
help="Destination to save the state of the training process.")
args = parser.parse_args()
main(args.save_to, args.num_epochs)
12 changes: 6 additions & 6 deletions tests/monitoring/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def _allocate(self):
@application(inputs=['input_'], outputs=['output'])
def apply(self, input_, application_call):
V = self.params[0]
application_call.add_monitor((V ** 2).sum(), name='V_monitor')
mean_input = mean(input_.mean(axis=1).sum(), input_.shape[0])
application_call.add_monitor(mean_input, name='mean_mean_input')
mean_row_mean = mean(input_.mean(axis=1).sum(), input_.shape[0])
application_call.add_monitor((V ** 2).sum(), name='V_squared')
application_call.add_monitor(mean_row_mean, name='mean_row_mean')
application_call.add_monitor(input_.mean(),
name='per_batch_mean_input')
name='mean_batch_element')
return input_ + V


Expand Down Expand Up @@ -50,5 +50,5 @@ def test_param_monitor():
initialize()
accumulate = theano.function([X], updates=aggregator.accumulation_updates)
accumulate(numpy.arange(4, dtype=theano.config.floatX).reshape(2, 2))
accumulate(numpy.arange(4, 8, dtype=theano.config.floatX).reshape(2, 2))
assert_allclose(aggregator.readout_expression.eval(), 3.5)
accumulate(numpy.arange(4, 10, dtype=theano.config.floatX).reshape(3, 2))
assert_allclose(aggregator.readout_expression.eval(), 4.5)

0 comments on commit c152a8f

Please sign in to comment.