Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed May 26, 2023
2 parents a983c19 + a93ff57 commit c2eaed8
Show file tree
Hide file tree
Showing 17 changed files with 542 additions and 43 deletions.
31 changes: 29 additions & 2 deletions test/data/test_segmentation_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
import unittest
from shutil import rmtree

import h5py
import numpy as np
from torch_em.util.test import create_segmentation_test_data


class TestSegmentationDataset(unittest.TestCase):
path = "./data.h5"
tmp_folder = "./tmp"
path = "./tmp/data.h5"

def setUp(self):
os.makedirs(self.tmp_folder, exist_ok=True)

def tearDown(self):
os.remove(self.path)
rmtree(self.tmp_folder)

def create_default_data(self, raw_key, label_key):
shape = (128,) * 3
Expand Down Expand Up @@ -178,6 +183,28 @@ def test_with_raw_and_label_channels(self):
self.assertEqual(x.shape, expected_raw_shape)
self.assertEqual(y.shape, expected_label_shape)

def test_tif(self):
import imageio.v3 as imageio
from torch_em.data import SegmentationDataset

raw_path = os.path.join(self.tmp_folder, "raw.tif")
label_path = os.path.join(self.tmp_folder, "labels.tif")
shape = (128, 128, 128)
imageio.imwrite(raw_path, np.random.rand(*shape).astype("float32"))
imageio.imwrite(label_path, np.random.rand(*shape).astype("float32"))

patch_shape = (32, 32, 32)
raw_key, label_key = None, None
ds = SegmentationDataset(
raw_path, raw_key, label_path, label_key, patch_shape=patch_shape
)

expected_patch_shape = (1,) + patch_shape
for i in range(10):
x, y = ds[i]
self.assertEqual(x.shape, expected_patch_shape)
self.assertEqual(y.shape, expected_patch_shape)


if __name__ == "__main__":
unittest.main()
87 changes: 87 additions & 0 deletions test/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import unittest
from shutil import rmtree

import numpy as np
import torch

from torch_em.util import model_is_equal


class TestClassification(unittest.TestCase):
def tearDown(self):
if os.path.exists("./checkpoints"):
rmtree("./checkpoints")
if os.path.exists("./logs"):
rmtree("./logs")

def _check_checkpoint(self, path, expected_iterations, expected_model, model_class, **model_kwargs):
self.assertTrue(os.path.exists(path))
checkpoint = torch.load(path)

self.assertIn("optimizer_state", checkpoint)
self.assertIn("model_state", checkpoint)

loaded_model = model_class(**model_kwargs)
loaded_model.load_state_dict(checkpoint["model_state"])
self.assertTrue(model_is_equal(expected_model, loaded_model))

self.assertEqual(checkpoint["iteration"], expected_iterations)

def test_classification_2d(self):
from torch_em.classification import default_classification_loader, default_classification_trainer
from torchvision.models.resnet import resnet18

shape = (3, 256, 256)
image_shape = (128, 128)

n_samples = 15
data = [np.random.rand(*shape) for _ in range(n_samples)]

n_classes = 8
target = np.random.randint(0, n_classes, size=n_samples)

loader = default_classification_loader(data, target, batch_size=1, image_shape=image_shape)

model = resnet18(num_classes=n_classes)
trainer = default_classification_trainer(
name="test-model-2d", model=model, train_loader=loader, val_loader=loader,
compile_model=False,
)
n_iterations = 18
trainer.fit(n_iterations)

self._check_checkpoint(
"./checkpoints/test-model-2d/latest.pt", 18, trainer.model, resnet18, num_classes=n_classes,
)

def test_classification_3d(self):
from torch_em.classification import default_classification_loader, default_classification_trainer
from torch_em.model.resnet3d import resnet3d_18

shape = (1, 128, 128, 128)
image_shape = (64, 64, 64)

n_samples = 10
data = [np.random.rand(*shape) for _ in range(n_samples)]

n_classes = 8
target = np.random.randint(0, n_classes, size=n_samples)

loader = default_classification_loader(data, target, batch_size=1, image_shape=image_shape)

model = resnet3d_18(in_channels=1, out_channels=n_classes)
trainer = default_classification_trainer(
name="test-model-3d", model=model, train_loader=loader, val_loader=loader,
compile_model=False,
)
trainer.fit(12)

self._check_checkpoint(
"./checkpoints/test-model-3d/latest.pt", 12, trainer.model, resnet3d_18,
in_channels=1, out_channels=n_classes
)


if __name__ == "__main__":
unittest.main()
19 changes: 11 additions & 8 deletions test/transform/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import numpy
import unittest

import numpy as np
import torch

from torch_em.transform import Tile


from unittest import TestCase


class TestTile(TestCase):
class TestTile(unittest.TestCase):
def test_tile(self):
for ndim, reps in [(1, (4, 2)), (2, (4, 2)), (3, (4, 2))]:
with self.subTest():
Expand All @@ -17,11 +16,11 @@ def test_tile(self):
def _test_tile_impl(ndim, reps):
tile_aug = Tile(reps, match_shape_exactly=len(reps) == ndim)
test_shape = [2, 3, 4][:ndim]
data = numpy.random.random(test_shape)
data = np.random.random(test_shape)

x = torch.tensor(data)

expected = numpy.tile(x.numpy(), reps)
expected = np.tile(x.numpy(), reps)
if len(reps) == ndim:
expected_torch = x.repeat(*reps)
assert expected.shape == expected_torch.shape
Expand All @@ -30,7 +29,11 @@ def _test_tile_impl(ndim, reps):

