-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/constantinpape/torch-em
- Loading branch information
Showing
17 changed files
with
542 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import os | ||
import unittest | ||
from shutil import rmtree | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from torch_em.util import model_is_equal | ||
|
||
|
||
class TestClassification(unittest.TestCase): | ||
def tearDown(self): | ||
if os.path.exists("./checkpoints"): | ||
rmtree("./checkpoints") | ||
if os.path.exists("./logs"): | ||
rmtree("./logs") | ||
|
||
def _check_checkpoint(self, path, expected_iterations, expected_model, model_class, **model_kwargs): | ||
self.assertTrue(os.path.exists(path)) | ||
checkpoint = torch.load(path) | ||
|
||
self.assertIn("optimizer_state", checkpoint) | ||
self.assertIn("model_state", checkpoint) | ||
|
||
loaded_model = model_class(**model_kwargs) | ||
loaded_model.load_state_dict(checkpoint["model_state"]) | ||
self.assertTrue(model_is_equal(expected_model, loaded_model)) | ||
|
||
self.assertEqual(checkpoint["iteration"], expected_iterations) | ||
|
||
def test_classification_2d(self): | ||
from torch_em.classification import default_classification_loader, default_classification_trainer | ||
from torchvision.models.resnet import resnet18 | ||
|
||
shape = (3, 256, 256) | ||
image_shape = (128, 128) | ||
|
||
n_samples = 15 | ||
data = [np.random.rand(*shape) for _ in range(n_samples)] | ||
|
||
n_classes = 8 | ||
target = np.random.randint(0, n_classes, size=n_samples) | ||
|
||
loader = default_classification_loader(data, target, batch_size=1, image_shape=image_shape) | ||
|
||
model = resnet18(num_classes=n_classes) | ||
trainer = default_classification_trainer( | ||
name="test-model-2d", model=model, train_loader=loader, val_loader=loader, | ||
compile_model=False, | ||
) | ||
n_iterations = 18 | ||
trainer.fit(n_iterations) | ||
|
||
self._check_checkpoint( | ||
"./checkpoints/test-model-2d/latest.pt", 18, trainer.model, resnet18, num_classes=n_classes, | ||
) | ||
|
||
def test_classification_3d(self): | ||
from torch_em.classification import default_classification_loader, default_classification_trainer | ||
from torch_em.model.resnet3d import resnet3d_18 | ||
|
||
shape = (1, 128, 128, 128) | ||
image_shape = (64, 64, 64) | ||
|
||
n_samples = 10 | ||
data = [np.random.rand(*shape) for _ in range(n_samples)] | ||
|
||
n_classes = 8 | ||
target = np.random.randint(0, n_classes, size=n_samples) | ||
|
||
loader = default_classification_loader(data, target, batch_size=1, image_shape=image_shape) | ||
|
||
model = resnet3d_18(in_channels=1, out_channels=n_classes) | ||
trainer = default_classification_trainer( | ||
name="test-model-3d", model=model, train_loader=loader, val_loader=loader, | ||
compile_model=False, | ||
) | ||
trainer.fit(12) | ||
|
||
self._check_checkpoint( | ||
"./checkpoints/test-model-3d/latest.pt", 12, trainer.model, resnet3d_18, | ||
in_channels=1, out_channels=n_classes | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import unittest | ||
from copy import deepcopy | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
class TestRaw(unittest.TestCase): | ||
def _test_standardize(self, input_): | ||
from torch_em.transform.raw import standardize | ||
|
||
def check_out(out): | ||
self.assertEqual(out.shape, input_.shape) | ||
if torch.is_tensor(out): | ||
mean, std = out.mean().numpy(), out.std().numpy() | ||
else: | ||
mean, std = out.mean(), out.std() | ||
self.assertLess(mean, 0.001) | ||
self.assertLess(np.abs(1.0 - std), 0.001) | ||
|
||
# test standardize without arguments | ||
out = standardize(deepcopy(input_)) | ||
check_out(out) | ||
|
||
# test standardize with axis | ||
out = standardize(deepcopy(input_), axis=(1, 2)) | ||
check_out(out) | ||
|
||
# test standardize with fixed mean and std | ||
mean, std = input_.mean(), input_.std() | ||
out = standardize(deepcopy(input_), mean=mean, std=std) | ||
check_out(out) | ||
|
||
def test_standardize_torch(self): | ||
input_ = torch.rand(3, 128, 128) | ||
self._test_standardize(input_) | ||
|
||
def test_standardize_numpy(self): | ||
input_ = np.random.rand(3, 128, 128) | ||
self._test_standardize(input_) | ||
|
||
def _test_normalize(self, input_): | ||
from torch_em.transform.raw import normalize | ||
|
||
def check_out(out): | ||
self.assertEqual(out.shape, input_.shape) | ||
if torch.is_tensor(out): | ||
min_, max_ = out.min().numpy(), out.max().numpy() | ||
else: | ||
min_, max_ = out.min(), out.max() | ||
self.assertLess(min_, 0.001) | ||
self.assertLess(np.abs(1.0 - max_), 0.001) | ||
|
||
# test normalize without arguments | ||
out = normalize(deepcopy(input_)) | ||
check_out(out) | ||
|
||
# test normalize with axis | ||
out = normalize(deepcopy(input_), axis=(1, 2)) | ||
check_out(out) | ||
|
||
# test normalize with fixed min, max | ||
min_, max_ = input_.min(), input_.max() - input_.min() | ||
out = normalize(deepcopy(input_), minval=min_, maxval=max_) | ||
check_out(out) | ||
|
||
def test_normalize_torch(self): | ||
input_ = torch.randn(3, 128, 128) | ||
self._test_normalize(input_) | ||
|
||
def test_normalize_numpy(self): | ||
input_ = np.random.randn(3, 128, 128) | ||
self._test_normalize(input_) | ||
|
||
def _test_normalize_percentile(self, input_): | ||
from torch_em.transform.raw import normalize_percentile | ||
|
||
def check_out(out): | ||
self.assertEqual(out.shape, input_.shape) | ||
|
||
# test normalize without arguments | ||
out = normalize_percentile(deepcopy(input_)) | ||
check_out(out) | ||
|
||
# test normalize with axis | ||
out = normalize_percentile(deepcopy(input_), axis=(1, 2)) | ||
check_out(out) | ||
|
||
# test normalize with percentile arguments | ||
out = normalize_percentile(deepcopy(input_), lower=5.0, upper=95.0) | ||
check_out(out) | ||
|
||
def test_normalize_percentile_torch(self): | ||
input_ = torch.randn(3, 128, 128) | ||
self._test_normalize_percentile(input_) | ||
|
||
def test_normalize_percentile_numpy(self): | ||
input_ = np.random.randn(3, 128, 128) | ||
self._test_normalize_percentile(input_) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.4.1" | ||
__version__ = "0.5.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .classification import default_classification_trainer, default_classification_loader | ||
|
||
from .classification_logger import ClassificationLogger | ||
from .classification_trainer import ClassificationTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from functools import partial | ||
|
||
import sklearn.metrics as metrics | ||
import torch | ||
import torch_em | ||
|
||
from .classification_dataset import ClassificationDataset | ||
from .classification_logger import ClassificationLogger | ||
from .classification_trainer import ClassificationTrainer | ||
|
||
|
||
class ClassificationMetric: | ||
def __init__(self, metric_name="accuracy_score", **metric_kwargs): | ||
if not hasattr(metrics, metric_name): | ||
raise ValueError(f"Invalid metric_name {metric_name}") | ||
self.metric = getattr(metrics, metric_name) | ||
self.metric_kwargs = metric_kwargs | ||
|
||
def __call__(self, y_true, y_pred): | ||
metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs) | ||
return metric_error | ||
|
||
|
||
def default_classification_loader( | ||
data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs, | ||
): | ||
ndim = data[0].ndim - 1 | ||
if ndim not in (2, 3): | ||
raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}") | ||
|
||
if normalization is None: | ||
axis = (1, 2) if ndim == 2 else (1, 2, 3) | ||
normalization = partial(torch_em.transform.raw.standardize, axis=axis) | ||
|
||
if augmentation is None: | ||
augmentation = torch_em.transform.get_augmentations(ndim=ndim) | ||
|
||
dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape) | ||
loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs) | ||
return loader | ||
|
||
|
||
# TODO | ||
def zarr_classification_loader(): | ||
return default_classification_loader() | ||
|
||
|
||
def default_classification_trainer( | ||
name, | ||
model, | ||
train_loader, | ||
val_loader, | ||
loss=None, | ||
metric=None, | ||
logger=ClassificationLogger, | ||
trainer_class=ClassificationTrainer, | ||
**kwargs, | ||
): | ||
""" | ||
""" | ||
# set the default loss and metric (if no values where passed) | ||
loss = torch.nn.CrossEntropyLoss() if loss is None else loss | ||
metric = ClassificationMetric() if metric is None else metric | ||
|
||
# metric: note that we use lower metric = better ! | ||
# so we record the accuracy error instead of the error rate | ||
trainer = torch_em.default_segmentation_trainer( | ||
name, model, train_loader, val_loader, | ||
loss=loss, metric=metric, | ||
logger=logger, trainer_class=trainer_class, | ||
**kwargs, | ||
) | ||
return trainer |
Oops, something went wrong.