diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index 099875f9..1ebe6acd 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -112,6 +112,18 @@ def compute_class_distribution(root_folder, label_threshold=0.5): return [bg_frequency, fg_frequency] +def get_punet_predictions(model, inputs): + activation = torch.nn.Sigmoid() + prior_samples = 16 + + with torch.no_grad(): + model.forward(inputs) + samples_per_input = [activation(model.sample(testing=True))for _ in range(prior_samples)] + avg_pred = torch.stack(samples_per_input, dim=0).sum(dim=0) / prior_samples + + return avg_pred + + # use get_model and prediction_function to customize this, e.g. for using it with the PUNet # set model_state to "teacher_state" when using this with a mean-teacher method def evaluate_transfered_model( @@ -121,9 +133,12 @@ def evaluate_transfered_model( label_root = os.path.join(args.input, "annotations", "livecell_test_images") results = {"src": [ct_src]} - device = torch.device("cuda") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") thresh = args.confidence_threshold + if thresh is None: + assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking" + with torch.no_grad(): for ct_trg in CELL_TYPES: @@ -134,10 +149,15 @@ def evaluate_transfered_model( if args.output is None: out_folder = None else: + out_folder = args.output + f"thresh-{thresh}" + + if args.consensus_masking: + out_folder = out_folder + "-masking" + if args.distribution_alignment: - out_folder = os.path.join(args.output, f"thresh-{thresh}-distro-align/", ct_src, ct_trg) - else: - out_folder = os.path.join(args.output, f"thresh-{thresh}", ct_src, ct_trg) + out_folder = out_folder + "-distro-align" + + out_folder = os.path.join(out_folder, ct_src, ct_trg) if out_folder is not None: os.makedirs(out_folder, exist_ok=True) @@ -145,10 +165,16 @@ def evaluate_transfered_model( if args.save_root is None: ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" else: + ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}" + + if args.consensus_masking: + ckpt = ckpt + "-masking" + if args.distribution_alignment: - ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}-distro-align/{ct_src}/{ct_trg}" - else: - ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" + ckpt = ckpt + "-distro-align" + + ckpt = os.path.join(ckpt, ct_src, ct_trg) + model = get_model() model = load_model(checkpoint=ckpt, model=model, state_key=model_state, device=device) @@ -184,21 +210,9 @@ def evaluate_transfered_model( return pd.DataFrame(results) -def get_punet_predictions(model, inputs): - activation = torch.nn.Sigmoid() - prior_samples = 16 - - with torch.no_grad(): - model.forward(inputs) - samples_per_input = [activation(model.sample(testing=True))for _ in range(prior_samples)] - avg_pred = torch.stack(samples_per_input, dim=0).sum(dim=0) / prior_samples - - return avg_pred - - # use get_model and prediction_function to customize this, e.g. for using it with the PUNet def evaluate_source_model(args, ct_src, method, get_model=get_unet, prediction_function=None): - device = torch.device("cuda") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if args.save_root is None: ckpt = f"checkpoints/{method}/{ct_src}" @@ -310,4 +324,5 @@ def get_parser(default_batch_size=8, default_iterations=int(1e5)): parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) parser.add_argument("--target_ct", nargs="+", default=None) parser.add_argument("-o", "--output") + parser.add_argument("--distribution_alignment", action='store_true') return parser diff --git a/experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py new file mode 100644 index 00000000..5c9fbf8c --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py @@ -0,0 +1,134 @@ +import os + +import pandas as pd +import torch +import torch_em.self_training as self_training + +import common + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = common.get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="weak", + ) + check_loader(loader, n_images) + + +def _train_source_target(args, source_cell_type, target_cell_type): + model = common.get_punet() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) + + # self training functionality + # - when thresh is None, we don't mask the reconstruction loss (RL) with label filters + # - when thresh is passed (float), we mask the RL with weighted consensus label filters + # - when thresh is passed (float) with args.consensus_masking, we mask the RL with masked consensus label filters + thresh = args.confidence_threshold + if thresh is None: + assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking" + + pseudo_labeler = self_training.ProbabilisticPseudoLabeler(activation=torch.nn.Sigmoid(), + confidence_threshold=thresh, prior_samples=16, + consensus_masking=args.consensus_masking) + loss = self_training.ProbabilisticUNetLoss() + loss_and_metric = self_training.ProbabilisticUNetLossAndMetric() + + # data loaders + supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size) + supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1) + unsupervised_train_loader = common.get_unsupervised_loader( + args, args.batch_size, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + unsupervised_val_loader = common.get_unsupervised_loader( + args, 1, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + if args.consensus_masking: + name = f"punet_adamt/thresh-{thresh}-masking" + else: + name = f"punet_adamt/thresh-{thresh}" + + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + supervised_train_loader=supervised_train_loader, + unsupervised_train_loader=unsupervised_train_loader, + supervised_val_loader=supervised_val_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=None, + mixed_precision=True, + device=device, + log_image_interval=100, + save_root=args.save_root, + compile_model=False + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + if args.target_ct is None: + target_cell_list = common.CELL_TYPES + else: + target_cell_list = args.target_ct + + for target_cell_type in target_cell_list: + print("Training on target cell type:", target_cell_type) + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for source cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "punet_adamt", model_state="teacher_state") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "punet_adamt.csv"), index=False) + + +def main(): + parser = common.get_parser(default_iterations=100000, default_batch_size=4) + parser.add_argument("--confidence_threshold", default=None, type=float) + parser.add_argument("--consensus_masking", action='store_true') + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/punet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/punet_mean_teacher.py new file mode 100644 index 00000000..9866b3ab --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/punet_mean_teacher.py @@ -0,0 +1,141 @@ +import os + +import pandas as pd +import torch +import torch_em.self_training as self_training +from torch_em.util import load_model + +import common + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = common.get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="weak", + ) + check_loader(loader, n_images) + + +def _train_source_target(args, source_cell_type, target_cell_type): + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + model = common.get_punet() + if args.save_root is None: + src_checkpoint = f"./checkpoints/punet_source/{source_cell_type}" + else: + src_checkpoint = args.save_root + f"checkpoints/punet_source/{source_cell_type}" + model = load_model(checkpoint=src_checkpoint, model=model, device=device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) + + # self training functionality + # - when thresh is None, we don't mask the reconstruction loss (RL) with label filters + # - when thresh is passed (float), we mask the RL with weighted consensus label filters + # - when thresh is passed (float) with args.consensus_masking, we mask the RL with masked consensus label filters + thresh = args.confidence_threshold + if thresh is None: + assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking" + + pseudo_labeler = self_training.ProbabilisticPseudoLabeler(activation=torch.nn.Sigmoid(), + confidence_threshold=thresh, prior_samples=16, + consensus_masking=args.consensus_masking) + loss = self_training.ProbabilisticUNetLoss() + loss_and_metric = self_training.ProbabilisticUNetLossAndMetric() + + # data loaders + unsupervised_train_loader = common.get_unsupervised_loader( + args, args.batch_size, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + unsupervised_val_loader = common.get_unsupervised_loader( + args, 1, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + + if args.consensus_masking: + name = f"punet_mean_teacher/thresh-{thresh}-masking/{source_cell_type}/{target_cell_type}" + else: + name = f"punet_mean_teacher/thresh-{thresh}/{source_cell_type}/{target_cell_type}" + + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + unsupervised_train_loader=unsupervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=None, + mixed_precision=True, + device=device, + log_image_interval=100, + save_root=args.save_root, + compile_model=False + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + if args.target_ct is None: + target_cell_list = common.CELL_TYPES + else: + target_cell_list = args.target_ct + + for target_cell_type in target_cell_list: + print("Training on target cell type:", target_cell_type) + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for source cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "punet_mean_teacher", + get_model=common.get_punet, + model_state="teacher_state", + prediction_function=common.get_punet_predictions) + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "punet_mean_teacher.csv"), index=False) + + +def main(): + parser = common.get_parser(default_iterations=10000, default_batch_size=4) + parser.add_argument("--confidence_threshold", default=None, type=float) + parser.add_argument("--consensus_masking", action='store_true') + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py index 31350d8c..9b9f9efa 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -118,5 +118,4 @@ def main(): if __name__ == "__main__": - # break main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py index 36496f7a..8ee5413e 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py @@ -52,8 +52,6 @@ def _train_source_target(args, source_cell_type, target_cell_type): teacher_augmentation="weak", student_augmentation="strong-separate", ) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - name = f"unet_fixmatch/thresh-{thresh}" if args.distribution_alignment: diff --git a/torch_em/self_training/__init__.py b/torch_em/self_training/__init__.py index a0b3e4ad..89227a0d 100644 --- a/torch_em/self_training/__init__.py +++ b/torch_em/self_training/__init__.py @@ -3,5 +3,5 @@ ProbabilisticUNetLossAndMetric from .mean_teacher import MeanTeacherTrainer from .fix_match import FixMatchTrainer -from .pseudo_labeling import DefaultPseudoLabeler +from .pseudo_labeling import DefaultPseudoLabeler, ProbabilisticPseudoLabeler from .probabilistic_unet_trainer import ProbabilisticUNetTrainer, DummyLoss diff --git a/torch_em/self_training/loss.py b/torch_em/self_training/loss.py index a574fe80..963e95d8 100644 --- a/torch_em/self_training/loss.py +++ b/torch_em/self_training/loss.py @@ -82,11 +82,11 @@ def __init__(self, loss=None): super().__init__() self.loss = loss - def __call__(self, model, input_, labels): + def __call__(self, model, input_, labels, label_filter=None): model.forward(input_, labels) if self.loss is None: - elbo = model.elbo(labels) + elbo = model.elbo(labels, label_filter) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ l2_regularisation(model.fcomb.layers) loss = -elbo + 1e-5 * reg_loss @@ -112,11 +112,11 @@ def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), self.loss = loss self.prior_samples = prior_samples - def __call__(self, model, input_, labels): + def __call__(self, model, input_, labels, label_filter=None): model.forward(input_, labels) if self.loss is None: - elbo = model.elbo(labels) + elbo = model.elbo(labels, label_filter) reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \ l2_regularisation(model.fcomb.layers) loss = -elbo + 1e-5 * reg_loss diff --git a/torch_em/self_training/pseudo_labeling.py b/torch_em/self_training/pseudo_labeling.py index dc3d2a9d..85081da9 100644 --- a/torch_em/self_training/pseudo_labeling.py +++ b/torch_em/self_training/pseudo_labeling.py @@ -42,3 +42,66 @@ def __call__(self, teacher, input_): label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ else self._compute_label_mask_one_side(pseudo_labels) return pseudo_labels, label_mask + + +class ProbabilisticPseudoLabeler: + """Compute pseudo labels from the Probabilistic UNet. + + Parameters: + activation [nn.Module, callable] - activation function applied to the teacher prediction. + confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. + If none is given no mask will be computed (default: None) + threshold_from_both_sides [bool] - whether to include both values bigger than the threshold + and smaller than 1 - it, or only values bigger than it in the mask. + The former should be used for binary labels, the latter for for multiclass labels (default: False) + prior_samples [int] - the number of times we want to sample from the + prior distribution per inputs (default: 16) + consensus_masking [bool] - whether to activate consensus masking in the label filter (default: False) + If false, the weighted consensus response (weighted per-pixel response) is returned + If true, the masked consensus response (complete aggrement of pixels) is returned + """ + def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True, + prior_samples=16, consensus_masking=False): + self.activation = activation + self.confidence_threshold = confidence_threshold + self.threshold_from_both_sides = threshold_from_both_sides + self.prior_samples = prior_samples + self.consensus_masking = consensus_masking + # TODO serialize the class names and kwargs for activation instead + self.init_kwargs = { + "activation": None, "confidence_threshold": confidence_threshold, + "threshold_from_both_sides": threshold_from_both_sides + } + + def _compute_label_mask_both_sides(self, pseudo_labels): + upper_threshold = self.confidence_threshold + lower_threshold = 1.0 - self.confidence_threshold + mask = [torch.where((sample >= upper_threshold) + (sample <= lower_threshold), + torch.tensor(1.), + torch.tensor(0.)) for sample in pseudo_labels] + return mask + + def _compute_label_mask_one_side(self, pseudo_labels): + mask = [torch.where((sample >= self.confidence_threshold), + torch.tensor(1.), + torch.tensor(0.)) for sample in pseudo_labels] + return mask + + def __call__(self, teacher, input_): + teacher.forward(input_) + if self.activation is not None: + pseudo_labels = [self.activation(teacher.sample()) for _ in range(self.prior_samples)] + else: + pseudo_labels = [teacher.sample() for _ in range(self.prior_samples)] + pseudo_labels = torch.stack(pseudo_labels, dim=0).sum(dim=0)/self.prior_samples + + if self.confidence_threshold is None: + label_mask = None + else: + label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides \ + else self._compute_label_mask_one_side(pseudo_labels) + label_mask = torch.stack(label_mask, dim=0).sum(dim=0)/self.prior_samples + if self.consensus_masking: + label_mask = torch.where(label_mask == 1, 1, 0) + + return pseudo_labels, label_mask