assert actual.shape == expected.shape

a = numpy.array(data)
a = np.array(data)

actual = tile_aug(a)
assert actual.shape == expected.shape


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/transform/test_label_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,5 @@ def test_distance_transform_empty_labels(self):
self.assertTrue(np.allclose(tnew, 1.0))


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
103 changes: 103 additions & 0 deletions test/transform/test_raw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import unittest
from copy import deepcopy

import numpy as np
import torch


class TestRaw(unittest.TestCase):
def _test_standardize(self, input_):
from torch_em.transform.raw import standardize

def check_out(out):
self.assertEqual(out.shape, input_.shape)
if torch.is_tensor(out):
mean, std = out.mean().numpy(), out.std().numpy()
else:
mean, std = out.mean(), out.std()
self.assertLess(mean, 0.001)
self.assertLess(np.abs(1.0 - std), 0.001)

# test standardize without arguments
out = standardize(deepcopy(input_))
check_out(out)

# test standardize with axis
out = standardize(deepcopy(input_), axis=(1, 2))
check_out(out)

# test standardize with fixed mean and std
mean, std = input_.mean(), input_.std()
out = standardize(deepcopy(input_), mean=mean, std=std)
check_out(out)

def test_standardize_torch(self):
input_ = torch.rand(3, 128, 128)
self._test_standardize(input_)

def test_standardize_numpy(self):
input_ = np.random.rand(3, 128, 128)
self._test_standardize(input_)

def _test_normalize(self, input_):
from torch_em.transform.raw import normalize

def check_out(out):
self.assertEqual(out.shape, input_.shape)
if torch.is_tensor(out):
min_, max_ = out.min().numpy(), out.max().numpy()
else:
min_, max_ = out.min(), out.max()
self.assertLess(min_, 0.001)
self.assertLess(np.abs(1.0 - max_), 0.001)

# test normalize without arguments
out = normalize(deepcopy(input_))
check_out(out)

# test normalize with axis
out = normalize(deepcopy(input_), axis=(1, 2))
check_out(out)

# test normalize with fixed min, max
min_, max_ = input_.min(), input_.max() - input_.min()
out = normalize(deepcopy(input_), minval=min_, maxval=max_)
check_out(out)

def test_normalize_torch(self):
input_ = torch.randn(3, 128, 128)
self._test_normalize(input_)

def test_normalize_numpy(self):
input_ = np.random.randn(3, 128, 128)
self._test_normalize(input_)

def _test_normalize_percentile(self, input_):
from torch_em.transform.raw import normalize_percentile

def check_out(out):
self.assertEqual(out.shape, input_.shape)

# test normalize without arguments
out = normalize_percentile(deepcopy(input_))
check_out(out)

# test normalize with axis
out = normalize_percentile(deepcopy(input_), axis=(1, 2))
check_out(out)

# test normalize with percentile arguments
out = normalize_percentile(deepcopy(input_), lower=5.0, upper=95.0)
check_out(out)

def test_normalize_percentile_torch(self):
input_ = torch.randn(3, 128, 128)
self._test_normalize_percentile(input_)

def test_normalize_percentile_numpy(self):
input_ = np.random.randn(3, 128, 128)
self._test_normalize_percentile(input_)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion torch_em/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.1"
__version__ = "0.5.0"
4 changes: 4 additions & 0 deletions torch_em/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .classification import default_classification_trainer, default_classification_loader

from .classification_logger import ClassificationLogger
from .classification_trainer import ClassificationTrainer
73 changes: 73 additions & 0 deletions torch_em/classification/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from functools import partial

import sklearn.metrics as metrics
import torch
import torch_em

from .classification_dataset import ClassificationDataset
from .classification_logger import ClassificationLogger
from .classification_trainer import ClassificationTrainer


class ClassificationMetric:
def __init__(self, metric_name="accuracy_score", **metric_kwargs):
if not hasattr(metrics, metric_name):
raise ValueError(f"Invalid metric_name {metric_name}")
self.metric = getattr(metrics, metric_name)
self.metric_kwargs = metric_kwargs

def __call__(self, y_true, y_pred):
metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs)
return metric_error


def default_classification_loader(
data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs,
):
ndim = data[0].ndim - 1
if ndim not in (2, 3):
raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}")

if normalization is None:
axis = (1, 2) if ndim == 2 else (1, 2, 3)
normalization = partial(torch_em.transform.raw.standardize, axis=axis)

if augmentation is None:
augmentation = torch_em.transform.get_augmentations(ndim=ndim)

dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape)
loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader


# TODO
def zarr_classification_loader():
return default_classification_loader()


def default_classification_trainer(
name,
model,
train_loader,
val_loader,
loss=None,
metric=None,
logger=ClassificationLogger,
trainer_class=ClassificationTrainer,
**kwargs,
):
"""
"""
# set the default loss and metric (if no values where passed)
loss = torch.nn.CrossEntropyLoss() if loss is None else loss
metric = ClassificationMetric() if metric is None else metric

# metric: note that we use lower metric = better !
# so we record the accuracy error instead of the error rate
trainer = torch_em.default_segmentation_trainer(
name, model, train_loader, val_loader,
loss=loss, metric=metric,
logger=logger, trainer_class=trainer_class,
**kwargs,
)
return trainer
Loading

0 comments on commit c2eaed8

Please sign in to comment.