-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement remaining classification functionality
- Loading branch information
1 parent
d47e7e0
commit 91922b8
Showing
4 changed files
with
130 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from .classification import default_classification_trainer, default_classification_loader | ||
|
||
from .classification_logger import ClassificationLogger | ||
from .classification_trainer import ClassificationTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from functools import partial | ||
|
||
import sklearn.metrics as metrics | ||
import torch | ||
import torch_em | ||
|
||
from .classification_dataset import ClassificationDataset | ||
from .classification_logger import ClassificationLogger | ||
from .classification_trainer import ClassificationTrainer | ||
|
||
|
||
class ClassificationMetric: | ||
def __init__(self, metric_name="accuracy_score", **metric_kwargs): | ||
if not hasattr(metrics, metric_name): | ||
raise ValueError(f"Invalid metric_name {metric_name}") | ||
self.metric = getattr(metrics, metric_name) | ||
self.metric_kwargs = metric_kwargs | ||
|
||
def __call__(self, y_true, y_pred): | ||
metric_error = 1.0 - self.metric(y_true, y_pred, **self.metric_kwargs) | ||
return metric_error | ||
|
||
|
||
def default_classification_loader( | ||
data, target, batch_size, normalization=None, augmentation=None, image_shape=None, **loader_kwargs, | ||
): | ||
ndim = data[0].ndim - 1 | ||
if ndim not in (2, 3): | ||
raise ValueError(f"Expect input data of dimensionality 2 or 3, got {ndim}") | ||
|
||
if normalization is None: | ||
axis = (1, 2) if ndim == 2 else (1, 2, 3) | ||
normalization = partial(torch_em.transform.raw.standardize, axis=axis) | ||
|
||
if augmentation is None: | ||
augmentation = torch_em.transform.get_augmentations(ndim=ndim) | ||
|
||
dataset = ClassificationDataset(data, target, normalization, augmentation, image_shape) | ||
loader = torch_em.segmentation.get_data_loader(dataset, batch_size, **loader_kwargs) | ||
return loader | ||
|
||
|
||
# TODO | ||
def zarr_classification_loader(): | ||
return default_classification_loader() | ||
|
||
|
||
def default_classification_trainer( | ||
name, | ||
model, | ||
train_loader, | ||
val_loader, | ||
loss=None, | ||
metric=None, | ||
logger=ClassificationLogger, | ||
trainer_class=ClassificationTrainer, | ||
**kwargs, | ||
): | ||
""" | ||
""" | ||
# set the default loss and metric (if no values where passed) | ||
loss = torch.nn.CrossEntropyLoss() if loss is None else loss | ||
metric = ClassificationMetric() if metric is None else metric | ||
|
||
# metric: note that we use lower metric = better ! | ||
# so we record the accuracy error instead of the error rate | ||
trainer = torch_em.default_segmentation_trainer( | ||
name, model, train_loader, val_loader, | ||
loss=loss, metric=metric, | ||
logger=logger, trainer_class=trainer_class, | ||
**kwargs, | ||
) | ||
return trainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
import torch | ||
from skimage.transform import resize | ||
|
||
|
||
class ClassificationDataset(torch.utils.data.Dataset): | ||
def __init__(self, data, target, normalization, augmentation, image_shape): | ||
if len(data) != len(target): | ||
raise ValueError(f"Length of data and target don't agree: {len(data)} != {len(target)}") | ||
self.data = data | ||
self.target = target | ||
self.normalization = normalization | ||
self.augmentation = augmentation | ||
self.image_shape = image_shape | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def resize(self, x): | ||
out = [resize(channel, self.image_shape, preserve_range=True)[None] for channel in x] | ||
return np.concatenate(out, axis=0) | ||
|
||
def __getitem__(self, index): | ||
x, y = self.data[index], self.target[index] | ||
|
||
# apply normalization | ||
if self.normalization is not None: | ||
x = self.normalization(x) | ||
|
||
# resize to sample shape if it was given | ||
if self.image_shape is not None: | ||
x = self.resize(x) | ||
|
||
# apply augmentations (if any) | ||
if self.augmentation is not None: | ||
_shape = x.shape | ||
# adds unwanted batch axis | ||
x = self.augmentation(x)[0][0] | ||
assert x.shape == _shape | ||
|
||
return x, y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters