Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement advanced rf training for s2d #63

Merged
merged 9 commits into from
Jul 5, 2022
Merged

Conversation

constantinpape
Copy link
Owner

Hey @JonasHell @k-dominik, @akreshuk
I have implemented a training routine for random forests for shallow 2 deep here that trains RFs in stages, and selects the training examples for each stage by taking the worst predictions from forests from the previous stage, similar to what we discussed during the retreat. This is implemented with prepare_shallow2deep_advanced by adding a sampling_strategy parameter that enables customizing the sampling of random forest training samples. (And using the strategy described above by default).

I will train a RF for mito segmentation based on this shortly and share it, to see if it improves results. And if you have other sampling strategies (e.g. sampling scribbles instead of points), we could eventually think about adding them here.

In more detail, here''s how the sampling_strategy is implemented (copied from the docstring):

    This function accepts the 'sampling_strategy' parameter, which allows to implement custom
    sampling strategies for the samples used for training the random forests.
    Training operates in stages, the parameter 'forests_per_stage' determines how many forests
    are trained in each stage, and 'sample_fraction_per_stage' which fraction of the samples is
    taken per stage. The random forests in stage 0 are always trained from balanced dense labels.
    For the other stages 'sampling_strategy' enables specifying the strategy; it has to be a function
    with signature '(features, labels, forests, forests_per_stage, sample_fraction_per_stage)',
    and return the sampled features and labels. See thw 'worst_points' function
    in this file for an example implementation.

@constantinpape constantinpape merged commit 9458d39 into main Jul 5, 2022
@constantinpape constantinpape deleted the more-s2d-training branch July 5, 2022 21:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant