Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add worst_tiles sampling #77

Merged
merged 3 commits into from
Jul 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions torch_em/shallow2deep/prepare_shallow2deep.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import copy
import pickle
from concurrent import futures
from glob import glob
from functools import partial

import numpy as np
import torch_em
from scipy.ndimage import gaussian_filter, convolve
from skimage.feature import peak_local_max
from sklearn.ensemble import RandomForestClassifier
from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets
from tqdm import tqdm
Expand Down Expand Up @@ -504,11 +507,88 @@ def random_points(
return features[samples], labels[samples]


def worst_tiles(
features, labels, rf_id,
forests, forests_per_stage,
sample_fraction_per_stage,
img_shape,
tiles_shape=[25, 25],
smoothing_sigma=None,
accumulate_samples=True,
):
# check inputs
ndim = len(img_shape)
assert ndim in [2, 3], img_shape
assert len(tiles_shape) == ndim, tiles_shape

# get the corresponding random forest from the last stage
# and predict with it
last_forest = forests[rf_id - forests_per_stage]
pred = last_forest.predict_proba(features)

# labels to one-hot encoding
unique, inverse = np.unique(labels, return_inverse=True)
onehot = np.eye(unique.shape[0])[inverse]

# compute the difference between labels and prediction
diff = np.abs(onehot - pred)
assert len(diff) == len(features)

# reshape diff to image shape
diff_img = diff.reshape(img_shape + (-1,))

# sample in a class balanced way
nc = len(np.unique(labels))
n_samples_class = int(sample_fraction_per_stage * len(features)) // nc
samples = []
for class_id in range(nc):
# smooth either with gaussian or 1-kernel
if smoothing_sigma:
diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode='constant')
else:
kernel = np.ones(tiles_shape)
diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode='constant')

# get training samples based on tiles around maxima of the label-prediction diff
# do this in a class-specific way to ensure that each class is sampled
# get maxima of the label-prediction diff (they seem to be sorted already)
max_centers = peak_local_max(
diff_img_smooth,
min_distance=max(tiles_shape),
exclude_border=tuple([s // 2 for s in tiles_shape])
)

# get indices of tiles around maxima
tiles = []
for center in max_centers:
tile_slice = tuple([slice(center[d]-tiles_shape[d]//2,
center[d]+tiles_shape[d]//2 + 1, None) for d in range(ndim)])
grid = np.mgrid[tile_slice]
samples_in_tile = grid.reshape(ndim, -1)
samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape)
tiles.append(samples_in_tile)
tiles = np.concatenate(tiles)

# take samples that belong to the current class
this_samples = tiles[labels[tiles] == class_id][:n_samples_class]
samples.append(this_samples)
samples = np.concatenate(samples)

# get the features and labels, add from previous rf if specified
features, labels = features[samples], labels[samples]
if accumulate_samples:
features = np.concatenate([last_forest.train_features, features], axis=0)
labels = np.concatenate([last_forest.train_labels, labels], axis=0)

return features, labels


SAMPLING_STRATEGIES = {
"random_points": random_points,
"uncertain_points": uncertain_points,
"uncertain_worst_points": uncertain_worst_points,
"worst_points": worst_points,
"worst_tiles": worst_tiles,
}


Expand All @@ -526,6 +606,7 @@ def prepare_shallow2deep_advanced(
forests_per_stage,
sample_fraction_per_stage,
sampling_strategy="worst_points",
sampling_kwargs={},
raw_transform=None,
label_transform=None,
rois=None,
Expand Down Expand Up @@ -576,6 +657,11 @@ def _train_rf(rf_id):
raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze()
assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}"

# monkey patch original shape to sampling_kwargs
# deepcopy needed due to multithreading
current_kwargs = copy.deepcopy(sampling_kwargs)
current_kwargs['img_shape'] = raw.shape

# only balance samples for the first (densely trained) rfs
features, labels = _get_features_and_labels(
raw, labels, filters_and_sigmas, balance_labels=False
Expand All @@ -585,6 +671,7 @@ def _train_rf(rf_id):
features, labels, rf_id,
forests, forests_per_stage,
sample_fraction_per_stage,
**current_kwargs,
)
else: # sample randomly
features, labels = random_points(
Expand Down