Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing model checkpointing #256

Merged
merged 16 commits into from Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 18 additions & 3 deletions nobrainer/processing/base.py
Expand Up @@ -20,7 +20,13 @@ class BaseEstimator:
state_variables = []
model_ = None

def __init__(self, multi_gpu=False):
def __init__(self, checkpoint_filepath=None, multi_gpu=False):
self.checkpoint_tracker = None
if checkpoint_filepath:
from .checkpoint import CheckpointTracker

self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath)

self.strategy = get_strategy(multi_gpu)

@property
Expand All @@ -38,7 +44,7 @@ def save(self, save_dir):
# are stored as members, which doesn't leave room for
# parameters that are specific to the runtime context.
# (e.g. multi_gpu).
if key == "multi_gpu":
if key == "multi_gpu" or key == "checkpoint_filepath":
continue
model_info["__init__"][key] = getattr(self, key)
for val in self.state_variables:
Expand All @@ -49,7 +55,7 @@ def save(self, save_dir):

@classmethod
def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
"""Saves a trained model"""
"""Loads a trained model from a save directory"""
model_dir = Path(str(model_dir).rstrip(os.pathsep))
assert model_dir.exists() and model_dir.is_dir()
model_file = model_dir / "model_params.pkl"
Expand All @@ -70,6 +76,15 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
)
return klass

@classmethod
def load_latest(cls, checkpoint_filepath):
from .checkpoint import CheckpointTracker

checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath)
estimator = checkpoint_tracker.load()
estimator.checkpoint_tracker = checkpoint_tracker
return estimator


class TransformerMixin:
"""Mixin class for all transformers in scikit-learn."""
Expand Down
46 changes: 46 additions & 0 deletions nobrainer/processing/checkpoint.py
@@ -0,0 +1,46 @@
"""Checkpointing utils"""

from glob import glob
import logging
import os

import tensorflow as tf


class CheckpointTracker(tf.keras.callbacks.ModelCheckpoint):
"""Class for saving/loading estimators at/from checkpoints."""

def __init__(self, estimator, file_path, **kwargs):
"""
estimator: BaseEstimator, instance of an estimator (e.g., Segmentation).
file_path: str, directory to/from which to save or load.
"""
self.estimator = estimator
super().__init__(file_path, **kwargs)

def _save_model(self, epoch, batch, logs):
"""Save the current state of the estimator. This overrides the
base class implementation to save `nobrainer` specific info.

epoch: int, the index of the epoch that just finished.
batch: int, the index of the batch that just finished.
logs: dict, logging info passed into on_epoch_end or on_batch_end.
"""
self.save(self._get_file_path(epoch, batch, logs))

def save(self, directory):
"""Save the current state of the estimator.
directory: str, path in which to save the model.
"""
logging.info(f"Saving to dir {directory}")
self.estimator.save(directory)

def load(self):
"""Loads the most-recently created checkpoint from the
checkpoint directory.
"""
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/"))
latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(latest)
logging.info(f"Loaded estimator from {latest}.")
return self.estimator
30 changes: 16 additions & 14 deletions nobrainer/processing/segmentation.py
@@ -1,20 +1,24 @@
import importlib
import os
import logging

import tensorflow as tf

from .base import BaseEstimator
from .. import losses, metrics
from ..dataset import get_steps_per_epoch

logging.getLogger().setLevel(logging.INFO)


class Segmentation(BaseEstimator):
"""Perform segmentation type operations"""

state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(self, base_model, model_args=None, multi_gpu=False):
super().__init__(multi_gpu=multi_gpu)
def __init__(
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False
):
super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)

if not isinstance(base_model, str):
self.base_model = base_model.__name__
Expand All @@ -31,8 +35,6 @@ def fit(
dataset_train,
dataset_validate=None,
epochs=1,
checkpoint_dir=os.getcwd(),
warm_start=False,
# TODO: figure out whether optimizer args should be flattened
optimizer=None,
opt_args=None,
Expand Down Expand Up @@ -73,21 +75,17 @@ def _compile():
metrics=metrics,
)

if warm_start:
if self.model is None:
raise ValueError("warm_start requested, but model is undefined")
with self.strategy.scope():
_compile()
else:
if self.model is None:
mod = importlib.import_module("..models", "nobrainer.processing")
base_model = getattr(mod, self.base_model)
if batch_size % self.strategy.num_replicas_in_sync:
raise ValueError("batch size must be a multiple of the number of GPUs")

with self.strategy.scope():
_create(base_model)
_compile()
print(self.model_.summary())
with self.strategy.scope():
_compile()
self.model_.summary()

train_steps = get_steps_per_epoch(
n_volumes=dataset_train.n_volumes,
Expand All @@ -105,13 +103,17 @@ def _compile():
batch_size=batch_size,
)

# TODO add checkpoint
callbacks = []
if self.checkpoint_tracker:
callbacks.append(self.checkpoint_tracker)

self.model_.fit(
dataset_train,
epochs=epochs,
steps_per_epoch=train_steps,
validation_data=dataset_validate,
validation_steps=evaluate_steps,
callbacks=callbacks,
)

return self
Expand Down
77 changes: 77 additions & 0 deletions nobrainer/tests/checkpoint_test.py
@@ -0,0 +1,77 @@
"""Tests for `nobrainer.processing.checkpoint`."""

import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import os


import numpy as np
from numpy.testing import assert_allclose
import tensorflow as tf

from nobrainer.models import meshnet
from nobrainer.processing.segmentation import Segmentation


def _get_toy_dataset():
data_shape = (8, 8, 8, 8, 1)
train = tf.data.Dataset.from_tensors(
(np.random.rand(*data_shape), np.random.randint(0, 1, data_shape))
)
train.scalar_labels = False
train.n_volumes = data_shape[0]
train.volume_shape = data_shape[1:4]
return train


def _assert_model_weights_allclose(model1, model2):
for layer1, layer2 in zip(model1.model.layers, model2.model.layers):
weights1 = layer1.get_weights()
weights2 = layer2.get_weights()
assert len(weights1) == len(weights2)
for index in range(len(weights1)):
assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08)


def test_checkpoint(tmp_path):
train = _get_toy_dataset()

checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}")
model1 = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath)
model1.fit(
dataset_train=train,
epochs=2,
)

model2 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath)
_assert_model_weights_allclose(model1, model2)
model2.fit(
dataset_train=train,
epochs=3,
)

model3 = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath)
_assert_model_weights_allclose(model2, model3)


def test_warm_start_workflow(tmp_path):
train = _get_toy_dataset()

checkpoint_dir = os.path.join("checkpoints")
checkpoint_filepath = os.path.join(checkpoint_dir, "{epoch:03d}")
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)

for iteration in range(2):
try:
bem = Segmentation.load_latest(checkpoint_filepath=checkpoint_filepath)
assert iteration == 1
assert bem.model is not None
for layer in bem.model.layers:
for weight_array in layer.get_weights():
assert np.count_nonzero(weight_array)
except (AssertionError, ValueError):
bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about simply:

bem = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath, warm_start=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and have the try-except be a function of warm_start inside the baseestimator and checkpoint_filepath

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this isn't a warm start? That seems confusing.

Copy link
Contributor Author

@ohinds ohinds Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do something like

bem = Segmentation.load_or_init_with_checkpoints(meshnet, checkpoint_filepath=checkpoint_filepath)

which could initialize from zero if no checkpoints are found?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a warm start in a way. doing init_with_checkpoints sounds good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, pushed that change.

assert iteration == 0
assert bem.model is None
bem.fit(
dataset_train=train,
epochs=2,
)