Skip to content

Commit

Permalink
Merge pull request #107 from JonasHell/rf_sampling_strats
Browse files Browse the repository at this point in the history
add dense accumulate as additional baseline
  • Loading branch information
constantinpape committed Feb 22, 2023
2 parents 5544225 + 65aa38f commit b6efdcd
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions torch_em/shallow2deep/prepare_shallow2deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,44 @@ def worst_tiles(
return features, labels


def balanced_dense_accumulate(
features, labels, rf_id,
forests, forests_per_stage,
sample_fraction_per_stage,
accumulate_samples=True,
**kwargs,
):
samples = []
nc = len(np.unique(labels))
# sample in a class balanced way
# take all pixels from minority class
# and choose same amount from other classes randomly
n_samples_class = np.unique(labels, return_counts=True)[1].min()
for class_id in range(nc):
class_indices = np.where(labels == class_id)[0]
this_samples = np.random.choice(
class_indices, size=n_samples_class, replace=len(class_indices) < n_samples_class
)
samples.append(this_samples)
samples = np.concatenate(samples)
features, labels = features[samples], labels[samples]

# accumulate
if accumulate_samples and rf_id >= forests_per_stage:
last_forest = forests[rf_id - forests_per_stage]
features = np.concatenate([last_forest.train_features, features], axis=0)
labels = np.concatenate([last_forest.train_labels, labels], axis=0)

return features, labels


SAMPLING_STRATEGIES = {
"random_points": random_points,
"uncertain_points": uncertain_points,
"uncertain_worst_points": uncertain_worst_points,
"worst_points": worst_points,
"worst_tiles": worst_tiles,
"balanced_dense_accumulate": balanced_dense_accumulate,
}


Expand Down

0 comments on commit b6efdcd

Please sign in to comment.