-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #120 from rizar/dataset_evaluator
Validation is up and running.
- Loading branch information
Showing
10 changed files
with
279 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Extensions for monitoring the training process.""" | ||
from blocks.extensions import SimpleExtension | ||
from blocks.monitoring.evaluators import DatasetEvaluator | ||
|
||
|
||
class DataStreamMonitoring(SimpleExtension): | ||
"""Monitors values of Theano expressions on a data stream. | ||
Parameters | ||
---------- | ||
expressions : list of Theano variables | ||
The expressions to monitor. The variable names are used as | ||
expression names. | ||
data_stream : instance of :class:`DataStream` | ||
The data stream to monitor on. A data epoch is requsted | ||
each time monitoring is done. | ||
prefix : str | ||
A prefix to add to expression names when adding records to the | ||
log. An underscore will be used to separate the prefix. | ||
""" | ||
PREFIX_SEPARATOR = '_' | ||
|
||
def __init__(self, expressions, data_stream, prefix, **kwargs): | ||
kwargs.setdefault("after_every_epoch", True) | ||
kwargs.setdefault("before_first_epoch", True) | ||
super(DataStreamMonitoring, self).__init__(**kwargs) | ||
self._evaluator = DatasetEvaluator(expressions) | ||
self.data_stream = data_stream | ||
self.prefix = prefix | ||
|
||
def do(self, callback_name, *args): | ||
"""Write the values of monitored expressions to the log.""" | ||
value_dict = self._evaluator.evaluate(self.data_stream) | ||
for name, value in value_dict.items(): | ||
prefixed_name = self.prefix + self.PREFIX_SEPARATOR + name | ||
setattr(self.main_loop.log.current_row, prefixed_name, value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from collections import OrderedDict | ||
import logging | ||
|
||
import theano | ||
|
||
from blocks.utils import dict_subset | ||
from blocks.monitoring.aggregation import _DataIndependent, Mean | ||
from blocks.graph import ComputationGraph | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
class DatasetEvaluator(object): | ||
"""A DatasetEvaluator evaluates many Theano expressions on a dataset. | ||
The DatasetEvaluator provides a do-it-all method, :meth:`evaluate`, | ||
which computes values of ``expressions`` on a dataset. | ||
Alternatively, methods :meth:`_initialize_computation`, | ||
:meth:`_process_batch`, :meth:`_readout_expressions` can be used with a | ||
custom loop over data. | ||
The values computed on subsets of the given dataset are aggregated | ||
using the :class:`AggregationScheme`s provided in the | ||
`aggregation_scheme` tags. If no tag is given, the value is **averaged | ||
over minibatches**. However, care is taken to ensure that variables | ||
which do not depend on data are not unnecessarily recomputed. | ||
Parameters | ||
---------- | ||
expressions : dict or list | ||
If a list of Theano variables. The variable names are used as | ||
expression names. All the variables names must be different. | ||
Each variable can be tagged with an :class:`AggregationScheme` that | ||
specifies how the value can be computed for a data set by | ||
aggregating minibatches. | ||
""" | ||
def __init__(self, expressions): | ||
self.expressions = OrderedDict( | ||
[(var.name, var) for var in expressions]) | ||
if len(self.expressions) < len(expressions): | ||
raise ValueError( | ||
"Expression variables should have different names") | ||
|
||
self.inputs = ComputationGraph( | ||
list(self.expressions.values())).inputs | ||
self._compile() | ||
|
||
def _compile(self): | ||
"""Compiles Theano functions. | ||
.. todo:: | ||
The current compilation method does not account for updates | ||
attached to `ComputationGraph` elements. Compiling should | ||
be out-sourced to `ComputationGraph` to deal with it. | ||
""" | ||
self._initialize_updates = [] | ||
self._accumulate_updates = [] | ||
self._readout = OrderedDict() | ||
|
||
for k, v in self.expressions.items(): | ||
logger.debug('Expression to evaluate: %s', v.name) | ||
if not hasattr(v.tag, 'aggregation_scheme'): | ||
if ComputationGraph([v]).inputs == []: | ||
logger.debug('Using _DataIndependent aggregation scheme' | ||
' for %s since it does not depend on' | ||
' the data', k) | ||
v.tag.aggregation_scheme = _DataIndependent(variable=v) | ||
else: | ||
logger.debug('Using the default (average over minibatches)' | ||
' aggregation scheme for %s', k) | ||
v.tag.aggregation_scheme = Mean(v, 1.0) | ||
|
||
aggregator = v.tag.aggregation_scheme.get_aggregator() | ||
self._initialize_updates.extend(aggregator.initialization_updates) | ||
self._accumulate_updates.extend(aggregator.accumulation_updates) | ||
self._readout[k] = aggregator.readout_expression | ||
|
||
if self._initialize_updates: | ||
self._initialize_fun = theano.function( | ||
[], [], updates=self._initialize_updates) | ||
else: | ||
self._initialize_fun = None | ||
|
||
self._initialized = False | ||
self._input_names = [v.name for v in self.inputs] | ||
|
||
if self._accumulate_updates: | ||
self._accumulate_fun = theano.function( | ||
self.inputs, [], updates=self._accumulate_updates) | ||
else: | ||
self._accumulate_fun = None | ||
|
||
self._readout_fun = theano.function([], list(self._readout.values())) | ||
|
||
def _initialize_computation(self): | ||
"""Initialize the aggragators to process a dataset.""" | ||
self._initialized = True | ||
if self._initialize_fun is not None: | ||
self._initialize_fun() | ||
|
||
def _process_batch(self, batch): | ||
batch = dict_subset(batch, self._input_names) | ||
if self._accumulate_fun is not None: | ||
self._accumulate_fun(**batch) | ||
|
||
def _readout_expressions(self): | ||
if not self._initialized: | ||
raise Exception("To readout you must first initialize, then" | ||
"process batches!") | ||
self._initialized = False | ||
ret_vals = self._readout_fun() | ||
return dict(zip(self.expressions.keys(), ret_vals)) | ||
|
||
def evaluate(self, data_stream): | ||
"""Compute the expressions over a data stream. | ||
Parameters | ||
---------- | ||
data_stream : instance of :class:`DataStream` | ||
The data stream. Only the first epoch of data is used. | ||
Returns | ||
------- | ||
A mapping from expression names to the values computed on the | ||
provided dataset. | ||
""" | ||
self._initialize_computation() | ||
|
||
if self._accumulate_fun is not None: | ||
for batch in data_stream.get_epoch_iterator(as_dict=True): | ||
self._process_batch(batch) | ||
else: | ||
logger.debug('Only constant monitors are used, will not' | ||
'iterate the over data!') | ||
|
||
return self._readout_expressions() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,60 @@ | ||
#!/usr/bin/env python | ||
|
||
from argparse import ArgumentParser | ||
|
||
from theano import tensor | ||
|
||
from blocks.algorithms import GradientDescent, SteepestDescent | ||
from blocks.bricks import MLP, Tanh, Identity | ||
from blocks.bricks.cost import CategoricalCrossEntropy | ||
from blocks.bricks import MLP, Tanh, Softmax | ||
from blocks.bricks.cost import CategoricalCrossEntropy, MisclassficationRate | ||
from blocks.initialization import IsotropicGaussian, Constant | ||
from blocks.datasets import DataStream | ||
from blocks.datasets.mnist import MNIST | ||
from blocks.datasets.schemes import SequentialScheme | ||
from blocks.extensions import FinishAfter, Printing | ||
from blocks.extensions.saveload import SerializeMainLoop | ||
from blocks.extensions.monitoring import DataStreamMonitoring | ||
from blocks.main_loop import MainLoop | ||
|
||
|
||
def main(save_to="mnist.pkl", num_epochs=2): | ||
mlp = MLP([Tanh(), Identity()], [784, 100, 10]) | ||
def main(save_to, num_epochs): | ||
mlp = MLP([Tanh(), Softmax()], [784, 100, 10], | ||
weights_init=IsotropicGaussian(0, 0.01), | ||
biases_init=Constant(0)) | ||
mlp.initialize() | ||
x = tensor.matrix('features') | ||
y = tensor.lmatrix('targets') | ||
cost = CategoricalCrossEntropy().apply(y.flatten() - 1, mlp.apply(x)) | ||
probs = mlp.apply(x) | ||
cost = CategoricalCrossEntropy().apply(y.flatten(), probs) | ||
error_rate = MisclassficationRate().apply(y.flatten(), probs) | ||
|
||
mnist = MNIST("train") | ||
mnist_train = MNIST("train") | ||
mnist_test = MNIST("test") | ||
|
||
main_loop = MainLoop( | ||
mlp, | ||
DataStream(mnist, | ||
iteration_scheme=SequentialScheme(mnist.num_examples, 50)), | ||
DataStream(mnist_train, | ||
iteration_scheme=SequentialScheme( | ||
mnist_train.num_examples, 50)), | ||
GradientDescent(cost=cost, | ||
step_rule=SteepestDescent(learning_rate=0.1)), | ||
extensions=[FinishAfter(after_n_epochs=num_epochs), | ||
DataStreamMonitoring( | ||
[cost, error_rate], | ||
DataStream(mnist_test, | ||
iteration_scheme=SequentialScheme( | ||
mnist_test.num_examples, 500)), | ||
prefix="test"), | ||
SerializeMainLoop(save_to), | ||
Printing()]) | ||
main_loop.run() | ||
|
||
if __name__ == "__main__": | ||
main() | ||
parser = ArgumentParser("An example of training an MLP on" | ||
" the MNIST dataset.") | ||
parser.add_argument("--num-epochs", type=int, default=2, | ||
help="Number of training epochs to do.") | ||
parser.add_argument("save_to", default="mnist.pkl", nargs="?", | ||
help="Destination to save the state of the training process.") | ||
args = parser.parse_args() | ||
main(args.save_to, args.num_epochs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.