New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Processing model checkpointing #256
Changes from 14 commits
9bf1ff4
15bb946
06d304c
17b9bad
61ccc9f
fbbc4c2
cded43c
c7d7cfe
29991dd
8c3fc54
2f51db2
9286eb7
8e6839d
eb32be3
9fe59a6
0185eff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Checkpointing utils""" | ||
|
||
from glob import glob | ||
import logging | ||
import os | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class CheckpointTracker(tf.keras.callbacks.ModelCheckpoint): | ||
"""Class for saving/loading estimators at/from checkpoints.""" | ||
|
||
def __init__(self, estimator, file_path, **kwargs): | ||
""" | ||
estimator: BaseEstimator, instance of an estimator (e.g., Segmentation). | ||
file_path: str, directory to/from which to save or load. | ||
""" | ||
self.estimator = estimator | ||
super().__init__(file_path, **kwargs) | ||
|
||
def _save_model(self, epoch, batch, logs): | ||
"""Save the current state of the estimator. This overrides the | ||
base class implementation to save `nobrainer` specific info. | ||
|
||
epoch: int, the index of the epoch that just finished. | ||
batch: int, the index of the batch that just finished. | ||
logs: dict, logging info passed into on_epoch_end or on_batch_end. | ||
""" | ||
self.save(self._get_file_path(epoch, batch, logs)) | ||
|
||
def save(self, directory): | ||
"""Save the current state of the estimator. | ||
directory: str, path in which to save the model. | ||
""" | ||
logging.info(f"Saving to dir {directory}") | ||
self.estimator.save(directory) | ||
|
||
def load(self): | ||
"""Loads the most-recently created checkpoint from the | ||
checkpoint directory. | ||
""" | ||
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/")) | ||
latest = max(checkpoints, key=os.path.getctime) | ||
self.estimator = self.estimator.load(latest) | ||
logging.info(f"Loaded estimator from {latest}.") | ||
return self.estimator |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""Tests for `nobrainer.processing.checkpoint`.""" | ||
|
||
import os | ||
|
||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
import tensorflow as tf | ||
|
||
from nobrainer.models import meshnet | ||
from nobrainer.processing.segmentation import Segmentation | ||
|
||
|
||
def _get_toy_dataset(): | ||
data_shape = (8, 8, 8, 8, 1) | ||
train = tf.data.Dataset.from_tensors( | ||
(np.random.rand(*data_shape), np.random.randint(0, 1, data_shape)) | ||
) | ||
train.scalar_labels = False | ||
train.n_volumes = data_shape[0] | ||
train.volume_shape = data_shape[1:4] | ||
return train | ||
|
||
|
||
def _assert_model_weights_allclose(model1, model2): | ||
for layer1, layer2 in zip(model1.model.layers, model2.model.layers): | ||
weights1 = layer1.get_weights() | ||
weights2 = layer2.get_weights() | ||
assert len(weights1) == len(weights2) | ||
for index in range(len(weights1)): | ||
assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08) | ||
|
||
|
||
def test_checkpoint(tmp_path): | ||
train = _get_toy_dataset() | ||
|
||
checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}") | ||
model1 = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath) | ||
model1.fit( | ||
dataset_train=train, | ||
epochs=2, | ||
) | ||
|
||
model2 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath) | ||
_assert_model_weights_allclose(model1, model2) | ||
model2.fit( | ||
dataset_train=train, | ||
epochs=3, | ||
) | ||
|
||
model3 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath) | ||
_assert_model_weights_allclose(model2, model3) | ||
|
||
|
||
def test_warm_start_workflow(tmp_path): | ||
train = _get_toy_dataset() | ||
|
||
checkpoint_dir = os.path.join("checkpoints") | ||
checkpoint_filepath = os.path.join(checkpoint_dir, "{epoch:03d}") | ||
if not os.path.exists(checkpoint_dir): | ||
os.mkdir(checkpoint_dir) | ||
|
||
for iteration in range(2): | ||
try: | ||
bem = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath) | ||
assert iteration == 1 | ||
assert bem.model is not None | ||
for layer in bem.model.layers: | ||
for weight_array in layer.get_weights(): | ||
assert np.count_nonzero(weight_array) | ||
except (AssertionError, ValueError): | ||
bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about simply: bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath, warm_start=True) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and have the try-except be a function of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because this isn't a warm start? That seems confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do something like bem = Segmentation.load_or_init_with_checkpoints(meshnet, checkpoint_filepath=checkpoint_filepath) which could initialize from zero if no checkpoints are found? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is a warm start in a way. doing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, pushed that change. |
||
assert iteration == 0 | ||
assert bem.model is None | ||
bem.fit( | ||
dataset_train=train, | ||
epochs=2, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.