Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing model checkpointing #256

Merged
merged 16 commits into from Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions nobrainer/processing/base.py
Expand Up @@ -24,6 +24,7 @@ def __init__(self, checkpoint_filepath=None, multi_gpu=False):
self.checkpoint_tracker = None
if checkpoint_filepath:
from .checkpoint import CheckpointTracker

self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath)

self.strategy = get_strategy(multi_gpu)
Expand Down Expand Up @@ -78,6 +79,7 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
@classmethod
def load_latest(cls, checkpoint_filepath):
from .checkpoint import CheckpointTracker

checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath)
estimator = checkpoint_tracker.load()
estimator.checkpoint_tracker = checkpoint_tracker
Expand Down
3 changes: 2 additions & 1 deletion nobrainer/processing/checkpoint.py
Expand Up @@ -3,6 +3,7 @@
from glob import glob
import logging
import os

import tensorflow as tf

from .base import BaseEstimator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .base import BaseEstimator

Expand Down Expand Up @@ -40,7 +41,7 @@ def load(self):
"""Loads the most-recently created checkpoint from the
checkpoint directory.
"""
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), '*/'))
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/"))
latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(latest)
logging.info(f"Loaded estimator from {latest}.")
Expand Down
5 changes: 3 additions & 2 deletions nobrainer/processing/segmentation.py
Expand Up @@ -8,7 +8,6 @@
from .. import losses, metrics
from ..dataset import get_steps_per_epoch


logging.getLogger().setLevel(logging.INFO)


Expand All @@ -17,7 +16,9 @@ class Segmentation(BaseEstimator):

state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False):
def __init__(
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False
):
super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)

if not isinstance(base_model, str):
Expand Down
14 changes: 8 additions & 6 deletions nobrainer/tests/checkpoint_test.py
@@ -1,13 +1,15 @@
"""Tests for `nobrainer.processing.checkpoint`."""

from nobrainer.processing.segmentation import Segmentation
from nobrainer.models import meshnet
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import os


import numpy as np
from numpy.testing import assert_allclose
import os
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import pytest

import tensorflow as tf

from nobrainer.models import meshnet
from nobrainer.processing.segmentation import Segmentation


def _assert_model_weights_allclose(model1, model2):
for layer1, layer2 in zip(model1.model.layers, model2.model.layers):
Expand All @@ -17,17 +19,17 @@ def _assert_model_weights_allclose(model1, model2):
for index in range(len(weights1)):
assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08)


def test_checkpoint(tmp_path):
data_shape = (8, 8, 8, 8, 1)
train = tf.data.Dataset.from_tensors(
(np.random.rand(*data_shape),
np.random.randint(0, 1, data_shape))
(np.random.rand(*data_shape), np.random.randint(0, 1, data_shape))
)
train.scalar_labels = False
train.n_volumes = data_shape[0]
train.volume_shape = data_shape[1:4]

checkpoint_filepath = os.path.join(tmp_path, 'checkpoint-epoch_{epoch:03d}')
checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}")
model1 = Segmentation(meshnet, checkpoint_filepath=checkpoint_filepath)
model1.fit(
dataset_train=train,
Expand Down