diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7a1c8b2e..355f7bc5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,7 +11,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: [3.8, 3.9] + python-version: ["3.9", "3.10"] steps: - name: Checkout diff --git a/environment_cpu.yaml b/environment_cpu.yaml index 061986bd..e93d7320 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -8,7 +8,7 @@ dependencies: - bioimageio.core >=0.5.0 - cpuonly - python-elf - - pytorch + - pytorch >=2.0 - tensorboard - tifffile - torchvision diff --git a/environment_gpu.yaml b/environment_gpu.yaml index 4aa5c5ef..6b2fdf6d 100644 --- a/environment_gpu.yaml +++ b/environment_gpu.yaml @@ -8,7 +8,8 @@ dependencies: - affogato - bioimageio.core >=0.5.0 - python-elf - - pytorch-cuda=11.6 # you may need to update the pytorch version to match your system + - pytorch >=2.0 + - pytorch-cuda>=11.7 # you may need to update the pytorch version to match your system - tensorboard - tifffile - torchvision diff --git a/experiments/livecell/train_boundaries.py b/experiments/livecell/train_boundaries.py index ab5d2427..f54fb78a 100644 --- a/experiments/livecell/train_boundaries.py +++ b/experiments/livecell/train_boundaries.py @@ -62,7 +62,7 @@ def check_loader(args, train=True, val=True, n_images=5): check_loader(loader, n_images) -if __name__ == '__main__': +if __name__ == "__main__": parser = torch_em.util.parser_helper(default_batch_size=8) parser.add_argument("--cell_type", default=None) args = parser.parse_args() diff --git a/experiments/probabilistic_domain_adaptation/README.md b/experiments/probabilistic_domain_adaptation/README.md new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/livecell/README.md b/experiments/probabilistic_domain_adaptation/livecell/README.md new file mode 100644 index 00000000..38424eb1 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/README.md @@ -0,0 +1,4 @@ +# Prelim. results + +-> UNet results are a bit worse than from Anwai, double check how the training differs. +-> Mean Teacher improves results. diff --git a/experiments/probabilistic_domain_adaptation/livecell/check_result.py b/experiments/probabilistic_domain_adaptation/livecell/check_result.py new file mode 100644 index 00000000..884bc5c1 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/check_result.py @@ -0,0 +1,19 @@ +import argparse +import pandas as pd + + +def check_result(path): + table = pd.read_csv(path) + print(table) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("paths", nargs="+") + args = parser.parse_args() + for path in args.paths: + check_result(path) + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py new file mode 100644 index 00000000..cd7ad1c9 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -0,0 +1,234 @@ +import argparse +import os +from glob import glob + +try: + import imageio.v2 as imageio +except ImportError: + import imageio +import numpy as np +import pandas as pd +import torch +import torch_em + +from elf.evaluation import dice_score +from torch_em.data.datasets.livecell import (get_livecell_loader, + _download_livecell_images, + _download_livecell_annotations) +from torch_em.model import UNet2d +from torch_em.util.prediction import predict_with_padding +from torchvision import transforms +from tqdm import tqdm + +CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + + +# +# The augmentations we use for the LiveCELL experiments: +# - weak augmenations: blurring and additive gaussian noise +# - strong augmentations: TODO +# + + +def weak_augmentations(p=0.25): + norm = torch_em.transform.raw.standardize + aug = transforms.Compose([ + norm, + transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), + transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( + scale=(0, 0.15), clip_kwargs=False)], p=p + ), + ]) + return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) + + +# TODO +def strong_augmentations(): + pass + + +# +# Model and prediction functionality: the models we use in all experiments +# + +def get_unet(): + return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4) + + +def load_model(model, ckpt, state="model_state", device=None): + state = torch.load(os.path.join(ckpt, "best.pt"))[state] + model.load_state_dict(state) + if device is not None: + model.to(device) + return model + + +# use get_model and prediction_function to customize this, e.g. for using it with the PUNet +# set model_state to "teacher_state" when using this with a mean-teacher method +def evaluate_transfered_model( + args, ct_src, method, get_model=get_unet, prediction_function=None, model_state="model_state" +): + image_folder = os.path.join(args.input, "images", "livecell_test_images") + label_root = os.path.join(args.input, "annotations", "livecell_test_images") + + results = {"src": [ct_src]} + device = torch.device("cuda") + + thresh = args.confidence_threshold + with torch.no_grad(): + for ct_trg in CELL_TYPES: + + if ct_trg == ct_src: + results[ct_trg] = None + continue + + out_folder = None if args.output is None else os.path.join( + args.output, f"thresh-{thresh}", ct_src, ct_trg + ) + if out_folder is not None: + os.makedirs(out_folder, exist_ok=True) + + ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" + model = get_model() + model = load_model(model, ckpt, device=device, state=model_state) + + label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) + scores = [] + for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"): + + labels = imageio.imread(label_path) + if out_folder is None: + out_path = None + else: + out_path = os.path.join(out_folder, os.path.basename(label_path)) + if os.path.exists(out_path): + pred = imageio.imread(out_path) + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + scores.append(score) + continue + + image_path = os.path.join(image_folder, os.path.basename(label_path)) + assert os.path.exists(image_path) + image = imageio.imread(image_path) + image = torch_em.transform.raw.standardize(image) + pred = predict_with_padding( + model, image, min_divisible=(16, 16), device=device, prediction_function=prediction_function, + ).squeeze() + assert image.shape == labels.shape + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + if out_path is not None: + imageio.imwrite(out_path, pred) + scores.append(score) + + results[ct_trg] = np.mean(scores) + return pd.DataFrame(results) + + +# use get_model and prediction_function to customize this, e.g. for using it with the PUNet +def evaluate_source_model(args, ct_src, method, get_model=get_unet, prediction_function=None): + ckpt = f"checkpoints/{method}/{ct_src}" + model = get_model() + model = torch_em.util.get_trainer(ckpt).model + + image_folder = os.path.join(args.input, "images", "livecell_test_images") + label_root = os.path.join(args.input, "annotations", "livecell_test_images") + + results = {"src": [ct_src]} + device = torch.device("cuda") + + with torch.no_grad(): + for ct_trg in CELL_TYPES: + + out_folder = None if args.output is None else os.path.join(args.output, ct_src, ct_trg) + if out_folder is not None: + os.makedirs(out_folder, exist_ok=True) + + label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) + scores = [] + for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"): + + labels = imageio.imread(label_path) + if out_folder is None: + out_path = None + else: + out_path = os.path.join(out_folder, os.path.basename(label_path)) + if os.path.exists(out_path): + pred = imageio.imread(out_path) + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + scores.append(score) + continue + + image_path = os.path.join(image_folder, os.path.basename(label_path)) + assert os.path.exists(image_path) + image = imageio.imread(image_path) + image = torch_em.transform.raw.standardize(image) + pred = predict_with_padding( + model, image, min_divisible=(16, 16), device=device, prediction_function=prediction_function + ).squeeze() + assert image.shape == labels.shape + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + if out_path is not None: + imageio.imwrite(out_path, pred) + scores.append(score) + + results[ct_trg] = np.mean(scores) + return pd.DataFrame(results) + + +# +# Other utility functions: loaders, parser +# + + +def _get_image_paths(args, split, cell_type): + _download_livecell_images(args.input, download=True) + image_paths, _ = _download_livecell_annotations(args.input, split, download=True, + cell_types=[cell_type], label_path=None) + return image_paths + + +def get_unsupervised_loader(args, split, cell_type, teacher_augmentation, student_augmentation): + patch_shape = (512, 512) + + def _parse_aug(aug): + if aug == "weak": + return weak_augmentations() + elif aug == "strong": + return strong_augmentations() + assert callable(aug) + return aug + + raw_transform = torch_em.transform.get_raw_transform() + transform = torch_em.transform.get_augmentations(ndim=2) + + image_paths = _get_image_paths(args, split, cell_type) + + augmentations = (_parse_aug(teacher_augmentation), _parse_aug(student_augmentation)) + ds = torch_em.data.RawImageCollectionDataset( + image_paths, patch_shape, raw_transform, transform, + augmentations=augmentations + ) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=args.batch_size, num_workers=8, shuffle=True) + return loader + + +def get_supervised_loader(args, split, cell_type): + patch_shape = (512, 512) + loader = get_livecell_loader( + args.input, patch_shape, split, + download=True, binary=True, batch_size=args.batch_size, + cell_types=[cell_type], num_workers=8, shuffle=True, + ) + return loader + + +def get_parser(default_batch_size=8, default_iterations=int(1e5)): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-p", "--phase", required=True) + parser.add_argument("-b", "--batch_size", default=default_batch_size, type=int) + parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int) + parser.add_argument("-s", "--save_root") + parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) + parser.add_argument("-o", "--output") + return parser diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py new file mode 100644 index 00000000..f3587d5c --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -0,0 +1,114 @@ +import os + +import pandas as pd +import torch +import torch_em.self_training as self_training + +import common + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = common.get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="weak", + ) + check_loader(loader, n_images) + + +def _train_source_target(args, source_cell_type, target_cell_type): + model = common.get_model() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + + # self training functionality + thresh = args.confidence_threshold + pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh) + loss = self_training.DefaultSelfTrainingLoss() + loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() + + # data loaders + supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type) + supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type) + unsupervised_train_loader = common.get_unsupervised_loader( + args, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + unsupervised_val_loader = common.get_unsupervised_loader( + args, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + name = f"unet_adamt/thresh-{thresh}/{source_cell_type}/{target_cell_type}" + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + supervised_train_loader=supervised_train_loader, + unsupervised_train_loader=unsupervised_train_loader, + supervised_val_loader=supervised_val_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=self_training.SelfTrainingTensorboardLogger, + mixed_precision=True, + device=device, + log_image_interval=100, + save_root=args.save_root, + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + for target_cell_type in common.CELL_TYPES: + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "unet_adamt", model_state="teacher_state") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_adamt.csv"), index=False) + + +def main(): + parser = common.get_parser(default_iterations=75000, default_batch_size=4) + parser.add_argument("--confidence_threshold", default=0.9) + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py new file mode 100644 index 00000000..79bc309b --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py @@ -0,0 +1,113 @@ +import os + +import pandas as pd +import torch +import torch_em.self_training as self_training + +import common + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = common.get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="weak", + ) + check_loader(loader, n_images) + + +def _train_source_target(args, source_cell_type, target_cell_type): + model = common.get_unet() + src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" + model = common.load_model(model, src_checkpoint) + + optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + + # self training functionality + thresh = args.confidence_threshold + pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh) + loss = self_training.DefaultSelfTrainingLoss() + loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() + + # data loaders + unsupervised_train_loader = common.get_unsupervised_loader( + args, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + unsupervised_val_loader = common.get_unsupervised_loader( + args, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + name = f"unet_mean_teacher/thresh-{thresh}/{source_cell_type}/{target_cell_type}" + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + unsupervised_train_loader=unsupervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=self_training.SelfTrainingTensorboardLogger, + mixed_precision=True, + device=device, + log_image_interval=100, + save_root=args.save_root, + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + for target_cell_type in common.CELL_TYPES: + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "unet_mean_teacher", model_state="teacher_state") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_mean_teacher.csv"), index=False) + + +def main(): + parser = common.get_parser(default_iterations=25000, default_batch_size=8) + parser.add_argument("--confidence_threshold", default=0.9) + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py new file mode 100644 index 00000000..4347a8e5 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py @@ -0,0 +1,71 @@ +import os + +import torch_em +import pandas as pd + +import common + + +def _train_cell_type(args, cell_type): + model = common.get_unet() + train_loader = common.get_supervised_loader(args, "train", cell_type) + val_loader = common.get_supervised_loader(args, "val", cell_type) + name = f"unet_source/{cell_type}" + trainer = torch_em.default_segmentation_trainer( + name=name, + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-4, + mixed_precision=True, + log_image_interval=100, + save_root=args.save_root, + ) + trainer.fit(iterations=args.n_iterations) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_cell_type(args, cell_type) + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the loader for the first cell type", cell_types[0]) + + loader = common.get_supervised_loader(args) + check_loader(loader, n_images) + + +def run_evaluation(args): + results = [] + for ct in args.cell_types: + res = common.evaluate_source_model(args, ct, "unet_source") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_source.csv"), index=False) + + +def main(): + parser = common.get_parser(default_iterations=50000) + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/README.md b/experiments/probabilistic_domain_adaptation/mitochondria/README.md new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/unet_adamt.py b/experiments/probabilistic_domain_adaptation/mitochondria/unet_adamt.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/mitochondria/unet_mean_teacher.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/unet_source.py b/experiments/probabilistic_domain_adaptation/mitochondria/unet_source.py new file mode 100644 index 00000000..e69de29b diff --git a/test/self_training/__init__.py b/test/self_training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/self_training/test_mean_teacher.py b/test/self_training/test_mean_teacher.py new file mode 100644 index 00000000..361a9e61 --- /dev/null +++ b/test/self_training/test_mean_teacher.py @@ -0,0 +1,130 @@ +import os +import unittest +from shutil import rmtree + +import torch +import torch_em +import torch_em.self_training as self_training +from torch_em.model import UNet2d +from torch_em.util.test import create_segmentation_test_data + + +class TestMeanTeacher(unittest.TestCase): + tmp_folder = "./tmp" + data_path = "./tmp/data.h5" + raw_key = "raw" + label_key = "labels" + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) + create_segmentation_test_data(self.data_path, self.raw_key, self.label_key, shape=(128,) * 3, chunks=(32,) * 3) + + def tearDown(self): + + def _remove(folder): + try: + rmtree(folder) + except OSError: + pass + + _remove(self.tmp_folder) + _remove("./logs") + _remove("./checkpoints") + + def _test_mean_teacher( + self, + unsupervised_train_loader, + supervised_train_loader=None, + unsupervised_val_loader=None, + supervised_val_loader=None, + supervised_loss=None, + supervised_loss_and_metric=None, + unsupervised_loss_and_metric=None, + ): + model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) + optimizer = torch.optim.Adam(model.parameters()) + + name = "mt-test" + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + pseudo_labeler=self_training.DefaultPseudoLabeler(), + unsupervised_loss=self_training.DefaultSelfTrainingLoss(), + unsupervised_loss_and_metric=unsupervised_loss_and_metric, + unsupervised_train_loader=unsupervised_train_loader, + supervised_train_loader=supervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_val_loader=supervised_val_loader, + supervised_loss=supervised_loss, + supervised_loss_and_metric=supervised_loss_and_metric, + mixed_precision=False, + device=torch.device("cpu"), + ) + trainer.fit(53) + self.assertTrue(os.path.exists(f"./checkpoints/{name}/best.pt")) + self.assertTrue(os.path.exists(f"./checkpoints/{name}/latest.pt")) + + # make sure that the trainer can be deserialized from the checkpoint + trainer2 = self_training.MeanTeacherTrainer.from_checkpoint(os.path.join("./checkpoints", name), name="latest") + self.assertEqual(trainer.iteration, trainer2.iteration) + self.assertTrue(torch_em.util.model_is_equal(trainer.model, trainer2.model)) + self.assertTrue(torch_em.util.model_is_equal(trainer.teacher, trainer2.teacher)) + self.assertEqual(len(trainer.unsupervised_train_loader), len(trainer2.unsupervised_train_loader)) + if supervised_train_loader is not None: + self.assertEqual(len(trainer.supervised_train_loader), len(trainer2.supervised_train_loader)) + + # and that it can be trained further + trainer2.fit(10) + self.assertEqual(trainer2.iteration, 63) + + def get_unsupervised_loader(self, n_samples): + augmentations = ( + torch_em.transform.raw.GaussianBlur(), + torch_em.transform.raw.GaussianBlur(), + ) + ds = torch_em.data.RawDataset( + raw_path=self.data_path, + raw_key=self.raw_key, + patch_shape=(1, 64, 64), + n_samples=n_samples, + ndim=2, + augmentations=augmentations, + ) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=1, shuffle=True) + return loader + + def get_supervised_loader(self, n_samples): + ds = torch_em.data.SegmentationDataset( + raw_path=self.data_path, raw_key=self.raw_key, + label_path=self.data_path, label_key=self.label_key, + patch_shape=(1, 64, 64), ndim=2, + n_samples=n_samples, + ) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=1, shuffle=True) + return loader + + def test_mean_teacher_unsupervised(self): + unsupervised_train_loader = self.get_unsupervised_loader(n_samples=50) + unsupervised_val_loader = self.get_unsupervised_loader(n_samples=4) + self._test_mean_teacher( + unsupervised_train_loader=unsupervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader, + unsupervised_loss_and_metric=self_training.DefaultSelfTrainingLossAndMetric(), + ) + + def test_mean_teacher_semisupervised(self): + unsupervised_train_loader = self.get_unsupervised_loader(n_samples=50) + supervised_train_loader = self.get_supervised_loader(n_samples=51) + supervised_val_loader = self.get_supervised_loader(n_samples=4) + self._test_mean_teacher( + unsupervised_train_loader=unsupervised_train_loader, + supervised_train_loader=supervised_train_loader, + supervised_val_loader=supervised_val_loader, + supervised_loss=self_training.DefaultSelfTrainingLoss(), + supervised_loss_and_metric=self_training.DefaultSelfTrainingLossAndMetric(), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_segmentation.py b/test/test_segmentation.py index b480da32..0fa6108f 100644 --- a/test/test_segmentation.py +++ b/test/test_segmentation.py @@ -57,8 +57,7 @@ def _test_training(self, model_class, model_kwargs, mixed_precision=False, device=torch.device("cpu"), logger=None) - train_iters = 51 - trainer.fit(train_iters) + trainer.fit(n_iterations) def _test_checkpoint(cp_path, check_progress): self.assertTrue(os.path.exists(cp_path)) @@ -71,7 +70,7 @@ def _test_checkpoint(cp_path, check_progress): loaded_model.load_state_dict(checkpoint["model_state"]) if check_progress: - self.assertEqual(checkpoint["iteration"], train_iters) + self.assertEqual(checkpoint["iteration"], n_iterations) self.assertEqual(checkpoint["epoch"], 2) _test_checkpoint("./checkpoints/test/latest.pt", True) diff --git a/test/util/test_modelzoo.py b/test/util/test_modelzoo.py index 78ca700e..38f2a826 100644 --- a/test/util/test_modelzoo.py +++ b/test/util/test_modelzoo.py @@ -11,6 +11,11 @@ from torch_em.model import UNet2d from torch_em.trainer import DefaultTrainer +try: + import onnx +except ImportError: + onnx = None + class ExpandChannels: def __init__(self, n_channels): @@ -89,14 +94,21 @@ def test_export_single_channel(self): def test_export_multi_channel(self): self._test_export(4) - def test_add_weights(self): + def test_add_weights_torchscript(self): from torch_em.util.modelzoo import add_weight_formats self._test_export(1) - additional_formats = ["onnx", "torchscript"] + additional_formats = ["torchscript"] add_weight_formats(self.save_folder, additional_formats) - self.assertTrue(os.path.join(self.save_folder, "weigths.onnx")) self.assertTrue(os.path.join(self.save_folder, "weigths-torchscript.pt")) + @unittest.skipIf(onnx is None, "Needs onnx") + def test_add_weights_onnx(self): + from torch_em.util.modelzoo import add_weight_formats + self._test_export(1) + additional_formats = ["onnx"] + add_weight_formats(self.save_folder, additional_formats) + self.assertTrue(os.path.join(self.save_folder, "weigths.onnx")) + if __name__ == "__main__": unittest.main() diff --git a/torch_em/data/raw_dataset.py b/torch_em/data/raw_dataset.py index 3bd16cf1..d54fd083 100644 --- a/torch_em/data/raw_dataset.py +++ b/torch_em/data/raw_dataset.py @@ -31,6 +31,7 @@ def __init__( sampler=None, ndim=None, with_channels=False, + augmentations=None, ): self.raw_path = raw_path self.raw_key = raw_key @@ -60,6 +61,10 @@ def __init__( self.sampler = sampler self.dtype = dtype + if augmentations is not None: + assert len(augmentations) == 2 + self.augmentations = augmentations + self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples self.sample_shape = patch_shape @@ -124,10 +129,18 @@ def __getitem__(self, index): if self.transform is not None: raw = self.transform(raw) + if isinstance(raw, list): + assert len(raw) == 1 + raw = raw[0] if self.trafo_halo is not None: raw = self.crop(raw) raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) + if self.augmentations is not None: + aug1, aug2 = self.augmentations + raw1, raw2 = aug1(raw), aug2(raw) + return raw1, raw2 + return raw # need to overwrite pickle to support h5py diff --git a/torch_em/data/raw_image_collection_dataset.py b/torch_em/data/raw_image_collection_dataset.py index 741a05bb..d6462a05 100644 --- a/torch_em/data/raw_image_collection_dataset.py +++ b/torch_em/data/raw_image_collection_dataset.py @@ -36,6 +36,7 @@ def __init__( dtype=torch.float32, n_samples=None, sampler=None, + augmentations=None, ): self._check_inputs(raw_image_paths) self.raw_images = raw_image_paths @@ -56,6 +57,10 @@ def __init__( self._len = n_samples self.sample_random_index = True + if augmentations is not None: + assert len(augmentations) == 2 + self.augmentations = augmentations + def __len__(self): return self._len @@ -104,13 +109,21 @@ def _get_sample(self, index): def __getitem__(self, index): raw = self._get_sample(index) + if self.raw_transform is not None: raw = self.raw_transform(raw) + if self.transform is not None: raw = self.transform(raw) assert len(raw) == 1 raw = raw[0] # if self.trafo_halo is not None: # raw = self.crop(raw) + raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) + if self.augmentations is not None: + aug1, aug2 = self.augmentations + raw1, raw2 = aug1(raw), aug2(raw) + return raw1, raw2 + return raw diff --git a/torch_em/self_training/__init__.py b/torch_em/self_training/__init__.py new file mode 100644 index 00000000..f5d2a38e --- /dev/null +++ b/torch_em/self_training/__init__.py @@ -0,0 +1,4 @@ +from .logger import SelfTrainingTensorboardLogger +from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric +from .mean_teacher import MeanTeacherTrainer +from .pseudo_labeling import DefaultPseudoLabeler diff --git a/torch_em/self_training/augmentations.py b/torch_em/self_training/augmentations.py new file mode 100644 index 00000000..f6d81a22 --- /dev/null +++ b/torch_em/self_training/augmentations.py @@ -0,0 +1 @@ +# TODO build up weak and strong augmentation libraries diff --git a/torch_em/self_training/fix_match.py b/torch_em/self_training/fix_match.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_em/self_training/logger.py b/torch_em/self_training/logger.py new file mode 100644 index 00000000..3ecbfa9d --- /dev/null +++ b/torch_em/self_training/logger.py @@ -0,0 +1,84 @@ +import os + +import torch +import torch_em + +from torchvision.utils import make_grid + + +class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): + def __init__(self, trainer, save_root, **unused_kwargs): + super().__init__(trainer, save_root) + self.my_root = save_root + self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ + os.path.join(self.my_root, "logs", trainer.name) + os.makedirs(self.log_dir, exist_ok=True) + + self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) + self.log_image_interval = trainer.log_image_interval + + def _add_supervised_images(self, step, name, x, y, pred): + if x.ndim == 5: + assert y.ndim == pred.ndim == 5 + zindex = x.shape[2] // 2 + x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] + + grid = make_grid( + [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]], + padding=8 + ) + self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) + + def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): + if x1.ndim == 5: + assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 + zindex = x1.shape[2] // 2 + x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] + pseudo_labels = pseudo_labels[:, :, zindex] + if label_filter is not None: + assert label_filter.ndim == 5 + label_filter = label_filter[:, :, zindex] + + images = [ + torch_em.transform.raw.normalize(x1[0]), + torch_em.transform.raw.normalize(x2[0]), + pred[0, 0:1], pseudo_labels[0, 0:1], + ] + im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels" + if label_filter is not None: + images.append(label_filter[0, 0:1]) + name += "-labelfilter" + grid = make_grid(images, nrow=2, padding=8) + self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) + + def log_combined_loss(self, step, loss): + self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) + + def log_lr(self, step, lr): + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + + def log_train_supervised(self, step, loss, x, y, pred): + self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) + if step % self.log_image_interval == 0: + self._add_supervised_images(step, "validation", x, y, pred) + + def log_validation_supervised(self, step, metric, loss, x, y, pred): + self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) + self._add_supervised_images(step, "validation", x, y, pred) + + def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): + self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) + if step % self.log_image_interval == 0: + self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) + + def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): + self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) + self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) + + def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + if gt_metric is not None: + self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step) diff --git a/torch_em/self_training/loss.py b/torch_em/self_training/loss.py new file mode 100644 index 00000000..030c116b --- /dev/null +++ b/torch_em/self_training/loss.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch_em + + +class DefaultSelfTrainingLoss(nn.Module): + """Loss function for self training. + + Parameters: + loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) + activation [nn.Module, callable] - the activation function to be applied to the prediction + before passing it to the loss. (default: None) + """ + def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None): + super().__init__() + self.activation = activation + self.loss = loss + + def __call__(self, model, input_, labels, label_filter=None): + prediction = model(input_) + if self.activation is not None: + prediction = self.activation(prediction) + if label_filter is None: + loss = self.loss(prediction, labels) + else: + loss = self.loss(prediction * label_filter, labels * label_filter) + return loss + + +class DefaultSelfTrainingLossAndMetric(nn.Module): + """Loss and metric function for self training. + + Parameters: + loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) + metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) + activation [nn.Module, callable] - the activation function to be applied to the prediction + before passing it to the loss. (default: None) + """ + def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None): + super().__init__() + self.activation = activation + self.loss = loss + self.metric = metric + + def __call__(self, model, input_, labels, label_filter=None): + prediction = model(input_) + if self.activation is not None: + prediction = self.activation(prediction) + if label_filter is None: + loss = self.loss(prediction, labels) + else: + loss = self.loss(prediction * label_filter, labels * label_filter) + metric = self.metric(prediction, labels) + return loss, metric diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py new file mode 100644 index 00000000..2aa760d3 --- /dev/null +++ b/torch_em/self_training/mean_teacher.py @@ -0,0 +1,349 @@ +import time +from copy import deepcopy + +import torch +import torch_em + +from .logger import SelfTrainingTensorboardLogger + + +class Dummy(torch.nn.Module): + pass + + +class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): + """This trainer implements self-traning for semi-supervised learning and domain following the 'MeanTeacher' approach + of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from the + student model via EMA of weights to predict pseudo-labels on unlabeled data. + We support two training strategies: joint training on labeled and unlabeled data + (with a supervised and unsupervised loss function). And training only on the unsupervised data. + + This class expects the following data loaders: + - unsupervised_train_loader: Returns two augmentations of the same input. + - supervised_train_loader (optional): Returns input and labels. + - unsupervised_val_loader (optional): Same as unsupervised_train_loader + - supervised_val_loader (optional): Same as supervised_train_loader + At least one of unsupervised_val_loader and supervised_val_loader must be given. + + And the following elements to customize the pseudo labeling: + - pseudo_labeler: to compute the psuedo-labels + - Parameters: teacher, teacher_input + - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) + - unsupervised_loss: the loss between model predictions and pseudo labels + - Parameters: model, model_input, pseudo_labels, label_filter + - Returns: loss + - supervised_loss (optional): the supervised loss function + - Parameters: model, input, labels + - Returns: loss + - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric + - Parameters: model, model_input, pseudo_labels, label_filter + - Returns: loss, metric + - supervised_loss_and_metric (optional): the supervised loss function and metric + - Parameters: model, input, labels + - Returns: loss, metric + At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. + + If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. + If it is None, the most appropriate initialization scheme for the training approach is chosen: + - semi-supervised training -> reinit, because we usually train a model from scratch + - unsupervised training -> do not reinit, because we usually fine-tune a model + + Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' + for setting the ratio between supervised and unsupervised training samples + + Parameters: + model [nn.Module] - + unsupervised_train_loader [torch.DataLoader] - + unsupervised_loss [callable] - + supervised_train_loader [torch.DataLoader] - (default: None) + supervised_loss [callable] - (default: None) + unsupervised_loss_and_metric [callable] - (default: None) + supervised_loss_and_metric [callable] - (default: None) + logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) + momentum [float] - (default: 0.999) + reinit_teacher [bool] - (default: None) + **kwargs - keyword arguments for torch_em.DataLoader + """ + + def __init__( + self, + model, + unsupervised_train_loader, + unsupervised_loss, + pseudo_labeler, + supervised_train_loader=None, + unsupervised_val_loader=None, + supervised_val_loader=None, + supervised_loss=None, + unsupervised_loss_and_metric=None, + supervised_loss_and_metric=None, + logger=SelfTrainingTensorboardLogger, + momentum=0.999, + reinit_teacher=None, + **kwargs + ): + # Do we have supervised data or not? + if supervised_train_loader is None: + # No. -> We use the unsupervised training logic. + train_loader = unsupervised_train_loader + self._train_epoch_impl = self._train_epoch_unsupervised + else: + # Yes. -> We use the semi-supervised training logic. + assert supervised_loss is not None + train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ + else unsupervised_train_loader + self._train_epoch_impl = self._train_epoch_semisupervised + self.unsupervised_train_loader = unsupervised_train_loader + self.supervised_train_loader = supervised_train_loader + + # Check that we have at least one of supvervised / unsupervised val loader. + assert sum(( + supervised_val_loader is not None, + unsupervised_val_loader is not None, + )) > 0 + self.supervised_val_loader = supervised_val_loader + self.unsupervised_val_loader = unsupervised_val_loader + + if self.unsupervised_val_loader is None: + val_loader = self.supervised_val_loader + else: + val_loader = self.unsupervised_train_loader + + # Check that we have at least one of supvervised / unsupervised loss and metric. + assert sum(( + supervised_loss_and_metric is not None, + unsupervised_loss_and_metric is not None, + )) > 0 + self.supervised_loss_and_metric = supervised_loss_and_metric + self.unsupervised_loss_and_metric = unsupervised_loss_and_metric + + super().__init__( + model=model, train_loader=train_loader, val_loader=val_loader, + loss=Dummy(), metric=Dummy(), logger=logger, **kwargs + ) + + self.unsupervised_loss = unsupervised_loss + self.supervised_loss = supervised_loss + + self.pseudo_labeler = pseudo_labeler + + self.momentum = momentum + + # determine how we initialize the teacher weights (copy or reinitialization) + if reinit_teacher is None: + # semisupervised training: reinitialize + # unsupervised training: copy + self.reinit_teacher = supervised_train_loader is not None + else: + self.reinit_teacher = reinit_teacher + + with torch.no_grad(): + self.teacher = deepcopy(self.model) + if self.reinit_teacher: + for layer in self.teacher.children(): + if hasattr(layer, "reset_parameters"): + layer.reset_parameters() + for param in self.teacher.parameters(): + param.requires_grad = False + + self._kwargs = kwargs + + def _momentum_update(self): + # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations + # to avoid a large gap between teacher and student weights, leading to inconsistent predictions + # if we don't reinit this is not necessary + if self.reinit_teacher: + current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) + else: + current_momentum = self.momentum + + for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): + param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) + + # TODO I think we need to serialize more things here for all the loaders etc. + # + # functionality for saving checkpoints and initialization + # + + def save_checkpoint(self, name, best_metric): + teacher_state = {"teacher_state": self.teacher.state_dict()} + super().save_checkpoint(name, best_metric, **teacher_state) + + def load_checkpoint(self, checkpoint="best"): + save_dict = super().load_checkpoint(checkpoint) + self.teacher.load_state_dict(save_dict["teacher_state"]) + self.teacher.to(self.device) + return save_dict + + def _initialize(self, iterations, load_from_checkpoint): + best_metric = super()._initialize(iterations, load_from_checkpoint) + self.teacher.to(self.device) + return best_metric + + # + # training and validation functionality + # + + def _train_epoch_unsupervised(self, progress, forward_context, backprop): + self.model.train() + + n_iter = 0 + t_per_iter = time.time() + + # Sample from both the supervised and unsupervised loader. + for xu1, xu2 in self.unsupervised_train_loader: + xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + + teacher_input, model_input = xu1, xu2 + + self.optimizer.zero_grad() + # Perform unsupervised training + with forward_context(): + # Compute the pseudo labels. + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) + backprop(loss) + + if self.logger is not None: + with torch.no_grad(), forward_context(): + pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None + self.logger.log_train_unsupervised( + self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter + ) + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + self.logger.log_lr(self._iteration, lr) + + with torch.no_grad(): + self._momentum_update() + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _train_epoch_semisupervised(self, progress, forward_context, backprop): + self.model.train() + + n_iter = 0 + t_per_iter = time.time() + + # Sample from both the supervised and unsupervised loader. + for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): + xs, ys = xs.to(self.device), ys.to(self.device) + xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + + # Perform supervised training. + self.optimizer.zero_grad() + with forward_context(): + # We pass the model, the input and the labels to the supervised loss function, + # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. + supervised_loss = self.supervised_loss(self.model, xs, ys) + + teacher_input, model_input = xu1, xu2 + # Perform unsupervised training + with forward_context(): + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) + + loss = (supervised_loss + unsupervised_loss) / 2 + backprop(loss) + + if self.logger is not None: + with torch.no_grad(), forward_context(): + unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None + supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None + + self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) + self.logger.log_train_unsupervised( + self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter + ) + + self.logger.log_combined_loss(self._iteration, loss) + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + self.logger.log_lr(self._iteration, lr) + + with torch.no_grad(): + self._momentum_update() + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _validate_supervised(self, forward_context): + metric_val = 0.0 + loss_val = 0.0 + + for x, y in self.supervised_val_loader: + x, y = x.to(self.device), y.to(self.device) + with forward_context(): + loss, metric = self.supervised_loss_and_metric(self.model, x, y) + loss_val += loss.item() + metric_val += metric.item() + + metric_val /= len(self.supervised_val_loader) + loss_val /= len(self.supervised_val_loader) + + if self.logger is not None: + with forward_context(): + pred = self.model(x) + self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) + + return metric_val + + def _validate_unsupervised(self, forward_context): + metric_val = 0.0 + loss_val = 0.0 + + for x1, x2 in self.unsupervised_val_loader: + x1, x2 = x1.to(self.device), x2.to(self.device) + teacher_input, model_input = x1, x2 + with forward_context(): + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) + loss_val += loss.item() + metric_val += metric.item() + + metric_val /= len(self.unsupervised_val_loader) + loss_val /= len(self.unsupervised_val_loader) + + if self.logger is not None: + with forward_context(): + pred = self.model(model_input) + self.logger.log_validation_unsupervised( + self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter + ) + + return metric_val + + def _validate_impl(self, forward_context): + self.model.eval() + + with torch.no_grad(): + + if self.supervised_val_loader is None: + supervised_metric = None + else: + supervised_metric = self._validate_supervised(forward_context) + + if self.unsupervised_val_loader is None: + unsupervised_metric = None + else: + unsupervised_metric = self._validate_unsupervised(forward_context) + + if unsupervised_metric is None: + metric = supervised_metric + elif supervised_metric is None: + metric = unsupervised_metric + else: + metric = (supervised_metric + unsupervised_metric) / 2 + + return metric diff --git a/torch_em/self_training/pseudo_labeling.py b/torch_em/self_training/pseudo_labeling.py new file mode 100644 index 00000000..1d84f66d --- /dev/null +++ b/torch_em/self_training/pseudo_labeling.py @@ -0,0 +1,39 @@ +import torch + + +class DefaultPseudoLabeler: + """Compute pseudo labels. + + Parameters: + activation [nn.Module, callable] - activation function applied to the teacher prediction. + confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. + If none is given no mask will be computed (default: None) + threshold_from_both_sides [bool] - whether to include both values bigger than the threshold + and smaller than 1 - it, or only values bigger than it in the mask. + The former should be used for binary labels, the latter for for multiclass labels (default: False) + """ + def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True): + self.activation = activation + self.confidence_threshold = confidence_threshold + self.threshold_from_both_sides = threshold_from_both_sides + + def _compute_label_mask_both_sides(self, pseudo_labels): + upper_threshold = self.confidence_threshold + lower_threshold = 1.0 - self.confidence_threshold + mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) + return mask + + def _compute_label_mask_one_side(self, pseudo_labels): + mask = (pseudo_labels >= self.confidence_threshold) + return mask + + def __call__(self, teacher, input_): + pseudo_labels = teacher(input_) + if self.activation is not None: + pseudo_labels = self.activation(pseudo_labels) + if self.confidence_threshold is None: + label_mask = None + else: + label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ + else self._compute_label_mask_one_side(pseudo_labels) + return pseudo_labels, label_mask diff --git a/torch_em/trainer/spoco_trainer.py b/torch_em/trainer/spoco_trainer.py index 3646465f..1d46dd17 100644 --- a/torch_em/trainer/spoco_trainer.py +++ b/torch_em/trainer/spoco_trainer.py @@ -2,8 +2,8 @@ from copy import deepcopy import torch -import torch.cuda.amp as amp from .default_trainer import DefaultTrainer +from .tensorboard_logger import TensorboardLogger class SPOCOTrainer(DefaultTrainer): @@ -13,9 +13,10 @@ def __init__( momentum=0.999, semisupervised_loss=None, semisupervised_loader=None, + logger=TensorboardLogger, **kwargs, ): - super().__init__(model=model, **kwargs) + super().__init__(model=model, logger=logger, **kwargs) self.momentum = momentum # copy the model and don"t require gradients for it self.model2 = deepcopy(self.model) @@ -46,7 +47,7 @@ def _initialize(self, iterations, load_from_checkpoint): self.model2.to(self.device) return best_metric - def _train_epoch_semisupervised(self, progress): + def _train_epoch_semisupervised(self, progress, forward_context, backprop): self.model.train() self.model2.train() progress.set_description( @@ -57,75 +58,17 @@ def _train_epoch_semisupervised(self, progress): x = x.to(self.device) self.optimizer.zero_grad() - prediction = self.model(x) - with torch.no_grad(): - self._momentum_update() - prediction2 = self.model2(x) - loss = self.semisupervised_loss(prediction, prediction2) - loss.backward() - self.optimizer.step() - - def _train_epoch(self, progress): - self.model.train() - self.model2.train() - - n_iter = 0 - t_per_iter = time.time() - for x, y in self.train_loader: - x, y = x.to(self.device), y.to(self.device) - - self.optimizer.zero_grad() - - prediction = self.model(x) - with torch.no_grad(): - self._momentum_update() - prediction2 = self.model2(x) - if self._iteration % self.log_image_interval == 0: - prediction.retain_grad() - loss = self.loss((prediction, prediction2), y) - - loss.backward() - self.optimizer.step() - - lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - if self.logger is not None: - self.logger.log_train(self._iteration, loss, lr, - x, y, prediction, - log_gradients=True) - - self._iteration += 1 - n_iter += 1 - if self._iteration >= self.max_iteration: - break - progress.update(1) - - if self.semisupervised_loader is not None: - self._train_epoch_semisupervised(progress) - t_per_iter = (time.time() - t_per_iter) / n_iter - return t_per_iter - - def _train_epoch_semisupervised_mixed(self, progress): - self.model.train() - self.model2.train() - progress.set_description( - f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True - ) - - for x in self.semisupervised_loader: - x = x.to(self.device) - self.optimizer.zero_grad() - - with amp.autocast(): + with forward_context(): prediction = self.model(x) with torch.no_grad(): - self._momentum_update() prediction2 = self.model2(x) loss = self.semisupervised_loss(prediction, prediction2) - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() + backprop(loss) - def _train_epoch_mixed(self, progress): + with torch.no_grad(): + self._momentum_update() + + def _train_epoch_impl(self, progress, forward_context, backprop): self.model.train() self.model2.train() @@ -136,21 +79,25 @@ def _train_epoch_mixed(self, progress): self.optimizer.zero_grad() - with amp.autocast(): + with forward_context(): prediction = self.model(x) with torch.no_grad(): - self._momentum_update() prediction2 = self.model2(x) loss = self.loss((prediction, prediction2), y) - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() + if self._iteration % self.log_image_interval == 0: + prediction.retain_grad() + + backprop(loss) + + with torch.no_grad(): + self._momentum_update() lr = [pm["lr"] for pm in self.optimizer.param_groups][0] if self.logger is not None: self.logger.log_train(self._iteration, loss, lr, - x, y, prediction) + x, y, prediction, + log_gradients=True) self._iteration += 1 n_iter += 1 @@ -159,11 +106,11 @@ def _train_epoch_mixed(self, progress): progress.update(1) if self.semisupervised_loader is not None: - self._train_epoch_semisupervised_mixed(progress) + self._train_epoch_semisupervised(progress, forward_context, backprop) t_per_iter = (time.time() - t_per_iter) / n_iter return t_per_iter - def _validate(self): + def _validate_impl(self, forward_context): self.model.eval() self.model2.eval() @@ -173,39 +120,14 @@ def _validate(self): with torch.no_grad(): for x, y in self.val_loader: x, y = x.to(self.device), y.to(self.device) - prediction = self.model(x) - prediction2 = self.model2(x) + with forward_context(): + prediction = self.model(x) + prediction2 = self.model2(x) loss += self.loss((prediction, prediction2), y).item() metric += self.metric(prediction, y).item() metric /= len(self.val_loader) loss /= len(self.val_loader) if self.logger is not None: - self.logger.log_validation(self._iteration, metric, loss, - x, y, prediction) + self.logger.log_validation(self._iteration, metric, loss, x, y, prediction) return metric - - def _validate_mixed(self): - self.model.eval() - self.model2.eval() - - metric_val = 0.0 - loss_val = 0.0 - - with torch.no_grad(): - for x, y in self.val_loader: - x, y = x.to(self.device), y.to(self.device) - with amp.autocast(): - prediction = self.model(x) - prediction2 = self.model2(x) - loss = self.loss((prediction, prediction2), y) - metric = self.metric(prediction, y) - loss_val += loss - metric_val += metric - - metric_val /= len(self.val_loader) - loss_val /= len(self.val_loader) - if self.logger is not None: - self.logger.log_validation(self._iteration, metric, loss, - x, y, prediction) - return metric_val diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index e20344c1..cd89a10e 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -8,18 +8,6 @@ # -def standardize(raw, mean=None, std=None, axis=None, eps=1e-7): - raw = raw.astype("float32") - - mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean - raw -= mean - - std = raw.std(axis=axis, keepdims=True) if std is None else std - raw /= (std + eps) - - return raw - - TORCH_DTYPES = { "float16": torch.float16, "float32": torch.float32, @@ -42,7 +30,21 @@ def cast(inpt, typestring): return inpt.astype(typestring) -def _normalize_torch(tensor, minval=None, maxval=None, axis=None, eps=1e-7): +def standardize(raw, mean=None, std=None, axis=None, eps=1e-7): + raw = cast(raw, "float32") + + # mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean + mean = raw.mean() if mean is None else mean + raw -= mean + + # std = raw.std(axis=axis, keepdims=True) if std is None else std + std = raw.std() if std is None else std + raw /= (std + eps) + + return raw + + +def _normalize_torch(tensor, minval, maxval, axis, eps): if axis: # torch returns torch.return_types.min or torch.return_types.max minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval tensor -= minval diff --git a/torch_em/util/modelzoo.py b/torch_em/util/modelzoo.py index 2f2a1693..5c0273f3 100644 --- a/torch_em/util/modelzoo.py +++ b/torch_em/util/modelzoo.py @@ -118,7 +118,9 @@ def _write_depedencies(export_folder, dependencies): if dependencies is None: ver = torch.__version__ major, minor = list(map(int, ver.split(".")[:2])) - assert major == 1 + assert major in (1, 2) + if major == 2: + warn("Modelzoo functionality is not fully tested for PyTorch 2") # the torch zip layout changed for a few versions: torch_min_version = "1.0" if minor > 6 and minor < 10: @@ -129,7 +131,7 @@ def _write_depedencies(export_folder, dependencies): dependencies = { "channels": ["pytorch", "conda-forge"], "name": "torch-em-deploy", - "dependencies": [f"pytorch>={torch_min_version},<2.0"] + "dependencies": [f"pytorch>={torch_min_version}"] } with open(dep_path, "w") as f: yaml.dump(dependencies, f)