From 65aa38ff7cf32e718eefb63998df8e21c28e457f Mon Sep 17 00:00:00 2001 From: JonasHell Date: Wed, 22 Feb 2023 12:09:06 +0100 Subject: [PATCH] add dense accumulate as additional baseline --- torch_em/shallow2deep/prepare_shallow2deep.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index 31987dbf..d772d73c 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -639,12 +639,44 @@ def worst_tiles( return features, labels +def balanced_dense_accumulate( + features, labels, rf_id, + forests, forests_per_stage, + sample_fraction_per_stage, + accumulate_samples=True, + **kwargs, +): + samples = [] + nc = len(np.unique(labels)) + # sample in a class balanced way + # take all pixels from minority class + # and choose same amount from other classes randomly + n_samples_class = np.unique(labels, return_counts=True)[1].min() + for class_id in range(nc): + class_indices = np.where(labels == class_id)[0] + this_samples = np.random.choice( + class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class + ) + samples.append(this_samples) + samples = np.concatenate(samples) + features, labels = features[samples], labels[samples] + + # accumulate + if accumulate_samples and rf_id >= forests_per_stage: + last_forest = forests[rf_id - forests_per_stage] + features = np.concatenate([last_forest.train_features, features], axis=0) + labels = np.concatenate([last_forest.train_labels, labels], axis=0) + + return features, labels + + SAMPLING_STRATEGIES = { "random_points": random_points, "uncertain_points": uncertain_points, "uncertain_worst_points": uncertain_worst_points, "worst_points": worst_points, "worst_tiles": worst_tiles, + "balanced_dense_accumulate": balanced_dense_accumulate, }