Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MeanTeacher Trainer #112

Merged
merged 22 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
47a4b51
Implement MeanTeacher Trainer WIP
constantinpape Mar 11, 2023
a03bf55
Merge branch 'main' of https://github.com/constantinpape/torch-em int…
constantinpape Mar 11, 2023
3ad4488
Mean teacher is running (not properly tested, not serializable or des…
constantinpape Mar 11, 2023
54e6b9c
MeanTeacher semi-supervised training runs (not properly tested)
constantinpape Mar 11, 2023
3090e1d
Update pseudo-labeling functionality
constantinpape Mar 11, 2023
451553e
Start implementing self-trainign experiments
constantinpape Mar 11, 2023
384eb54
Update livecell source training
constantinpape Mar 11, 2023
8a91b90
Livecell training updates
constantinpape Mar 11, 2023
e5a7db2
Fix issues in raw transforms
constantinpape Mar 11, 2023
8269508
Fix several issues in self training
constantinpape Mar 11, 2023
8ba48af
Update domain adaptation experiments
constantinpape Mar 11, 2023
6c8f824
Minor fixes
constantinpape Mar 11, 2023
020086a
Implement mean teacer training for livecell
constantinpape Mar 12, 2023
ef78bbc
Enable deserialization for mean teacher trainer
constantinpape Mar 12, 2023
d73b67d
Simplify spoco trainer
constantinpape Mar 12, 2023
66796d5
Fix issue in mean teacher trainer
constantinpape Mar 12, 2023
1644d18
Implement unet source eval
constantinpape Mar 13, 2023
39b00f9
Update livecell domain adaptation training
constantinpape Mar 14, 2023
a7057c4
Enable self-training for 3d data
constantinpape Mar 14, 2023
d41122d
Enable pytorch 2
constantinpape Mar 14, 2023
9e6208e
Bump python versions in CI and use pytorch 2 in the env files
constantinpape Mar 15, 2023
a785af9
Fix python version names in CI
constantinpape Mar 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: [3.8, 3.9]
python-version: ["3.9", "3.10"]

steps:
- name: Checkout
Expand Down
2 changes: 1 addition & 1 deletion environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- bioimageio.core >=0.5.0
- cpuonly
- python-elf
- pytorch
- pytorch >=2.0
- tensorboard
- tifffile
- torchvision
Expand Down
3 changes: 2 additions & 1 deletion environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ dependencies:
- affogato
- bioimageio.core >=0.5.0
- python-elf
- pytorch-cuda=11.6 # you may need to update the pytorch version to match your system
- pytorch >=2.0
- pytorch-cuda>=11.7 # you may need to update the pytorch version to match your system
- tensorboard
- tifffile
- torchvision
Expand Down
2 changes: 1 addition & 1 deletion experiments/livecell/train_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def check_loader(args, train=True, val=True, n_images=5):
check_loader(loader, n_images)


if __name__ == '__main__':
if __name__ == "__main__":
parser = torch_em.util.parser_helper(default_batch_size=8)
parser.add_argument("--cell_type", default=None)
args = parser.parse_args()
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Prelim. results

-> UNet results are a bit worse than from Anwai, double check how the training differs.
-> Mean Teacher improves results.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import argparse
import pandas as pd


def check_result(path):
table = pd.read_csv(path)
print(table)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("paths", nargs="+")
args = parser.parse_args()
for path in args.paths:
check_result(path)


if __name__ == "__main__":
main()
234 changes: 234 additions & 0 deletions experiments/probabilistic_domain_adaptation/livecell/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import argparse
import os
from glob import glob

try:
import imageio.v2 as imageio
except ImportError:
import imageio
import numpy as np
import pandas as pd
import torch
import torch_em

from elf.evaluation import dice_score
from torch_em.data.datasets.livecell import (get_livecell_loader,
_download_livecell_images,
_download_livecell_annotations)
from torch_em.model import UNet2d
from torch_em.util.prediction import predict_with_padding
from torchvision import transforms
from tqdm import tqdm

CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]


#
# The augmentations we use for the LiveCELL experiments:
# - weak augmenations: blurring and additive gaussian noise
# - strong augmentations: TODO
#


def weak_augmentations(p=0.25):
norm = torch_em.transform.raw.standardize
aug = transforms.Compose([
norm,
transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
scale=(0, 0.15), clip_kwargs=False)], p=p
),
])
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)


# TODO
def strong_augmentations():
pass


#
# Model and prediction functionality: the models we use in all experiments
#

def get_unet():
return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4)


def load_model(model, ckpt, state="model_state", device=None):
state = torch.load(os.path.join(ckpt, "best.pt"))[state]
model.load_state_dict(state)
if device is not None:
model.to(device)
return model


# use get_model and prediction_function to customize this, e.g. for using it with the PUNet
# set model_state to "teacher_state" when using this with a mean-teacher method
def evaluate_transfered_model(
args, ct_src, method, get_model=get_unet, prediction_function=None, model_state="model_state"
):
image_folder = os.path.join(args.input, "images", "livecell_test_images")
label_root = os.path.join(args.input, "annotations", "livecell_test_images")

