Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend s2d functionality #75

Merged
merged 35 commits into from
Jul 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d85b18a
Enable anisotropic shallow2deep training
constantinpape Jul 15, 2022
a96b87f
Update shallow2deep mito training
constantinpape Jul 15, 2022
adbb477
Implement more s2d sampling strategies
constantinpape Jul 15, 2022
b99b894
Add modelzoo config functionality WIP
constantinpape Jul 18, 2022
76f8ca9
Implement shallow2deep modelzoo config
constantinpape Jul 18, 2022
b3e7792
Update mito experiments
constantinpape Jul 19, 2022
a9933e9
Fix several issues in new mitochondria s2d training
constantinpape Jul 19, 2022
5d84059
Fix score based sampling
constantinpape Jul 19, 2022
92ac2a5
add worst_tiles sampling
JonasHell Jul 20, 2022
aa4e1db
make worst_tiles easier to read, use local maxima
JonasHell Jul 21, 2022
6c6c2bd
find worst tiles per class
JonasHell Jul 21, 2022
280c052
Merge pull request #77 from JonasHell/more-s2d
constantinpape Jul 21, 2022
36de2c9
Add urocell dataset
constantinpape Jul 21, 2022
7762327
Accumulate labels also in raw s2d rf sampling scheme
constantinpape Jul 21, 2022
cbfa6a3
Merge branch 'more-s2d' of https://github.com/constantinpape/torch-em…
constantinpape Jul 21, 2022
f1828ad
Fix small issues in mito s2d experiments
constantinpape Jul 21, 2022
b4f17c1
Add more mito datasets WIP
constantinpape Jul 22, 2022
3a315a1
Refactor em mitochondria experiments, add lucchi and kasthuri datasets
constantinpape Jul 22, 2022
1ce36f7
Fix issue in prepare s2d
constantinpape Jul 22, 2022
939590e
Merge branch 'more-s2d' of https://github.com/constantinpape/torch-em…
constantinpape Jul 22, 2022
664467c
Fix issue in kasthuri data loader
constantinpape Jul 22, 2022
7b3b09c
Merge branch 'more-s2d' of https://github.com/constantinpape/torch-em…
constantinpape Jul 22, 2022
4c47e80
Update min foreground sampler to enable multiple background values
constantinpape Jul 23, 2022
d488dca
Fix issue in worst_tile sampling
constantinpape Jul 23, 2022
e550ac6
Add binary target option to boundry transform with background
constantinpape Jul 23, 2022
cc7f39f
Add 3d mito s2d training
constantinpape Jul 23, 2022
e3ee157
Add kasthuri 3d model training
constantinpape Jul 23, 2022
ea31024
Fix more issues in worst_tile sampling
constantinpape Jul 23, 2022
44e7597
Add tests for prepare shallow2deep
constantinpape Jul 23, 2022
8c71a16
Add more s2d tests, implement s2d training with image collectiond dat…
constantinpape Jul 23, 2022
2bbafa8
Update training scripts for 2d lm membranes
constantinpape Jul 24, 2022
4a96fa0
Fix minor issues in prefab datasets
constantinpape Jul 24, 2022
03fe640
Enable worst_tiles sampling with ignore label
constantinpape Jul 24, 2022
c14e5b2
Fix typo
constantinpape Jul 24, 2022
4b65981
Update s2d mito experiments
constantinpape Jul 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch_em.data.datasets import get_kasthuri_loader
from torch_em.util.debug import check_loader


def check_kasthuri_loader(split):
loader = get_kasthuri_loader("./data", split=split, download=True, batch_size=1, patch_shape=(64, 256, 256))
check_loader(loader, n_samples=4, instance_labels=True)


if __name__ == "__main__":
check_kasthuri_loader(split="train")
check_kasthuri_loader(split="test")
42 changes: 42 additions & 0 deletions experiments/mitochondria-segmentation/kasthuri/train_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch_em
from torch_em.model import UNet3d
from torch_em.data.datasets import get_kasthuri_loader


def get_loader(args, split):
patch_shape = (64, 256, 256)

n_samples = 500 if split == "train" else 25
sampler = torch_em.data.sampler.MinForegroundSampler(min_fraction=0.05, background_id=[-1, 0])
label_transform = torch_em.transform.label.NoToBackgroundBoundaryTransform(ndim=3, add_binary_target=True)
loader = get_kasthuri_loader(
args.input, split=split, label_transform=label_transform,
batch_size=args.batch_size, patch_shape=patch_shape,
n_samples=n_samples, ndim=3, shuffle=True,
num_workers=12, sampler=sampler
)
return loader


def train_direct(args):
name = "kasthuri-mito-3d"
model = UNet3d(in_channels=1, out_channels=2, final_activation="Sigmoid", depth=4, initial_features=32)

