Skip to content

Commit

Permalink
Merge pull request #115 from rizar/checkpointing
Browse files Browse the repository at this point in the history
Pickling the Main Loop and Resuming from Unpickled
  • Loading branch information
rizar committed Jan 20, 2015
2 parents 7f2813f + 5ce2b25 commit eae3f54
Show file tree
Hide file tree
Showing 10 changed files with 345 additions and 73 deletions.
30 changes: 27 additions & 3 deletions blocks/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,14 +658,17 @@ def add_property(f):
return add_property

def __getattr__(self, attr):
if attr == '_brick':

if attr == 'f':
return {}
elif attr == '_brick':
raise AttributeError
elif attr in self.f:
return self.f[attr](self.brick)
elif hasattr(self, '_brick') and self.delegate_method is not None:
return getattr(self.delegate_method(self.brick), attr)
else:
super(Application, self).__getattribute__(attr)
raise AttributeError


def application_wrapper(**kwargs):
Expand Down Expand Up @@ -1023,6 +1026,25 @@ def apply(self, input_):
return output


class _PicklableActivation(object):
"""A base class for dynamically generated classes that can be pickled."""
def __reduce__(self):
activation = self.__class__._activation
if hasattr(activation, '__func__'):
activation = activation.__func__
return (_Initializor(),
(self.__class__.__name__, activation),
self.__dict__)


class _Initializor(object):
"""A callable object which returns a parametrized class."""
def __call__(self, name, activation):
object_ = _Initializor()
object_.__class__ = _activation_factory(name, activation)
return object_


def _activation_factory(name, activation):
"""Class factory for Bricks which perform simple Theano calls."""
class ActivationDocumentation(type):
Expand All @@ -1033,8 +1055,10 @@ def __new__(cls, name, bases, classdict):
return type.__new__(cls, name, bases, classdict)

@add_metaclass(ActivationDocumentation)
class Activation(Brick):
class Activation(Brick, _PicklableActivation):
"""Element-wise application of {0} function."""
_activation = activation

@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
"""Apply the {0} function element-wise.
Expand Down
17 changes: 9 additions & 8 deletions blocks/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import six
from six import add_metaclass

from blocks.utils import update_instance
from blocks.utils import update_instance, LambdaIterator, SequenceIterator


@add_metaclass(ABCMeta)
Expand Down Expand Up @@ -155,22 +155,22 @@ class InMemoryDataset(Dataset):
files are loaded after de-serialization, before the :meth:`load` method
is ever called.
>>> import pickle
>>> import dill
>>> from blocks.datasets.mnist import MNIST
>>> mnist = MNIST('train')
>>> print("{:,d} KB".format(
... mnist.data['features'].nbytes / 1024)) # doctest: +SKIP
183,750 KB
>>> with open('mnist.pkl', 'wb') as f:
... pickle.dump(mnist, f, protocol=pickle.HIGHEST_PROTOCOL)
... dill.dump(mnist, f)
You will notice that the dumping of the dataset was relatively quick,
because it didn't attempt to write MNIST to disk. We can now reload it,
and if the data file has not been moved, it will be as if nothing
happened.
>>> with open('mnist.pkl', 'rb') as f:
... mnist = pickle.load(f)
... mnist = dill.load(f)
>>> print(mnist.data['features'].shape)
(60000, 784)
Expand All @@ -181,7 +181,7 @@ class InMemoryDataset(Dataset):
>>> correct_path = config.data_path
>>> config.data_path = '/non/existing/path'
>>> with open('mnist.pkl', 'rb') as f:
... mnist = pickle.load(f)
... mnist = dill.load(f)
>>> print(mnist.data['features'].shape) # doctest: +SKIP
Traceback (most recent call last):
...
Expand Down Expand Up @@ -318,9 +318,10 @@ def __init__(self, container, sources=None):
self.data_channels = [container]

def open(self):
iterators = [iter(channel) for channel in self.data_channels]
while True:
yield tuple([next(iterator) for iterator in iterators])
iterators = [SequenceIterator(channel)
for channel in self.data_channels]
return LambdaIterator(
lambda: tuple([next(iterator) for iterator in iterators]))

def get_data(self, state, request=None):
if request is not None:
Expand Down
67 changes: 47 additions & 20 deletions blocks/extensions.py → blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,27 @@ class TrainingExtension(object):
typically add a certain functionality to the training procedure,
e.g. running validation on auxiliarry datasets or early stopping.
Parameters
----------
name : str, optional
The name of the extension. The names are useful in order to
distinguish between several extensions of the same type that
belongs to the same main loop. By default the name is set to
the name of the class.
Attributes
----------
main_loop : :class:`MainLoop`
The main loop to which the extension belongs.
name : str
The name of the extension.
"""
def __init__(self, name=None):
if not name:
name = self.__class__.__name__
self.name = name

