Skip to content

Commit

Permalink
Implement remaining classification functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed May 21, 2023
1 parent d47e7e0 commit 91922b8
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 10 deletions.
2 changes: 2 additions & 0 deletions torch_em/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .classification import default_classification_trainer, default_classification_loader

from .classification_logger import ClassificationLogger
from .classification_trainer import ClassificationTrainer
73 changes: 73 additions & 0 deletions torch_em/classification/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from functools import partial

import sklearn.metrics as metrics
import torch
import torch_em

from .classification_dataset import ClassificationDataset
from .classification_logger import ClassificationLogger
from .classification_trainer import ClassificationTrainer


class ClassificationMetric:
def __init__(self, metric_name="accuracy_score", **metric_kwargs):
if not hasattr(metrics, metric_name):
raise ValueError(f"Invalid metric_name {metric_name}")
self.metric = getattr(metrics, metric_name)
self.metric_kwargs = metric_kwargs

def __call__(self, y_true, y_pred):
metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs)
return metric_error


def default_classification_loader(
data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs,
):
ndim = data[0].ndim - 1
if ndim not in (2, 3):
raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}")

if normalization is None:
axis = (1, 2) if ndim == 2 else (1, 2, 3)
normalization = partial(torch_em.transform.raw.standardize, axis=axis)

if augmentation is None:
augmentation = torch_em.transform.get_augmentations(ndim=ndim)

dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape)
loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader


# TODO
def zarr_classification_loader():
return default_classification_loader()


def default_classification_trainer(
name,
model,
train_loader,
val_loader,
loss=None,
metric=None,
logger=ClassificationLogger,
trainer_class=ClassificationTrainer,
**kwargs,
):
"""
"""
# set the default loss and metric (if no values where passed)
loss = torch.nn.CrossEntropyLoss() if loss is None else loss
metric = ClassificationMetric() if metric is None else metric

# metric: note that we use lower metric = better !
# so we record the accuracy error instead of the error rate
trainer = torch_em.default_segmentation_trainer(
name, model, train_loader, val_loader,
loss=loss, metric=metric,
logger=logger, trainer_class=trainer_class,
**kwargs,
)
return trainer
41 changes: 41 additions & 0 deletions torch_em/classification/classification_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import torch
from skimage.transform import resize


class ClassificationDataset(torch.utils.data.Dataset):
def __init__(self, data, target, normalization, augmentation, image_shape):
if len(data) != len(target):
raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}")
self.data = data
self.target = target
self.normalization = normalization
self.augmentation = augmentation
self.image_shape = image_shape

def __len__(self):
return len(self.data)

def resize(self, x):
out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x]
return np.concatenate(out, axis=0)

def __getitem__(self, index):
x, y = self.data[index], self.target[index]

# apply normalization
if self.normalization is not None:
x = self.normalization(x)

# resize to sample shape if it was given
if self.image_shape is not None:
x = self.resize(x)

# apply augmentations (if any)
if self.augmentation is not None:
_shape = x.shape
# adds unwanted batch axis
x = self.augmentation(x)[0][0]
assert x.shape == _shape

return x, y
24 changes: 14 additions & 10 deletions torch_em/classification/classification_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ def confusion_matrix(y_true, y_pred, class_labels=None, title=None, save_path=No
# TODO normalization and stuff
# TODO get the class names
def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs):
assert images.ndim == 4
assert images.shape[1] in (1, 3)
assert images.ndim in (4, 5)
assert images.shape[1] in (1, 3), f"{images.shape}"

if images.ndim == 5:
is_3d = True
z = images.shape[2] // 2
else:
is_3d = False

n_images = images.shape[0]
n_rows = n_images // images_per_row
Expand All @@ -55,9 +61,11 @@ def make_grid(images, target=None, prediction=None, images_per_row=8, **kwargs):
for r in range(n_rows):
for c in range(images_per_row):
i = r * images_per_row + c
ax = axes[r, c]
if i == len(images):
break
ax = axes[r, c] if n_rows > 1 else axes[r]
ax.set_axis_off()
im = images[i]
im = images[i, :, z] if is_3d else images[i]
im = im.transpose((1, 2, 0))
if im.shape[-1] == 3: # rgb
ax.imshow(im)
Expand Down Expand Up @@ -95,12 +103,8 @@ def __init__(self, trainer, save_root, **unused_kwargs):

def add_image(self, x, y, pred, name, step):
scale_each = False
marker = make_grid(x[:, 0:1], y, pred, padding=4, normalize=True, scale_each=scale_each)
self.tb.add_image(tag=f"{name}/marker", img_tensor=marker, global_step=step)
nucleus = make_grid(x[:, 1:2], padding=4, normalize=True, scale_each=scale_each)
self.tb.add_image(tag=f"{name}/nucleus", img_tensor=nucleus, global_step=step)
mask = make_grid(x[:, 2:], padding=4)
self.tb.add_image(tag=f"{name}/mask", img_tensor=mask, global_step=step)
grid = make_grid(x, y, pred, padding=4, normalize=True, scale_each=scale_each)
self.tb.add_image(tag=f"{name}/images_and_predictions", img_tensor=grid, global_step=step)

def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
Expand Down

0 comments on commit 91922b8

Please sign in to comment.