Skip to content

Commit

Permalink
Merge pull request #976 from Thrandis/ccw2
Browse files Browse the repository at this point in the history
WIP: New Serialization 2
  • Loading branch information
rizar committed Feb 24, 2016
2 parents 5107756 + 6ee25c9 commit 568e7a7
Show file tree
Hide file tree
Showing 8 changed files with 688 additions and 266 deletions.
85 changes: 40 additions & 45 deletions blocks/extensions/saveload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
import os.path
import logging

from six.moves import cPickle

from blocks.extensions import SimpleExtension, TrainingExtension
from blocks.utils import reraise_as
from blocks.serialization import (
secure_dump, load, load_parameter_values, DEFAULT_PROTOCOL)
from blocks.serialization import (secure_dump, load, dump_and_add_to_dump,
load_parameters)

logger = logging.getLogger(__name__)

Expand All @@ -31,12 +29,18 @@ class Checkpoint(SimpleExtension):
----------
path : str
The destination path for pickling.
parameters : list, optional
The parameters to save separately. If None, the parameters from
the model (main_loop.model.parameters) are saved.
save_separately : list of str, optional
The list of the main loop's attributes to be pickled separately
to their own files. The paths will be formed by adding
the attribute name preceded by an underscore before the
`path` extension. The whole main loop will still be pickled
as usual.
The list of the main loop's attributes to be saved (copied)
in a separate file in the tar archive. It may be used for example
to save the log separetely. The name of the attribute will be used
as name in the tar file.
save_main_loop : bool
Choose whether to save the main loop or not. This can be useful
for example if you are only interested in saving the parameters,
but not the whole main loop. Defaults to `True`.
use_cpickle : bool
See documentation of :func:`~blocks.serialization.dump`.
Expand All @@ -52,35 +56,16 @@ class Checkpoint(SimpleExtension):
"""
def __init__(self, path, save_separately=None, use_cpickle=False,
**kwargs):
def __init__(self, path, parameters=None, save_separately=None,
save_main_loop=True, use_cpickle=False, **kwargs):
kwargs.setdefault("after_training", True)
super(Checkpoint, self).__init__(**kwargs)
if not save_separately:
save_separately = []
self.path = path
self.parameters = parameters
self.save_separately = save_separately
self.save_main_loop = save_main_loop
self.use_cpickle = use_cpickle

def save_separately_filenames(self, path):
"""Compute paths for separately saved attributes.
Parameters
----------
path : str
Path to which the main checkpoint file is being saved.
Returns
-------
paths : dict
A dictionary mapping attribute names to derived paths
based on the `path` passed in as an argument.
"""
root, ext = os.path.splitext(path)
return {attribute: root + "_" + attribute + ext
for attribute in self.save_separately}

def do(self, callback_name, *args):
"""Pickle the main loop object to the disk.
Expand All @@ -94,12 +79,21 @@ def do(self, callback_name, *args):
path = self.path
if from_user:
path, = from_user
secure_dump(self.main_loop, path, use_cpickle=self.use_cpickle)
filenames = self.save_separately_filenames(path)
for attribute in self.save_separately:
secure_dump(getattr(self.main_loop, attribute),
filenames[attribute], cPickle.dump,
protocol=DEFAULT_PROTOCOL)
to_add = None
if self.save_separately:
to_add = {attr: getattr(self.main_loop, attr) for attr in
self.save_separately}
if self.parameters is None:
if hasattr(self.main_loop, 'model'):
self.parameters = self.main_loop.model.parameters
object_ = None
if self.save_main_loop:
object_ = self.main_loop
secure_dump(object_, path,
dump_function=dump_and_add_to_dump,
parameters=self.parameters,
to_add=to_add,
use_cpickle=self.use_cpickle)
except Exception:
path = None
raise
Expand Down Expand Up @@ -146,14 +140,15 @@ def __init__(self, path, load_iteration_state=False, load_log=False,
self.load_log = load_log

def load_to(self, main_loop):
main_loop.model.set_parameter_values(load_parameter_values(self.path))
if self.load_iteration_state or self.load_log:
with open(self.path, "rb") as source:
with open(self.path, "rb") as source:
main_loop.model.set_parameter_values(load_parameters(source))
if self.load_iteration_state or self.load_log:
loaded_main_loop = load(source)
if self.load_log:
main_loop.log = loaded_main_loop.log
if self.load_iteration_state:
main_loop.iteration_state = loaded_main_loop.iteration_state
if self.load_log:
main_loop.log = loaded_main_loop.log
if self.load_iteration_state:
main_loop.iteration_state = \
loaded_main_loop.iteration_state

def before_training(self):
if not os.path.exists(self.path):
Expand Down

0 comments on commit 568e7a7

Please sign in to comment.