def dispatch(self, callback_name, *args):
"""Runs callback with the given name.
Expand Down Expand Up @@ -92,34 +107,37 @@ class SimpleExtension(TrainingExtension):
If ``True``, :meth:`do` is invoked before the first epoch.
after_every_epoch : bool
If ``True``, :meth:`do` is invoked after every epoch.
after_every_iteration : bool
If ``True``, :meth:`do` is invoked after every iteration.
after_every_batch: bool
If ``True``, :meth:`do` is invoked after every batch.
after_training : bool
If ``True``, :meth:`do` is invoked after training.
after_n_epochs : int, optional
If not ``None``, :meth:`do` is invoked when `after_n_epochs` are
done.
If not ``None``, :meth:`do` is invoked when `after_n_epochs`
epochs are done.
after_n_batches : int, optional
If not ``None``, :meth:`do` is invoked when `after_n_batches`
batches are processed.
"""
def __init__(self, before_first_epoch=False, after_every_epoch=False,
after_every_iteration=False, after_training=False,
after_n_epochs=None):
after_every_batch=False, after_training=False,
after_n_epochs=None, after_n_batches=None, **kwargs):
super(SimpleExtension, self).__init__(**kwargs)
self._conditions = []
if before_first_epoch:
self.add_condition(
"before_epoch",
predicate=lambda log: log.status.epochs_done == 0)
if after_every_epoch:
self.add_condition("after_epoch")
if after_every_iteration:
self.add_condition("after_iteration")
if after_every_batch:
self.add_condition("after_batch")
if after_training:
self.add_condition("after_training")
if after_n_epochs:
self.add_condition(
"after_epoch",
predicate=lambda log:
log.status.epochs_done == after_n_epochs)
self.invoke_after_n_epochs(after_n_epochs)
if after_n_batches:
self.invoke_after_n_batches(after_n_batches)

def add_condition(self, callback_name, predicate=None, arguments=None):
"""Adds a condition under which a :meth:`do` is called.
Expand All @@ -145,6 +163,18 @@ def add_condition(self, callback_name, predicate=None, arguments=None):
predicate = lambda log: True
self._conditions.append((callback_name, predicate, arguments))

def invoke_after_n_epochs(self, n_epochs):
self.add_condition(
"after_epoch",
predicate=lambda log:
log.status.epochs_done == n_epochs)

def invoke_after_n_batches(self, n_batches):
self.add_condition(
"after_batch",
predicate=lambda log:
log.status.iterations_done == n_batches)

@abstractmethod
def do(self, which_callback, *args):
"""Does the job of the training extension.
Expand Down Expand Up @@ -183,27 +213,24 @@ class FinishAfter(SimpleExtension):
def __init__(self, **kwargs):
super(FinishAfter, self).__init__(**kwargs)

def do(self, which_callback):
def do(self, which_callback, *args):
self.main_loop.log.current_row.training_finish_requested = True


class Printing(SimpleExtension):
"""Prints log messages to the screen."""
def __init__(self, **kwargs):
def set_if_absent(name):
if name not in kwargs:
kwargs[name] = True
set_if_absent("before_first_epoch")
set_if_absent("after_training")
set_if_absent("after_every_epoch")
kwargs.setdefault("before_first_epoch", True)
kwargs.setdefault("after_training", True)
kwargs.setdefault("after_every_epoch", True)
super(Printing, self).__init__(**kwargs)

def _print_attributes(self, attribute_tuples):
for attr, value in attribute_tuples:
if not attr.startswith("_"):
print("\t", "{}:".format(attr), value)

def do(self, which_callback):
def do(self, which_callback, *args):
log = self.main_loop.log
print("".join(79 * "-"))
if which_callback == "before_epoch" and log.status.epochs_done == 0:
Expand Down
61 changes: 61 additions & 0 deletions blocks/extensions/saveload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Extensions for saving and loading the state of a training process."""
import dill

from blocks.extensions import SimpleExtension


class SaveLoadBase(SimpleExtension):
"""The base class for save-load extensions.
Contains the logic that can be shared by different save-load
extensions.
"""
def log_saving_done(self, destination):
"""Makes a record in the log that saving has been done.
Parameters
----------
destination : str
The destination where the state of the training process was
saved.
"""
self.main_loop.log.current_row.saving_done_to = destination


class SerializeMainLoop(SaveLoadBase):
"""Saves a pickled version of the main loop to the disk.
The pickled main loop can be later reloaded and training can be
resumed.
Parameters
----------
path : str
The destination path for pickling.
Notes
-----
Instead of the standard pickling library, the dill package is used.
Using pickling for saving the whole main loop object comes with
certain limitations:
* Theano computation graphs build in the GPU-mode
(`theano.config.device == "gpu"`) can not be used in the usual mode
(and vice-versa). Therefore using this extension binds you to using
only one kind of device.
"""
def __init__(self, path, **kwargs):
kwargs.setdefault("after_training", True)
super(SerializeMainLoop, self).__init__(**kwargs)
self.path = path

def do(self, callback_name, *args):
"""Pickle the main loop object to the disk."""
with open(self.path, "wb") as destination:
dill.dump(self.main_loop, destination)
self.log_saving_done(self.path)

0 comments on commit eae3f54

Please sign in to comment.