train_loader = get_loader(args, "train")
val_loader = get_loader(args, "test")
loss = torch_em.loss.DiceLoss()
loss = torch_em.loss.wrapper.LossWrapper(
loss, torch_em.loss.wrapper.MaskIgnoreLabel()
)

trainer = torch_em.default_segmentation_trainer(
name, model, train_loader, val_loader,
loss=loss, learning_rate=3.0e-4, device=args.device, log_image_interval=50
)
trainer.fit(args.n_iterations)


if __name__ == "__main__":
parser = torch_em.util.parser_helper()
args = parser.parse_args()
train_direct(args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch_em.data.datasets import get_lucchi_loader
from torch_em.util.debug import check_loader


def check_lucchi_loader(split):
loader = get_lucchi_loader("./data", split=split, download=True, batch_size=1, patch_shape=(64, 256, 256))
check_loader(loader, n_samples=4, instance_labels=True)


if __name__ == "__main__":
check_lucchi_loader(split="train")
check_lucchi_loader(split="test")
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torch_em.data.datasets import get_uro_cell_loader
from torch_em.util.debug import check_loader


def check_uro_cell_loader(target):
loader = get_uro_cell_loader("./data", target=target, download=True,
batch_size=1, patch_shape=(32, 128, 128))
check_loader(loader, n_samples=5, instance_labels=True)


if __name__ == "__main__":
check_uro_cell_loader(target="mito")
63 changes: 61 additions & 2 deletions experiments/shallow2deep/em-mitochondria/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,67 @@
# Shallow2Deep for mitochondria
# Shallow2Deep for Mitochondria in EM

## Evaluation

Evaluation of different shallow2deep setups on EM-Mitochondria. All scores are measured with a soft dice score.
Evaluation of different shallow2deep setups for mitochondria segmentation in EM.
The enhancers are (potentially) trained on multiple datasets, evaluation is done on the Kasthuri dataset (which is not part of the training set except for one last version that will be the (for now) final one to be uploaded to bioimagei.io).
All scores are measured with a soft dice score.

## Datasets

- Mito-EM
- VNC
- Lucchi
- UroCell
- Kasthuri


### V4

- 2d enhancer: trained on Mito-EM and VNC
- anisotropic enhancer: random forests are trained in 2d, enhancer trained in 3d, trained on Mito-EM
- 3d enhancer: random forests trained in 3d, enhancer trained in 3d, trained on Kasthuri
- direct-nets: 2d and anisotropic networks trained on Mito-EM, 3d network trained on Kasthuri
- different strategies for training the initial rfs:
- `vanilla`: random forests are trained on randomly sampled dense patches
- `worst_points`: initial stage of forests (25 forests) are trained on random samples, forests in the next stages add worst predictions from prev. stage to their training set
- `uncertain_worst_points`: same as `worst_points`, but points are selected based on linear combination of uncertainty and worst predictions
- `random_points`: random points sampled in each stage, points are accumulated over the stages
- `worst_tiles`: training samples are taken from worst tile predictions

| method | few-labels | medium-labels | many-labels |
|:-----------------------------------|-------------:|----------------:|--------------:|
| rf3d | 0.326 | 0.328 | 0.385 |
| 2d-random_points | 0.593 | 0.693 | 0.782 |
| 2d-uncertain_worst_points | 0.613 | 0.777 | 0.794 |
| 2d-vanilla | 0.639 | 0.717 | 0.764 |
| 2d-worst_points | 0.549 | 0.711 | 0.730 |
| 2d-worst_tiles | 0.661 | 0.796 | 0.828 |
| direct_2d | 0.849 | nan | nan |
| anisotropic-random_points | 0.521 | 0.566 | 0.671 |
| anisotropic-uncertain_worst_points | 0.530 | 0.616 | 0.711 |
| anisotropic-vanilla | 0.576 | 0.660 | 0.749 |
| anisotropic-worst_points | 0.458 | 0.568 | 0.600 |
| anisotropic-worst_tiles | 0.614 | 0.728 | 0.788 |
| direct_anisotropic | 0.467 | nan | nan |
| 3d-random_points | 0.344 | 0.381 | 0.353 |
| 3d-worst_tiles | 0.385 | 0.472 | 0.504 |


### V5

TODO: (only best sampling from V4)
- train 2d on Mito-EM, VNC, Kasthuri and UroCell
- train anisotropic on Mito-EM, Kasthuri and UroCell
- train 3d on Kasthuri and UroCell

## V6

TODO same as V5, but train everything on Lucchi as well and upload the one with best sampling strategy to bioimage.io


## Old evaluation

Evaluation of older set-ups.

### V1

Expand Down
Loading