results = {"src": [ct_src]}
device = torch.device("cuda")

thresh = args.confidence_threshold
with torch.no_grad():
for ct_trg in CELL_TYPES:

if ct_trg == ct_src:
results[ct_trg] = None
continue

out_folder = None if args.output is None else os.path.join(
args.output, f"thresh-{thresh}", ct_src, ct_trg
)
if out_folder is not None:
os.makedirs(out_folder, exist_ok=True)

ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
model = get_model()
model = load_model(model, ckpt, device=device, state=model_state)

label_paths = glob(os.path.join(label_root, ct_trg, "*.tif"))
scores = []
for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"):

labels = imageio.imread(label_path)
if out_folder is None:
out_path = None
else:
out_path = os.path.join(out_folder, os.path.basename(label_path))
if os.path.exists(out_path):
pred = imageio.imread(out_path)
score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0)
scores.append(score)
continue

image_path = os.path.join(image_folder, os.path.basename(label_path))
assert os.path.exists(image_path)
image = imageio.imread(image_path)
image = torch_em.transform.raw.standardize(image)
pred = predict_with_padding(
model, image, min_divisible=(16, 16), device=device, prediction_function=prediction_function,
).squeeze()
assert image.shape == labels.shape
score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0)
if out_path is not None:
imageio.imwrite(out_path, pred)
scores.append(score)

results[ct_trg] = np.mean(scores)
return pd.DataFrame(results)


# use get_model and prediction_function to customize this, e.g. for using it with the PUNet
def evaluate_source_model(args, ct_src, method, get_model=get_unet, prediction_function=None):
ckpt = f"checkpoints/{method}/{ct_src}"
model = get_model()
model = torch_em.util.get_trainer(ckpt).model

image_folder = os.path.join(args.input, "images", "livecell_test_images")
label_root = os.path.join(args.input, "annotations", "livecell_test_images")

results = {"src": [ct_src]}
device = torch.device("cuda")

with torch.no_grad():
for ct_trg in CELL_TYPES:

out_folder = None if args.output is None else os.path.join(args.output, ct_src, ct_trg)
if out_folder is not None:
os.makedirs(out_folder, exist_ok=True)

label_paths = glob(os.path.join(label_root, ct_trg, "*.tif"))
scores = []
for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"):

labels = imageio.imread(label_path)
if out_folder is None:
out_path = None
else:
out_path = os.path.join(out_folder, os.path.basename(label_path))
if os.path.exists(out_path):
pred = imageio.imread(out_path)
score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0)
scores.append(score)
continue

image_path = os.path.join(image_folder, os.path.basename(label_path))
assert os.path.exists(image_path)
image = imageio.imread(image_path)
image = torch_em.transform.raw.standardize(image)
pred = predict_with_padding(
model, image, min_divisible=(16, 16), device=device, prediction_function=prediction_function
).squeeze()
assert image.shape == labels.shape
score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0)
if out_path is not None:
imageio.imwrite(out_path, pred)
scores.append(score)

results[ct_trg] = np.mean(scores)
return pd.DataFrame(results)


#
# Other utility functions: loaders, parser
#


def _get_image_paths(args, split, cell_type):
_download_livecell_images(args.input, download=True)
image_paths, _ = _download_livecell_annotations(args.input, split, download=True,
cell_types=[cell_type], label_path=None)
return image_paths


def get_unsupervised_loader(args, split, cell_type, teacher_augmentation, student_augmentation):
patch_shape = (512, 512)

def _parse_aug(aug):
if aug == "weak":
return weak_augmentations()
elif aug == "strong":
return strong_augmentations()
assert callable(aug)
return aug

raw_transform = torch_em.transform.get_raw_transform()
transform = torch_em.transform.get_augmentations(ndim=2)

image_paths = _get_image_paths(args, split, cell_type)

augmentations = (_parse_aug(teacher_augmentation), _parse_aug(student_augmentation))
ds = torch_em.data.RawImageCollectionDataset(
image_paths, patch_shape, raw_transform, transform,
augmentations=augmentations
)
loader = torch_em.segmentation.get_data_loader(ds, batch_size=args.batch_size, num_workers=8, shuffle=True)
return loader


def get_supervised_loader(args, split, cell_type):
patch_shape = (512, 512)
loader = get_livecell_loader(
args.input, patch_shape, split,
download=True, binary=True, batch_size=args.batch_size,
cell_types=[cell_type], num_workers=8, shuffle=True,
)
return loader


def get_parser(default_batch_size=8, default_iterations=int(1e5)):
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", required=True)
parser.add_argument("-p", "--phase", required=True)
parser.add_argument("-b", "--batch_size", default=default_batch_size, type=int)
parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int)
parser.add_argument("-s", "--save_root")
parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES)
parser.add_argument("-o", "--output")
return parser
Loading