New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Processing model checkpointing #256
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments on a few design choices that are not clear.
nobrainer/processing/segmentation.py
Outdated
@@ -74,8 +82,13 @@ def _compile(): | |||
) | |||
|
|||
if warm_start: | |||
if checkpoint_tracker: | |||
self = checkpoint_tracker.load() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be a static method or classmethod if it overwrites self
? perhaps take a look at the save/load method in base estimator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this was bothering me, too. Let me try reworking the interface a bit. The decisions I'm wrestling with are mostly related to how much to hide behind the scenes while still providing flexibility and consistency.
The way things are now, one decides to warm start and load a checkpoint when calling fit()
, which seems logical to me, but then this overwriting self part with the loaded checkpoint seems like the right thing to do. I will try moving all this to static construction to see how the interface looks.
nobrainer/processing/segmentation.py
Outdated
checkpoint_dir=os.getcwd(), | ||
checkpoint_file_path=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't checkpoint often a directory? why is this asking for a file path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the tf/keras approach, which I'm extending here (see https://keras.io/guides/training_with_built_in_methods/#checkpointing-models).
filepath
can be a directory, but it's used to format in the relevant epoch
or batch
variable as well. See the test below for using filepath
to save in a directory.
…nets/nobrainer into ohinds-model-checkpointing
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few suggestions for the failing pre-commit test.
i think this looks cleaner than before. it would be good to consider what a version of script in a slurm setting where the job is cancelled and requeued. how would that do it?
nobrainer/tests/checkpoint_test.py
Outdated
import numpy as np | ||
from numpy.testing import assert_allclose | ||
import os | ||
import pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import pytest |
@@ -1,13 +1,15 @@ | |||
"""Tests for `nobrainer.processing.checkpoint`.""" | |||
|
|||
from nobrainer.processing.segmentation import Segmentation | |||
from nobrainer.models import meshnet | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import os |
nobrainer/processing/checkpoint.py
Outdated
@@ -3,6 +3,7 @@ | |||
from glob import glob | |||
import logging | |||
import os | |||
|
|||
import tensorflow as tf | |||
|
|||
from .base import BaseEstimator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from .base import BaseEstimator |
…nets/nobrainer into ohinds-model-checkpointing
for more information, see https://pre-commit.ci
I added a test |
nobrainer/tests/checkpoint_test.py
Outdated
for weight_array in layer.get_weights(): | ||
assert np.count_nonzero(weight_array) | ||
except (AssertionError, ValueError): | ||
bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about simply:
bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath, warm_start=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and have the try-except be a function of warm_start
inside the baseestimator and checkpoint_filepath
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this isn't a warm start? That seems confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do something like
bem = Segmentation.load_or_init_with_checkpoints(meshnet, checkpoint_filepath=checkpoint_filepath)
which could initialize from zero if no checkpoints are found?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is a warm start in a way. doing init_with_checkpoints
sounds good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, pushed that change.
@ohinds - guide notebook still not running. |
That was a transient error. I restarted. |
Types of changes
Summary
This PR adds checkpointing to the segmentation processing estimator. It introduces a class that derives from the tensorflow
ModelCheckpoint
that will savenobrainer
-specific information at the checkpoint.Checklist
Acknowledgment