Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing model checkpointing #256

Merged
merged 16 commits into from Aug 25, 2023
Merged

Processing model checkpointing #256

merged 16 commits into from Aug 25, 2023

Conversation

ohinds
Copy link
Contributor

@ohinds ohinds commented Aug 21, 2023

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Summary

This PR adds checkpointing to the segmentation processing estimator. It introduces a class that derives from the tensorflow ModelCheckpoint that will save nobrainer-specific information at the checkpoint.

Checklist

  • I have added tests to cover my changes
  • I have updated documentation (if necessary)

Acknowledgment

  • I acknowledge that this contribution will be available under the Apache 2 license.

@ohinds ohinds requested a review from satra August 21, 2023 23:29
Copy link
Contributor

@satra satra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments on a few design choices that are not clear.

@@ -74,8 +82,13 @@ def _compile():
)

if warm_start:
if checkpoint_tracker:
self = checkpoint_tracker.load()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be a static method or classmethod if it overwrites self? perhaps take a look at the save/load method in base estimator.

Copy link
Contributor Author

@ohinds ohinds Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was bothering me, too. Let me try reworking the interface a bit. The decisions I'm wrestling with are mostly related to how much to hide behind the scenes while still providing flexibility and consistency.

The way things are now, one decides to warm start and load a checkpoint when calling fit(), which seems logical to me, but then this overwriting self part with the loaded checkpoint seems like the right thing to do. I will try moving all this to static construction to see how the interface looks.

Comment on lines 34 to 38
checkpoint_dir=os.getcwd(),
checkpoint_file_path=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't checkpoint often a directory? why is this asking for a file path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the tf/keras approach, which I'm extending here (see https://keras.io/guides/training_with_built_in_methods/#checkpointing-models).

filepath can be a directory, but it's used to format in the relevant epoch or batch variable as well. See the test below for using filepath to save in a directory.

Copy link
Contributor

@satra satra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few suggestions for the failing pre-commit test.

i think this looks cleaner than before. it would be good to consider what a version of script in a slurm setting where the job is cancelled and requeued. how would that do it?

import numpy as np
from numpy.testing import assert_allclose
import os
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import pytest

@@ -1,13 +1,15 @@
"""Tests for `nobrainer.processing.checkpoint`."""

from nobrainer.processing.segmentation import Segmentation
from nobrainer.models import meshnet
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import os

@@ -3,6 +3,7 @@
from glob import glob
import logging
import os

import tensorflow as tf

from .base import BaseEstimator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .base import BaseEstimator

@ohinds
Copy link
Contributor Author

ohinds commented Aug 24, 2023

a few suggestions for the failing pre-commit test.

i think this looks cleaner than before. it would be good to consider what a version of script in a slurm setting where the job is cancelled and requeued. how would that do it?

I added a test test_warm_start_workflow that demonstrates how this would be done. I could wrap this inside the BaseEstimator if you think that would be cleaner.

for weight_array in layer.get_weights():
assert np.count_nonzero(weight_array)
except (AssertionError, ValueError):
bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about simply:

bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath, warm_start=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and have the try-except be a function of warm_start inside the baseestimator and checkpoint_filepath

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this isn't a warm start? That seems confusing.

Copy link
Contributor Author

@ohinds ohinds Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do something like

bem = Segmentation.load_or_init_with_checkpoints(meshnet, checkpoint_filepath=checkpoint_filepath)

which could initialize from zero if no checkpoints are found?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a warm start in a way. doing init_with_checkpoints sounds good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, pushed that change.

@satra
Copy link
Contributor

satra commented Aug 25, 2023

@ohinds - guide notebook still not running.

@ohinds
Copy link
Contributor Author

ohinds commented Aug 25, 2023

@ohinds - guide notebook still not running.

That was a transient error. I restarted.

@satra satra merged commit 99eaf9d into master Aug 25, 2023
7 checks passed
@hvgazula hvgazula deleted the ohinds-model-checkpointing branch March 8, 2024 15:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants