From 54625a6e48b82489544561a6b591ce9c0c9da1ca Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 29 Mar 2023 22:48:27 +0200 Subject: [PATCH 1/7] Update Mean-Teacher Self-Training Scheme(s) --- .../livecell/common.py | 17 ++++++++++------- .../livecell/unet_adamt.py | 8 ++++---- .../livecell/unet_mean_teacher.py | 13 ++++++++----- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index c402c4fe..7c6ae2a1 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -55,9 +55,9 @@ def get_unet(): return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4) -def load_model(model, ckpt, state="model_state", device=None): - state = torch.load(os.path.join(ckpt, "best.pt"))[state] - model.load_state_dict(state) +def load_model(model, ckpt, get_model=get_unet, device=None): + model = get_model() + model = torch_em.util.get_trainer(ckpt).model if device is not None: model.to(device) return model @@ -88,9 +88,12 @@ def evaluate_transfered_model( if out_folder is not None: os.makedirs(out_folder, exist_ok=True) - ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" + if args.save_root is None: + ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" + else: + ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" model = get_model() - model = load_model(model, ckpt, device=device, state=model_state) + model = load_model(model, ckpt, device=device) label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) scores = [] @@ -190,7 +193,7 @@ def _get_image_paths(args, split, cell_type): return image_paths -def get_unsupervised_loader(args, split, cell_type, teacher_augmentation, student_augmentation): +def get_unsupervised_loader(args, batch_size, split, cell_type, teacher_augmentation, student_augmentation): patch_shape = (256, 256) def _parse_aug(aug): @@ -211,7 +214,7 @@ def _parse_aug(aug): image_paths, patch_shape, raw_transform, transform, augmentations=augmentations ) - loader = torch_em.segmentation.get_data_loader(ds, batch_size=args.batch_size, num_workers=8, shuffle=True) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=8, shuffle=True) return loader diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py index ff9590f2..e3d7c2fd 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -22,7 +22,7 @@ def check_loader(args, n_images=5): def _train_source_target(args, source_cell_type, target_cell_type): - model = common.get_model() + model = common.get_unet() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) @@ -36,11 +36,11 @@ def _train_source_target(args, source_cell_type, target_cell_type): supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size) supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1) unsupervised_train_loader = common.get_unsupervised_loader( - args, "train", target_cell_type, + args, args.batch_size, "train", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) unsupervised_val_loader = common.get_unsupervised_loader( - args, "val", target_cell_type, + args, 1, "val", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) @@ -98,7 +98,7 @@ def run_evaluation(args): def main(): parser = common.get_parser(default_iterations=75000, default_batch_size=4) - parser.add_argument("--confidence_threshold", default=0.9) + parser.add_argument("--confidence_threshold", default=None, type=float) args = parser.parse_args() if args.phase in ("c", "check"): check_loader(args) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py index 79bc309b..d4202659 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py @@ -23,10 +23,13 @@ def check_loader(args, n_images=5): def _train_source_target(args, source_cell_type, target_cell_type): model = common.get_unet() - src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" + if args.save_root is None: + src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" + else: + src_checkpoint = args.save_root + f"checkpoints/unet_source/{source_cell_type}" model = common.load_model(model, src_checkpoint) - optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) # self training functionality @@ -37,11 +40,11 @@ def _train_source_target(args, source_cell_type, target_cell_type): # data loaders unsupervised_train_loader = common.get_unsupervised_loader( - args, "train", target_cell_type, + args, args.batch_size, "train", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) unsupervised_val_loader = common.get_unsupervised_loader( - args, "val", target_cell_type, + args, 1, "val", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) @@ -97,7 +100,7 @@ def run_evaluation(args): def main(): parser = common.get_parser(default_iterations=25000, default_batch_size=8) - parser.add_argument("--confidence_threshold", default=0.9) + parser.add_argument("--confidence_threshold", default=None, type=float) args = parser.parse_args() if args.phase in ("c", "check"): check_loader(args) From 18f289e6d6d8067ed625d80be0107b01bf9ab7ac Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 31 Mar 2023 13:47:20 +0200 Subject: [PATCH 2/7] Update Mean-Teacher Training + Add FixMatch-based Trainings --- .../livecell/common.py | 54 ++- .../livecell/unet_adamatch.py | 128 ++++++++ .../livecell/unet_fixmatch.py | 134 ++++++++ torch_em/self_training/__init__.py | 1 + torch_em/self_training/fix_match.py | 309 ++++++++++++++++++ 5 files changed, 614 insertions(+), 12 deletions(-) create mode 100644 experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py create mode 100644 experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index 7c6ae2a1..aa0c07a5 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -19,14 +19,18 @@ from torch_em.util.prediction import predict_with_padding from torchvision import transforms from tqdm import tqdm +from torch_em.util import load_model CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] # # The augmentations we use for the LiveCELL experiments: -# - weak augmenations: blurring and additive gaussian noise -# - strong augmentations: TODO +# - weak augmenations: +# blurring and additive gaussian noise +# +# - strong augmentations: +# blurring, additive gaussian noise and randon contrast adjustment (FixMatch expects stronger parameters for each) # @@ -42,9 +46,19 @@ def weak_augmentations(p=0.25): return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) -# TODO -def strong_augmentations(): - pass +def strong_augmentations(p=0.5): + norm = torch_em.transform.raw.standardize + aug1 = transforms.Compose([ + norm, + transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(0.6, 3.0))], p=p), + transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( + scale=(0.05, 0.25), clip_kwargs=False)], p=p/2 + ), + transforms.RandomApply([torch_em.transform.raw.RandomContrast( + mean=0.0, alpha=(0.33, 3.0), clip_kwargs=False)], p=p + ) + ]) + return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug1) # @@ -55,12 +69,28 @@ def get_unet(): return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4) -def load_model(model, ckpt, get_model=get_unet, device=None): - model = get_model() - model = torch_em.util.get_trainer(ckpt).model - if device is not None: - model.to(device) - return model +# Computing the Source Distribution for Distribution Alignment +def compute_class_distribution(root_folder, label_threshold=0.5): + + bg_list, fg_list = [], [] + total = 0 + + files = glob(os.path.join(root_folder, "*")) + assert len(files) > 0, f"Did not find predictions @ {root_folder}" + + for pl_path in files: + img = imageio.imread(pl_path) + img = np.where(img >= label_threshold, 1, 0) + _, counts = np.unique(img, return_counts=True) + assert len(counts) == 2 + bg_list.append(counts[0]) + fg_list.append(counts[1]) + total += img.size + + bg_frequency = sum(bg_list) / float(total) + fg_frequency = sum(fg_list) / float(total) + assert np.isclose(bg_frequency + fg_frequency, 1.0) + return [bg_frequency, fg_frequency] # use get_model and prediction_function to customize this, e.g. for using it with the PUNet @@ -93,7 +123,7 @@ def evaluate_transfered_model( else: ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" model = get_model() - model = load_model(model, ckpt, device=device) + model = load_model(checkpoint=ckpt, model=model, state_key=model_state, device=device) label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) scores = [] diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py new file mode 100644 index 00000000..08989be9 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py @@ -0,0 +1,128 @@ +import os + +import pandas as pd +import torch +import torch_em.self_training as self_training + +import common + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = common.get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="strong", + ) + check_loader(loader, n_images) + + +def _train_source_target(args, source_cell_type, target_cell_type): + model = common.get_unet() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + + # self training functionality + thresh = args.confidence_threshold + pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh) + loss = self_training.DefaultSelfTrainingLoss() + loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() + + # data loaders + supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size) + supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1) + unsupervised_train_loader = common.get_unsupervised_loader( + args, args.batch_size, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="strong", + ) + unsupervised_val_loader = common.get_unsupervised_loader( + args, 1, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="strong", + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + name = f"unet_adamatch/thresh-{thresh}" + + if args.distribution_alignment: + assert args.output is not None + print(f"Getting scores for Source {source_cell_type} at Targets {target_cell_type}") + pred_folder = args.output + f"unet_source/{source_cell_type}/{target_cell_type}/" + src_dist = common.compute_class_distribution(pred_folder) + name = f"{name}-distro-align" + else: + src_dist = None + + name = name + f"/{source_cell_type}/{target_cell_type}" + + trainer = self_training.FixMatchTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + supervised_train_loader=supervised_train_loader, + unsupervised_train_loader=unsupervised_train_loader, + supervised_val_loader=supervised_val_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=self_training.SelfTrainingTensorboardLogger, + mixed_precision=True, + device=device, + log_image_interval=100, + save_root=args.save_root, + source_distribution=src_dist + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + for target_cell_type in common.CELL_TYPES: + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "unet_adamatch") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_adamatch.csv"), index=False) + + +def main(): + parser = common.get_parser(default_iterations=25000, default_batch_size=8) + parser.add_argument("--confidence_threshold", default=None, type=float) + parser.add_argument("--distribution_alignment", action='store_true', help="Activates Distribution Alignment") + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py new file mode 100644 index 00000000..d5160156 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py @@ -0,0 +1,134 @@ +import os + +import pandas as pd +import torch +import torch_em.self_training as self_training +from torch_em.util import load_model + +import common + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = common.get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="strong", + ) + check_loader(loader, n_images) + + +def _train_source_target(args, source_cell_type, target_cell_type): + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + model = common.get_unet() + if args.save_root is None: + src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" + else: + src_checkpoint = args.save_root + f"checkpoints/unet_source/{source_cell_type}" + model = load_model(checkpoint=src_checkpoint, model=model, device=device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + + # self training functionality + thresh = args.confidence_threshold + pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh) + loss = self_training.DefaultSelfTrainingLoss() + loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() + + # data loaders + unsupervised_train_loader = common.get_unsupervised_loader( + args, args.batch_size, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="strong", + ) + unsupervised_val_loader = common.get_unsupervised_loader( + args, 1, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="strong", + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + name = f"unet_fixmatch/thresh-{thresh}" + + if args.distribution_alignment: + assert args.output is not None + print(f"Getting scores for Source {source_cell_type} at Targets {target_cell_type}") + pred_folder = args.output + f"unet_source/{source_cell_type}/{target_cell_type}/" + src_dist = common.compute_class_distribution(pred_folder) + name = f"{name}-distro-align" + else: + src_dist = None + + name = name + f"/{source_cell_type}/{target_cell_type}" + + trainer = self_training.FixMatchTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + unsupervised_train_loader=unsupervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=self_training.SelfTrainingTensorboardLogger, + mixed_precision=True, + device=device, + log_image_interval=100, + save_root=args.save_root, + source_distribution=src_dist + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + for target_cell_type in common.CELL_TYPES: + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "unet_fixmatch") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_fixmatch.csv"), index=False) + + +def main(): + parser = common.get_parser(default_iterations=25000, default_batch_size=8) + parser.add_argument("--confidence_threshold", default=None, type=float) + parser.add_argument("--distribution_alignment", action='store_true', help="Activates Distribution Alignment") + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/torch_em/self_training/__init__.py b/torch_em/self_training/__init__.py index f5d2a38e..03b5304c 100644 --- a/torch_em/self_training/__init__.py +++ b/torch_em/self_training/__init__.py @@ -1,4 +1,5 @@ from .logger import SelfTrainingTensorboardLogger from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric from .mean_teacher import MeanTeacherTrainer +from .fix_match import FixMatchTrainer from .pseudo_labeling import DefaultPseudoLabeler diff --git a/torch_em/self_training/fix_match.py b/torch_em/self_training/fix_match.py index e69de29b..3c90b7ab 100644 --- a/torch_em/self_training/fix_match.py +++ b/torch_em/self_training/fix_match.py @@ -0,0 +1,309 @@ +import time + +import torch +import torch_em + +from .logger import SelfTrainingTensorboardLogger + + +class Dummy(torch.nn.Module): + pass + + +class FixMatchTrainer(torch_em.trainer.DefaultTrainer): + """This trainer implements self-traning for semi-supervised learning and domain following the 'FixMatch' approach + of Sohn et al. (https://arxiv.org/abs/2001.07685). This approach uses a (teacher) model derived from the + student model via sharing the weights to predict pseudo-labels on unlabeled data. + We support two training strategies: joint training on labeled and unlabeled data + (with a supervised and unsupervised loss function). And training only on the unsupervised data. + + This class expects the following data loaders: + - unsupervised_train_loader: Returns two augmentations (weak and strong) of the same input. + - supervised_train_loader (optional): Returns input and labels. + - unsupervised_val_loader (optional): Same as unsupervised_train_loader + - supervised_val_loader (optional): Same as supervised_train_loader + At least one of unsupervised_val_loader and supervised_val_loader must be given. + + And the following elements to customize the pseudo labeling: + - pseudo_labeler: to compute the psuedo-labels + - Parameters: model, teacher_input + - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) + - unsupervised_loss: the loss between model predictions and pseudo labels + - Parameters: model, model_input, pseudo_labels, label_filter + - Returns: loss + - supervised_loss (optional): the supervised loss function + - Parameters: model, input, labels + - Returns: loss + - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric + - Parameters: model, model_input, pseudo_labels, label_filter + - Returns: loss, metric + - supervised_loss_and_metric (optional): the supervised loss function and metric + - Parameters: model, input, labels + - Returns: loss, metric + At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. + + Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' + for setting the ratio between supervised and unsupervised training samples + + Parameters: + model [nn.Module] - + unsupervised_train_loader [torch.DataLoader] - + unsupervised_loss [callable] - + supervised_train_loader [torch.DataLoader] - (default: None) + supervised_loss [callable] - (default: None) + unsupervised_loss_and_metric [callable] - (default: None) + supervised_loss_and_metric [callable] - (default: None) + logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) + momentum [float] - (default: 0.999) + **kwargs - keyword arguments for torch_em.DataLoader + """ + + def __init__( + self, + model, + unsupervised_train_loader, + unsupervised_loss, + pseudo_labeler, + supervised_train_loader=None, + unsupervised_val_loader=None, + supervised_val_loader=None, + supervised_loss=None, + unsupervised_loss_and_metric=None, + supervised_loss_and_metric=None, + logger=SelfTrainingTensorboardLogger, + source_distribution=None, + **kwargs + ): + # Do we have supervised data or not? + if supervised_train_loader is None: + # No. -> We use the unsupervised training logic. + train_loader = unsupervised_train_loader + self._train_epoch_impl = self._train_epoch_unsupervised + else: + # Yes. -> We use the semi-supervised training logic. + assert supervised_loss is not None + train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ + else unsupervised_train_loader + self._train_epoch_impl = self._train_epoch_semisupervised + + self.unsupervised_train_loader = unsupervised_train_loader + self.supervised_train_loader = supervised_train_loader + + # Check that we have at least one of supvervised / unsupervised val loader. + assert sum(( + supervised_val_loader is not None, + unsupervised_val_loader is not None, + )) > 0 + self.supervised_val_loader = supervised_val_loader + self.unsupervised_val_loader = unsupervised_val_loader + + if self.unsupervised_val_loader is None: + val_loader = self.supervised_val_loader + else: + val_loader = self.unsupervised_train_loader + + # Check that we have at least one of supvervised / unsupervised loss and metric. + assert sum(( + supervised_loss_and_metric is not None, + unsupervised_loss_and_metric is not None, + )) > 0 + self.supervised_loss_and_metric = supervised_loss_and_metric + self.unsupervised_loss_and_metric = unsupervised_loss_and_metric + + super().__init__( + model=model, train_loader=train_loader, val_loader=val_loader, + loss=Dummy(), metric=Dummy(), logger=logger, **kwargs + ) + + self.unsupervised_loss = unsupervised_loss + self.supervised_loss = supervised_loss + + self.pseudo_labeler = pseudo_labeler + + if source_distribution is None: + self.source_distribution = None + else: + self.source_distribution = torch.FloatTensor(source_distribution).to(self.device) + + self._kwargs = kwargs + + def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): + if self.source_distribution is not None: + pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) + _, target_distribution = torch.unique(pseudo_labels_binary, return_counts=True) + target_distribution = target_distribution / target_distribution.sum() + distribution_ratio = self.source_distribution / target_distribution + pseudo_labels = torch.where( + pseudo_labels < label_threshold, + pseudo_labels * distribution_ratio[0], + pseudo_labels * distribution_ratio[1] + ).clip(0, 1) + + return pseudo_labels + + # + # training and validation functionality + # + + def _train_epoch_unsupervised(self, progress, forward_context, backprop): + self.model.train() + + n_iter = 0 + t_per_iter = time.time() + + # Sample from both the supervised and unsupervised loader. + for xu1, xu2 in self.unsupervised_train_loader: + xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + + teacher_input, model_input = xu1, xu2 + + self.optimizer.zero_grad() + # Perform unsupervised training + with forward_context(): + # Compute the pseudo labels. + pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) + # Perform distribution alignment for pseudo labels + pseudo_labels = self.get_distribution_alignment(pseudo_labels) + loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) + backprop(loss) + + if self.logger is not None: + with torch.no_grad(), forward_context(): + pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None + self.logger.log_train_unsupervised( + self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter + ) + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + self.logger.log_lr(self._iteration, lr) + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _train_epoch_semisupervised(self, progress, forward_context, backprop): + self.model.train() + + n_iter = 0 + t_per_iter = time.time() + + # Sample from both the supervised and unsupervised loader. + for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): + xs, ys = xs.to(self.device), ys.to(self.device) + xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + + # Perform supervised training. + self.optimizer.zero_grad() + with forward_context(): + # We pass the model, the input and the labels to the supervised loss function, + # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. + supervised_loss = self.supervised_loss(self.model, xs, ys) + + teacher_input, model_input = xu1, xu2 + # Perform unsupervised training + with forward_context(): + # Compute the pseudo labels. + pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) + # Perform distribution alignment for pseudo labels + pseudo_labels = self.get_distribution_alignment(pseudo_labels) + unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) + + loss = (supervised_loss + unsupervised_loss) / 2 + backprop(loss) + + if self.logger is not None: + with torch.no_grad(), forward_context(): + unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None + supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None + + self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) + self.logger.log_train_unsupervised( + self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter + ) + + self.logger.log_combined_loss(self._iteration, loss) + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + self.logger.log_lr(self._iteration, lr) + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _validate_supervised(self, forward_context): + metric_val = 0.0 + loss_val = 0.0 + + for x, y in self.supervised_val_loader: + x, y = x.to(self.device), y.to(self.device) + with forward_context(): + loss, metric = self.supervised_loss_and_metric(self.model, x, y) + loss_val += loss.item() + metric_val += metric.item() + + metric_val /= len(self.supervised_val_loader) + loss_val /= len(self.supervised_val_loader) + + if self.logger is not None: + with forward_context(): + pred = self.model(x) + self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) + + return metric_val + + def _validate_unsupervised(self, forward_context): + metric_val = 0.0 + loss_val = 0.0 + + for x1, x2 in self.unsupervised_val_loader: + x1, x2 = x1.to(self.device), x2.to(self.device) + teacher_input, model_input = x1, x2 + with forward_context(): + pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) + loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) + loss_val += loss.item() + metric_val += metric.item() + + metric_val /= len(self.unsupervised_val_loader) + loss_val /= len(self.unsupervised_val_loader) + + if self.logger is not None: + with forward_context(): + pred = self.model(model_input) + self.logger.log_validation_unsupervised( + self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter + ) + + return metric_val + + def _validate_impl(self, forward_context): + self.model.eval() + + with torch.no_grad(): + + if self.supervised_val_loader is None: + supervised_metric = None + else: + supervised_metric = self._validate_supervised(forward_context) + + if self.unsupervised_val_loader is None: + unsupervised_metric = None + else: + unsupervised_metric = self._validate_unsupervised(forward_context) + + if unsupervised_metric is None: + metric = supervised_metric + elif supervised_metric is None: + metric = unsupervised_metric + else: + metric = (supervised_metric + unsupervised_metric) / 2 + + return metric From bc0a3501e61ee857b8304a46685d99db7de378bd Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 31 Mar 2023 14:18:04 +0200 Subject: [PATCH 3/7] Update Strong Augmentations - for FixMatch (both Separate and Joint Setups) --- .../livecell/common.py | 44 +++++++++++++------ .../livecell/unet_adamatch.py | 4 +- .../livecell/unet_fixmatch.py | 4 +- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index aa0c07a5..d60d08ca 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -30,7 +30,7 @@ # blurring and additive gaussian noise # # - strong augmentations: -# blurring, additive gaussian noise and randon contrast adjustment (FixMatch expects stronger parameters for each) +# blurring, additive gaussian noise and randon contrast adjustment # @@ -46,18 +46,32 @@ def weak_augmentations(p=0.25): return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) -def strong_augmentations(p=0.5): +def strong_augmentations(p=0.5, mode=None): + assert mode is not None norm = torch_em.transform.raw.standardize - aug1 = transforms.Compose([ - norm, - transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(0.6, 3.0))], p=p), - transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( - scale=(0.05, 0.25), clip_kwargs=False)], p=p/2 - ), - transforms.RandomApply([torch_em.transform.raw.RandomContrast( - mean=0.0, alpha=(0.33, 3.0), clip_kwargs=False)], p=p - ) - ]) + + if mode == "separate": + aug1 = transforms.Compose([ + norm, + transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(1.0, 4.0))], p=p), + transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( + scale=(0.1, 0.35), clip_kwargs=False)], p=p), + transforms.RandomApply([torch_em.transform.raw.RandomContrast( + mean=0.0, alpha=(0.33, 3), clip_kwargs=False)], p=p), + ]) + + elif mode == "joint": + aug1 = transforms.Compose([ + norm, + transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(0.6, 3.0))], p=p), + transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( + scale=(0.05, 0.25), clip_kwargs=False)], p=p/2 + ), + transforms.RandomApply([torch_em.transform.raw.RandomContrast( + mean=0.0, alpha=(0.33, 3.0), clip_kwargs=False)], p=p + ) + ]) + return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug1) @@ -229,8 +243,10 @@ def get_unsupervised_loader(args, batch_size, split, cell_type, teacher_augmenta def _parse_aug(aug): if aug == "weak": return weak_augmentations() - elif aug == "strong": - return strong_augmentations() + elif aug == "strong-separate": + return strong_augmentations(mode="separate") + elif aug == "strong-joint": + return strong_augmentations(mode="joint") assert callable(aug) return aug diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py index 08989be9..ef723c1e 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py @@ -37,11 +37,11 @@ def _train_source_target(args, source_cell_type, target_cell_type): supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1) unsupervised_train_loader = common.get_unsupervised_loader( args, args.batch_size, "train", target_cell_type, - teacher_augmentation="weak", student_augmentation="strong", + teacher_augmentation="weak", student_augmentation="strong-joint", ) unsupervised_val_loader = common.get_unsupervised_loader( args, 1, "val", target_cell_type, - teacher_augmentation="weak", student_augmentation="strong", + teacher_augmentation="weak", student_augmentation="strong-joint", ) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py index d5160156..6ebad36b 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py @@ -45,11 +45,11 @@ def _train_source_target(args, source_cell_type, target_cell_type): # data loaders unsupervised_train_loader = common.get_unsupervised_loader( args, args.batch_size, "train", target_cell_type, - teacher_augmentation="weak", student_augmentation="strong", + teacher_augmentation="weak", student_augmentation="strong-separate", ) unsupervised_val_loader = common.get_unsupervised_loader( args, 1, "val", target_cell_type, - teacher_augmentation="weak", student_augmentation="strong", + teacher_augmentation="weak", student_augmentation="strong-separate", ) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") From fe9dcb5ab82cea8b21c43d44dba4297941001275 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 31 Mar 2023 15:46:19 +0200 Subject: [PATCH 4/7] Update Gradient Accumulation from Pseudo Labels (+ Add Parameter Descriptions) --- torch_em/self_training/fix_match.py | 33 +++++++++++++++++++------- torch_em/self_training/mean_teacher.py | 13 +++++++--- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/torch_em/self_training/fix_match.py b/torch_em/self_training/fix_match.py index 3c90b7ab..bdb5d495 100644 --- a/torch_em/self_training/fix_match.py +++ b/torch_em/self_training/fix_match.py @@ -49,12 +49,14 @@ class FixMatchTrainer(torch_em.trainer.DefaultTrainer): model [nn.Module] - unsupervised_train_loader [torch.DataLoader] - unsupervised_loss [callable] - + pseudo_labeler [callable] - supervised_train_loader [torch.DataLoader] - (default: None) supervised_loss [callable] - (default: None) unsupervised_loss_and_metric [callable] - (default: None) supervised_loss_and_metric [callable] - (default: None) logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) momentum [float] - (default: 0.999) + source_distribution [list] - (default: None) **kwargs - keyword arguments for torch_em.DataLoader """ @@ -127,6 +129,9 @@ def __init__( self._kwargs = kwargs + # distribution alignment - encourages the distribution of the model's generated pseudo labels to match the marginal + # distribution of pseudo labels from the source transfer + # (key idea: to maximize the mutual information) def get_distribution_alignment(self, pseudo_labels, label_threshold=0.5): if self.source_distribution is not None: pseudo_labels_binary = torch.where(pseudo_labels >= label_threshold, 1, 0) @@ -157,14 +162,20 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): teacher_input, model_input = xu1, xu2 + with torch.no_grad(): + # Compute the pseudo labels. + pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) + + pseudo_labels, label_filter = pseudo_labels.detach(), label_filter.detach() + + # Perform distribution alignment for pseudo labels + pseudo_labels = self.get_distribution_alignment(pseudo_labels) + self.optimizer.zero_grad() # Perform unsupervised training with forward_context(): - # Compute the pseudo labels. - pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) - # Perform distribution alignment for pseudo labels - pseudo_labels = self.get_distribution_alignment(pseudo_labels) loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) + backprop(loss) if self.logger is not None: @@ -204,12 +215,18 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): supervised_loss = self.supervised_loss(self.model, xs, ys) teacher_input, model_input = xu1, xu2 - # Perform unsupervised training - with forward_context(): + + with torch.no_grad(): # Compute the pseudo labels. pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) - # Perform distribution alignment for pseudo labels - pseudo_labels = self.get_distribution_alignment(pseudo_labels) + + pseudo_labels, label_filter = pseudo_labels.detach(), label_filter.detach() + + # Perform distribution alignment for pseudo labels + pseudo_labels = self.get_distribution_alignment(pseudo_labels) + + # Perform unsupervised training + with forward_context(): unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) loss = (supervised_loss + unsupervised_loss) / 2 diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index 2aa760d3..e675ce81 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -55,6 +55,7 @@ class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): model [nn.Module] - unsupervised_train_loader [torch.DataLoader] - unsupervised_loss [callable] - + pseudo_labeler [callable] - supervised_train_loader [torch.DataLoader] - (default: None) supervised_loss [callable] - (default: None) unsupervised_loss_and_metric [callable] - (default: None) @@ -196,11 +197,13 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): teacher_input, model_input = xu1, xu2 + with torch.no_grad(): + # Compute the pseudo labels. + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + self.optimizer.zero_grad() # Perform unsupervised training with forward_context(): - # Compute the pseudo labels. - pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) backprop(loss) @@ -244,9 +247,13 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): supervised_loss = self.supervised_loss(self.model, xs, ys) teacher_input, model_input = xu1, xu2 + + with torch.no_grad(): + # Compute the pseudo labels. + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + # Perform unsupervised training with forward_context(): - pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) loss = (supervised_loss + unsupervised_loss) / 2 From 1ceabf64992dc5967878465f9a4c3d3d1ceecb8b Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 31 Mar 2023 18:18:22 +0200 Subject: [PATCH 5/7] Update forward_context for Pseudo Labels --- torch_em/self_training/fix_match.py | 4 ++-- torch_em/self_training/mean_teacher.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_em/self_training/fix_match.py b/torch_em/self_training/fix_match.py index bdb5d495..c3fd6ac5 100644 --- a/torch_em/self_training/fix_match.py +++ b/torch_em/self_training/fix_match.py @@ -162,7 +162,7 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): teacher_input, model_input = xu1, xu2 - with torch.no_grad(): + with forward_context(), torch.no_grad(): # Compute the pseudo labels. pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) @@ -216,7 +216,7 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): teacher_input, model_input = xu1, xu2 - with torch.no_grad(): + with forward_context(), torch.no_grad(): # Compute the pseudo labels. pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index e675ce81..2b37c875 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -197,7 +197,7 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): teacher_input, model_input = xu1, xu2 - with torch.no_grad(): + with forward_context(), torch.no_grad(): # Compute the pseudo labels. pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) @@ -248,7 +248,7 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): teacher_input, model_input = xu1, xu2 - with torch.no_grad(): + with forward_context(), torch.no_grad(): # Compute the pseudo labels. pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) From 1be1963a9b66b4fecc50717cd6057babd8ef3073 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 1 Apr 2023 21:35:05 +0200 Subject: [PATCH 6/7] Update Mean-Teacher Training - Loading the Checkpoints --- .../livecell/unet_mean_teacher.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py index d4202659..5a5f08a4 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py @@ -3,6 +3,7 @@ import pandas as pd import torch import torch_em.self_training as self_training +from torch_em.util import load_model import common @@ -22,14 +23,17 @@ def check_loader(args, n_images=5): def _train_source_target(args, source_cell_type, target_cell_type): + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = common.get_unet() if args.save_root is None: src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" else: src_checkpoint = args.save_root + f"checkpoints/unet_source/{source_cell_type}" - model = common.load_model(model, src_checkpoint) + model = load_model(checkpoint=src_checkpoint, model=model, device=device) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) # self training functionality @@ -48,8 +52,6 @@ def _train_source_target(args, source_cell_type, target_cell_type): teacher_augmentation="weak", student_augmentation="weak", ) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - name = f"unet_mean_teacher/thresh-{thresh}/{source_cell_type}/{target_cell_type}" trainer = self_training.MeanTeacherTrainer( name=name, From 6be42fa0557816850382c48b6c973d76c5e6f336 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sat, 1 Apr 2023 22:02:30 +0200 Subject: [PATCH 7/7] Update README.md - Script Documentation --- .../probabilistic_domain_adaptation/README.md | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/experiments/probabilistic_domain_adaptation/README.md b/experiments/probabilistic_domain_adaptation/README.md index e69de29b..d7d76f87 100644 --- a/experiments/probabilistic_domain_adaptation/README.md +++ b/experiments/probabilistic_domain_adaptation/README.md @@ -0,0 +1,59 @@ +# Probabilistic Domain Adaption + +Implemention of [Probabilistic Domain Adaptation for Biomedical Image Segmentation](https://arxiv.org/abs/2303.11790) in `torch_em`. +Please cite the paper if you are using these approaches in your research. + +## Self-Training Approaches + +The subfolders contain the training scripts for both separate and joint training setups: + +- `unet_source.py` (UNet Source Training): +``` +python unet_source.py -p [check / train / evaluate] + -c + -i + -s + -o +``` + +- `unet_mean_teacher.py` (UNet Mean-Teacher Separate Training): +``` +python unet_mean_teacher.py -p [check / train / evaluate] + -c + -i + -s + -o + [(optional) --confidence_threshold ] +``` + +- `unet_adamt.py` (UNet Mean-Teacher Joint Training): +``` +python unet_adamt.py -p [check / train / evaluate] + -c + -i + -s + -o + [(optional) --confidence_threshold ] +``` + +- `unet_fixmatch.py` (UNet FixMatch Separate Training): +``` +python unet_fixmatch.py -p [check / train / evaluate] + -c + -i + -s + -o + [(optional) --confidence_threshold ] + [(optional) --distribution_alignment ] +``` + +- `unet_adamatch.py` (UNet FixMatch Joint Training): +``` +python unet_adamatch.py -p [check / train / evaluate] + -c + -i + -s + -o + [(optional) --confidence_threshold ] + [(optional) --distribution_alignment ] +```