From 47a4b51c9092e1953636dd3a971cdbc954743bf5 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 09:52:51 +0100 Subject: [PATCH 01/21] Implement MeanTeacher Trainer WIP --- torch_em/self_training/fix_match.py | 0 torch_em/self_training/logger.py | 56 +++++ torch_em/self_training/mean_teacher.py | 315 +++++++++++++++++++++++++ 3 files changed, 371 insertions(+) create mode 100644 torch_em/self_training/fix_match.py create mode 100644 torch_em/self_training/logger.py create mode 100644 torch_em/self_training/mean_teacher.py diff --git a/torch_em/self_training/fix_match.py b/torch_em/self_training/fix_match.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_em/self_training/logger.py b/torch_em/self_training/logger.py new file mode 100644 index 00000000..bddc8c3c --- /dev/null +++ b/torch_em/self_training/logger.py @@ -0,0 +1,56 @@ +import os + +import torch +import torch_em + + +class SelfTrainingTensorboardLogger(torch_em.traner.logger_base.TorchEmLogger): + def __init__(self, trainer, save_root, **unused_kwargs): + super().__init__(trainer, save_root) + self.my_root = save_root + self.log_dir = f"./logs/{trainer.name}" if self.my_root is None else\ + os.path.join(self.my_root, "logs", trainer.name) + os.makedirs(self.log_dir, exist_ok=True) + + self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) + self.log_image_interval = trainer.log_image_interval + + # TODO make a grid image + def _add_supervised_images(self): + pass + + # TODO make a grid image + def _add_unsupervised_images(self): + pass + + def log_combined_loss(self, step, loss): + self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) + + def log_lr(self, step, lr): + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + + def log_train_supervised(self, step, loss, x, y, pred): + self.tb.add_scalar(tag="train/supervised/loss", scalar_value=loss, global_step=step) + if step % self.log_image_interval == 0: + self._add_supervised_images(step, "validation", x, y, pred) + + def log_validation_supervised(self, step, metric, loss, x, y, pred): + self.tb.add_scalar(tag="validation/supervised/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/supervised/metric", scalar_value=metric, global_step=step) + self._add_supervised_images(step, "validation", x, y, pred) + + def log_train_unsupervised(self, step, loss, x1, x2, pred, pseudo_labels, label_filter=None): + self.tb.add_scalar(tag="train/unsupervised/loss", scalar_value=loss, global_step=step) + if step % self.log_image_interval == 0: + self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) + + def log_validation_unsupervised(self, step, metric, loss, x1, x2, pred, pseudo_labels, label_filter=None): + self.tb.add_scalar(tag="validation/unsupervised/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/unsupervised/metric", scalar_value=metric, global_step=step) + self._add_unsupervised_images(step, "validation", x1, x2, pred, pseudo_labels, label_filter) + + def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt_metric=None): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + if gt_metric is not None: + self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step) diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py new file mode 100644 index 00000000..97ad0580 --- /dev/null +++ b/torch_em/self_training/mean_teacher.py @@ -0,0 +1,315 @@ +import time +from copy import deepcopy + +import torch +import torch_em + + +class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): + """This trainer implements self-traning for semi-supervised learning and domain following the 'MeanTeacher' approach + of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from the + student model via EMA of weights to predict pseudo-labels on unlabeled data. + We support two training strategies: joint training on labeled and unlabeled data + (with a supervised and unsupervised loss function). And training only on the unsupervised data. + + This class expects the following data loaders: + - unsupervised_train_loader: Returns two augmentations of the same input. + - supervised_train_loader (optional): Returns input and labels. + + And the following elements to customize the pseudo labeling: + - pseudo_labeler: to compute the psuedo-labels + - Parameters: teacher, teacher_input + - Returns: pseudo_labels, label_filter (<- label filter can for example be mask, weight or None) + - unsupervised_loss: the loss between model predictions and pseudo labels + - Parameters: model, model_input, pseudo_labels, label_filter + - Returns: loss + - supervised_loss (optional): the supervised loss function + - Parameters: model, input, labels + - Returns: loss + - unsupervised_loss_and_metric (optional): the unsupervised loss function and metric + - Parameters: model, model_input, pseudo_labels, label_filter + - Returns: loss, metric + - supervised_loss_and_metric (optional): the supervised loss function and metric + - Parameters: model, input, labels + - Returns: loss, metric + At least one of unsupervised_loss_and_metric and supervised_loss_and_metric must be given. + + If the parameter reinit_teacher is set to true, the teacher weights are re-initialized. + If it is None, the most appropriate initialization scheme for the training approach is chosen: + - semi-supervised training -> reinit, because we usually train a model from scratch + - unsupervised training -> do not reinit, because we usually fine-tune a model + + Note: adjust the batch size ratio between the 'unsupervised_train_loader' and 'supervised_train_loader' + for setting the ratio between supervised and unsupervised training samples + + Parameters: + unsupervised_train_loader [torch.DataLoader] - + unsupervised_loss [callable] - + supervised_train_loader [torch.DataLoader] - (default: None) + supervised_loss [callable] - (default: None) + unsupervised_loss_and_metric [callable] - (default: None) + supervised_loss_and_metric [callable] - (default: None) + momentum [float] - (default: 0.999) + reinit_teacher [bool] - (default: None) + **kwargs - keyword arguments for torch_em.DataLoader + """ + + def __init__( + self, + unsupervised_train_loader, + unsupervised_loss, + pseudo_labeler, + supervised_train_loader=None, + supervised_loss=None, + supervised_loss_and_metric=None, + unsupervised_loss_and_metric=None, + momentum=0.999, + reinit_teacher=None, + **kwargs + ): + # Do we have supervised data or not? + if supervised_train_loader is None: + # No. -> We use the unsupervised training logic. + train_loader = unsupervised_train_loader + self._train_epoch_impl = self._train_epoch_unsupervised + else: + # Yes. -> We use the semi-supervised training logic. + assert supervised_loss is not None + train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ + else unsupervised_train_loader + self._train_epoch_impl = self._train_epoch_semisupervised + + # Check that we have at least one of supvervised / unsuperrvised loss and metric + assert sum(( + supervised_loss_and_metric is None, + unsupervised_loss_and_metric is None, + )) > 0 + self.supervised_loss_and_metric = supervised_loss_and_metric + self.unsupervised_loss_and_metric = unsupervised_loss_and_metric + + super().__init__(train_loader=train_loader, **kwargs) + self.unsupervised_train_loader = unsupervised_train_loader + self.supervised_train_loader = supervised_train_loader + + self.unsupervised_loss = unsupervised_loss + self.supervised_loss = supervised_loss + + self.pseudo_labeler = pseudo_labeler + + self.momentum = momentum + self._kwargs = {"momentum": momentum, "reinit_teacher": reinit_teacher, **kwargs} + + # determine how we initialize the teacher weights (copy or reinitialization) + if reinit_teacher is None: + # semisupervised training: reinitialize + # unsupervised training: copy + self.reinit_teacher = supervised_train_loader is not None + else: + self.reinit_teacher = reinit_teacher + + with torch.no_grad(): + self.teacher = deepcopy(self.model) + if self.reinit_teacher: + for layer in self.teacher.children(): + if hasattr(layer, "reset_parameters"): + layer.reset_parameters() + for param in self.teacher.parameters(): + param.requires_grad = False + + def _momentum_update(self): + # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations + # to avoid a large gap between teacher and student weights, leading to inconsistent predictions + # if we don't reinit this is not necessary + if self.reinit_teacher: + current_momentum = min(1 - 1 / (self._iteration + 1), self.momentum) + else: + current_momentum = self.momentum + + for param, param_teacher in zip(self.model.parameters(), self.teacher.parameters()): + param_teacher.data = param_teacher.data * current_momentum + param.data * (1. - current_momentum) + + # TODO I think we need to serialize more things here for all the loaders etc. + # + # functionality for saving checkpoints and initialization + # + + def save_checkpoint(self, name, best_metric): + teacher_state = {"teacher_state": self.teacher.state_dict()} + super().save_checkpoint(name, best_metric, **teacher_state) + + def load_checkpoint(self, checkpoint="best"): + save_dict = super().load_checkpoint(checkpoint) + self.teacher.load_state_dict(save_dict["teacher_state"]) + self.teacher.to(self.device) + return save_dict + + def _initialize(self, iterations, load_from_checkpoint): + best_metric = super()._initialize(iterations, load_from_checkpoint) + self.teacher.to(self.device) + return best_metric + + # + # training and validation functionality + # + + def _train_epoch_unsupervised(self, progress, forward_context, backprop): + self.model.train() + + n_iter = 0 + t_per_iter = time.time() + + # Sample from both the supervised and unsupervised loader. + for xu1, xu2 in self.target_train_loader: + xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + + teacher_input, model_input = xu1, xu2 + # Perform unsupervised training + with forward_context(): + # Compute the pseudo labels. + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) + backprop(loss) + + if self.logger is not None: + with torch.no_grad(), forward_context(): + pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None + self.logger.log_train_unsupervised( + self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter + ) + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + self.logger.log_lr(lr) + + with torch.no_grad(): + self._momentum_update() + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _train_epoch_semisupervised(self, progress, forward_context, backprop): + self.model.train() + + n_iter = 0 + t_per_iter = time.time() + + # Sample from both the supervised and unsupervised loader. + for (xs, ys), (xu1, xu2) in zip(self.source_train_loader, self.target_train_loader): + xs, ys = xs.to(self.device), ys.to(self.device) + xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + + # Perform supervised training. + self.optimizer.zero_grad() + with forward_context(): + # We pass the model, the input and the labels to the supervised loss function, + # so that how the loss is calculated stays flexible, e.g. to enable ELBO for PUNet. + supervised_loss = self.supervised_loss(self.model, xs, ys) + + teacher_input, model_input = xu1, xu2 + # Perform unsupervised training + with forward_context(): + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + unsupervised_loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) + + loss = (supervised_loss + unsupervised_loss) / 2 + backprop(loss) + + if self.logger is not None: + with torch.no_grad(), forward_context(): + unsup_pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None + supervised_pred = self.model(xs) if self._iteration % self.log_image_interval == 0 else None + + self.logger.log_train_supervised(self._iteration, supervised_loss, xs, ys, supervised_pred) + self.logger.log_train_unsupervised( + self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter + ) + + self.logger.log_combined_loss(loss) + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + self.logger.log_lr(lr) + + with torch.no_grad(): + self._momentum_update() + + self._iteration += 1 + n_iter += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) / n_iter + return t_per_iter + + def _validate_supervised(self, forward_context): + metric_val = 0.0 + loss_val = 0.0 + + for x, y in self.supervised_val_loader: + x, y = x.to(self.device), y.to(self.device) + with forward_context(): + loss, metric = self.supervised_loss_and_metric(self.model, x, y) + loss_val += loss.item() + metric_val += metric.item() + + metric_val /= len(self.supervised_val_loader) + loss_val /= len(self.supervised_val_loader) + + if self.logger is not None: + with forward_context(): + pred = self.model(x) + self.logger.log_validation_supervised(self._iteration, metric_val, loss_val, x, y, pred) + + return metric_val + + def _validate_unsupervised(self, forward_context): + metric_val = 0.0 + loss_val = 0.0 + + for x1, x2 in self.unsupervised_val_loader: + x1, x2 = x1.to(self.device), x2.to(self.device) + teacher_input, model_input = x1, x2 + with forward_context(): + pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) + loss, metric = self.unsupervised_loss_and_metric(self.model, model_input, pseudo_labels, label_filter) + loss_val += loss.item() + metric_val += metric.item() + + metric_val /= len(self.supervised_val_loader) + loss_val /= len(self.supervised_val_loader) + + if self.logger is not None: + with forward_context(): + pred = self.model(model_input) + self.logger.log_validation_unsupervised( + self._iteration, metric_val, loss_val, x1, x2, pred, pseudo_labels, label_filter + ) + + return metric_val + + def _validate_impl(self, forward_context): + self.model.eval() + + with torch.no_grad(): + + if self.supervised_val_loader is None: + supervised_metric = None + else: + supervised_metric = self._validate_supervised(forward_context) + + if self.unsupervised_val_loader is None: + unsupervised_metric = None + else: + unsuperised_metric = self._validate_unsupervised(forward_context) + + if unsuperised_metric is None: + metric = supervised_metric + elif supervised_metric is None: + metric = unsupervised_metric + else: + metric = (supervised_metric + unsupervised_metric) / 2 + + return metric From 3ad4488870c81a806b70fca1ef0fb4d794baec6a Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 20:09:17 +0100 Subject: [PATCH 02/21] Mean teacher is running (not properly tested, not serializable or deserializable) --- test/self_training/__init__.py | 0 test/self_training/test_mean_teacher.py | 114 ++++++++++++++++++ test/test_segmentation.py | 5 +- torch_em/data/raw_dataset.py | 10 ++ torch_em/data/raw_image_collection_dataset.py | 10 ++ torch_em/self_training/__init__.py | 1 + torch_em/self_training/augmentations.py | 1 + torch_em/self_training/mean_teacher.py | 41 +++++-- 8 files changed, 170 insertions(+), 12 deletions(-) create mode 100644 test/self_training/__init__.py create mode 100644 test/self_training/test_mean_teacher.py create mode 100644 torch_em/self_training/__init__.py create mode 100644 torch_em/self_training/augmentations.py diff --git a/test/self_training/__init__.py b/test/self_training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/self_training/test_mean_teacher.py b/test/self_training/test_mean_teacher.py new file mode 100644 index 00000000..613ecd2d --- /dev/null +++ b/test/self_training/test_mean_teacher.py @@ -0,0 +1,114 @@ +import os +import unittest +from shutil import rmtree + +import torch +import torch_em +from torch_em.model import UNet2d +from torch_em.util.test import create_segmentation_test_data + + +def simple_pseudo_labeler(teacher, input_): + pseudo_labels = teacher(input_) + return pseudo_labels, None + + +def simple_unsupservised_loss(model, model_input, pseudo_labels, label_filter): + assert label_filter is None + pred = model(model_input) + loss = torch_em.loss.dice_score(pred, pseudo_labels, invert=True) + return loss + + +def simple_unsupervised_loss_and_metric(model, model_input, pseudo_labels, label_filter): + assert label_filter is None + pred = model(model_input) + loss = torch_em.loss.dice_score(pred, pseudo_labels, invert=True) + return loss, loss + + +class TestMeanTeacher(unittest.TestCase): + tmp_folder = "./tmp" + data_path = "./tmp/data.h5" + raw_key = "raw" + label_key = "labels" + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) + create_segmentation_test_data(self.data_path, self.raw_key, self.label_key, shape=(128,) * 3, chunks=(32,) * 3) + + def tearDown(self): + + def _remove(folder): + try: + rmtree(folder) + except OSError: + pass + + _remove(self.tmp_folder) + _remove("./logs") + _remove("./checkpoints") + + def _test_mean_teacher( + self, + unsupervised_train_loader, + supervised_train_loader=None, + unsupervised_val_loader=None, + supervised_val_loader=None, + supervised_loss=None, + supervised_loss_and_metric=None, + ): + from torch_em.self_training import MeanTeacherTrainer + + model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) + optimizer = torch.optim.Adam(model.parameters()) + + trainer = MeanTeacherTrainer( + name="mt-test", + model=model, + optimizer=optimizer, + device=torch.device("cpu"), + unsupervised_train_loader=unsupervised_train_loader, + supervised_train_loader=supervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_val_loader=supervised_val_loader, + pseudo_labeler=simple_pseudo_labeler, + unsupervised_loss=simple_unsupservised_loss, + unsupervised_loss_and_metric=simple_unsupervised_loss_and_metric, + supervised_loss=supervised_loss, + supervised_loss_and_metric=supervised_loss_and_metric, + logger=None, + mixed_precision=False, + ) + trainer.fit(53) + self.assertTrue("./checkpoints/best.pt") + self.assertTrue("./checkpoints/latest.pt") + # TODO check that deserializing and continuing to train works + + def get_unsupervised_loader(self, n_samples): + augmentations = ( + torch_em.transform.raw.GaussianBlur(), + torch_em.transform.raw.GaussianBlur(), + ) + ds = torch_em.data.RawDataset( + raw_path=self.data_path, + raw_key=self.raw_key, + patch_shape=(1, 64, 64), + n_samples=n_samples, + ndim=2, + augmentations=augmentations, + ) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=1, shuffle=True) + return loader + + def test_mean_teacher_unsupervised(self): + unsupervised_train_loader = self.get_unsupervised_loader(n_samples=50) + unsupervised_val_loader = self.get_unsupervised_loader(n_samples=4) + self._test_mean_teacher( + unsupervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_segmentation.py b/test/test_segmentation.py index b480da32..0fa6108f 100644 --- a/test/test_segmentation.py +++ b/test/test_segmentation.py @@ -57,8 +57,7 @@ def _test_training(self, model_class, model_kwargs, mixed_precision=False, device=torch.device("cpu"), logger=None) - train_iters = 51 - trainer.fit(train_iters) + trainer.fit(n_iterations) def _test_checkpoint(cp_path, check_progress): self.assertTrue(os.path.exists(cp_path)) @@ -71,7 +70,7 @@ def _test_checkpoint(cp_path, check_progress): loaded_model.load_state_dict(checkpoint["model_state"]) if check_progress: - self.assertEqual(checkpoint["iteration"], train_iters) + self.assertEqual(checkpoint["iteration"], n_iterations) self.assertEqual(checkpoint["epoch"], 2) _test_checkpoint("./checkpoints/test/latest.pt", True) diff --git a/torch_em/data/raw_dataset.py b/torch_em/data/raw_dataset.py index 3bd16cf1..ae1ab034 100644 --- a/torch_em/data/raw_dataset.py +++ b/torch_em/data/raw_dataset.py @@ -31,6 +31,7 @@ def __init__( sampler=None, ndim=None, with_channels=False, + augmentations=None, ): self.raw_path = raw_path self.raw_key = raw_key @@ -60,6 +61,10 @@ def __init__( self.sampler = sampler self.dtype = dtype + if augmentations is not None: + assert len(augmentations) == 2 + self.augmentations = augmentations + self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples self.sample_shape = patch_shape @@ -128,6 +133,11 @@ def __getitem__(self, index): raw = self.crop(raw) raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) + if self.augmentations is not None: + aug1, aug2 = self.augmentations + raw1, raw2 = aug1(raw), aug2(raw) + return raw1, raw2 + return raw # need to overwrite pickle to support h5py diff --git a/torch_em/data/raw_image_collection_dataset.py b/torch_em/data/raw_image_collection_dataset.py index 741a05bb..7dfecf6d 100644 --- a/torch_em/data/raw_image_collection_dataset.py +++ b/torch_em/data/raw_image_collection_dataset.py @@ -36,6 +36,7 @@ def __init__( dtype=torch.float32, n_samples=None, sampler=None, + augmentations=None, ): self._check_inputs(raw_image_paths) self.raw_images = raw_image_paths @@ -56,6 +57,10 @@ def __init__( self._len = n_samples self.sample_random_index = True + if augmentations is not None: + assert len(augmentations) == 2 + self.augmentations = augmentations + def __len__(self): return self._len @@ -100,6 +105,11 @@ def _get_sample(self, index): if have_raw_channels: raw = raw.transpose((2, 0, 1)) + if self.augmentations is not None: + aug1, aug2 = self.augmentations + raw1, raw2 = aug1(raw), aug2(raw) + return raw1, raw2 + return raw def __getitem__(self, index): diff --git a/torch_em/self_training/__init__.py b/torch_em/self_training/__init__.py new file mode 100644 index 00000000..68f93adc --- /dev/null +++ b/torch_em/self_training/__init__.py @@ -0,0 +1 @@ +from .mean_teacher import MeanTeacherTrainer diff --git a/torch_em/self_training/augmentations.py b/torch_em/self_training/augmentations.py new file mode 100644 index 00000000..f6d81a22 --- /dev/null +++ b/torch_em/self_training/augmentations.py @@ -0,0 +1 @@ +# TODO build up weak and strong augmentation libraries diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index 97ad0580..a4fa5a94 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -5,6 +5,10 @@ import torch_em +class Dummy(torch.nn.Module): + pass + + class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): """This trainer implements self-traning for semi-supervised learning and domain following the 'MeanTeacher' approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from the @@ -15,6 +19,9 @@ class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): This class expects the following data loaders: - unsupervised_train_loader: Returns two augmentations of the same input. - supervised_train_loader (optional): Returns input and labels. + - unsupervised_val_loader (optional): Same as unsupervised_train_loader + - supervised_val_loader (optional): Same as supervised_train_loader + At least one of unsupervised_val_loader and supervised_val_loader must be given. And the following elements to customize the pseudo labeling: - pseudo_labeler: to compute the psuedo-labels @@ -60,6 +67,8 @@ def __init__( unsupervised_loss, pseudo_labeler, supervised_train_loader=None, + unsupervised_val_loader=None, + supervised_val_loader=None, supervised_loss=None, supervised_loss_and_metric=None, unsupervised_loss_and_metric=None, @@ -78,8 +87,23 @@ def __init__( train_loader = supervised_train_loader if len(supervised_train_loader) < len(unsupervised_train_loader)\ else unsupervised_train_loader self._train_epoch_impl = self._train_epoch_semisupervised + self.unsupervised_train_loader = unsupervised_train_loader + self.supervised_train_loader = supervised_train_loader - # Check that we have at least one of supvervised / unsuperrvised loss and metric + # Check that we have at least one of supvervised / unsupervised val loader. + assert sum(( + supervised_val_loader is None, + unsupervised_val_loader is None, + )) > 0 + self.supervised_val_loader = supervised_val_loader + self.unsupervised_val_loader = unsupervised_val_loader + + if self.unsupervised_val_loader is None: + val_loader = self.supervised_val_loader + else: + val_loader = self.unsupervised_train_loader + + # Check that we have at least one of supvervised / unsupervised loss and metric. assert sum(( supervised_loss_and_metric is None, unsupervised_loss_and_metric is None, @@ -87,9 +111,8 @@ def __init__( self.supervised_loss_and_metric = supervised_loss_and_metric self.unsupervised_loss_and_metric = unsupervised_loss_and_metric - super().__init__(train_loader=train_loader, **kwargs) - self.unsupervised_train_loader = unsupervised_train_loader - self.supervised_train_loader = supervised_train_loader + # TODO we need to recover which loadder is which, and take care of the correct deserialization! + super().__init__(train_loader=train_loader, val_loader=val_loader, loss=Dummy(), metric=Dummy(), **kwargs) self.unsupervised_loss = unsupervised_loss self.supervised_loss = supervised_loss @@ -159,7 +182,7 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): t_per_iter = time.time() # Sample from both the supervised and unsupervised loader. - for xu1, xu2 in self.target_train_loader: + for xu1, xu2 in self.unsupervised_train_loader: xu1, xu2 = xu1.to(self.device), xu2.to(self.device) teacher_input, model_input = xu1, xu2 @@ -278,8 +301,8 @@ def _validate_unsupervised(self, forward_context): loss_val += loss.item() metric_val += metric.item() - metric_val /= len(self.supervised_val_loader) - loss_val /= len(self.supervised_val_loader) + metric_val /= len(self.unsupervised_val_loader) + loss_val /= len(self.unsupervised_val_loader) if self.logger is not None: with forward_context(): @@ -303,9 +326,9 @@ def _validate_impl(self, forward_context): if self.unsupervised_val_loader is None: unsupervised_metric = None else: - unsuperised_metric = self._validate_unsupervised(forward_context) + unsupervised_metric = self._validate_unsupervised(forward_context) - if unsuperised_metric is None: + if unsupervised_metric is None: metric = supervised_metric elif supervised_metric is None: metric = unsupervised_metric From 54e6b9c6549971ac17ed675401c9764508af8e45 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 20:30:40 +0100 Subject: [PATCH 03/21] MeanTeacher semi-supervised training runs (not properly tested) --- test/self_training/test_mean_teacher.py | 55 ++++++++++++++++++++++--- torch_em/self_training/mean_teacher.py | 10 ++--- 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/test/self_training/test_mean_teacher.py b/test/self_training/test_mean_teacher.py index 613ecd2d..db325ebc 100644 --- a/test/self_training/test_mean_teacher.py +++ b/test/self_training/test_mean_teacher.py @@ -27,6 +27,18 @@ def simple_unsupervised_loss_and_metric(model, model_input, pseudo_labels, label return loss, loss +def simple_supervised_loss(model, input_, labels): + pred = model(input_) + loss = torch_em.loss.dice_score(pred, labels, invert=True) + return loss + + +def simple_supervised_loss_and_metric(model, input_, labels): + pred = model(input_) + loss = torch_em.loss.dice_score(pred, labels, invert=True) + return loss, loss + + class TestMeanTeacher(unittest.TestCase): tmp_folder = "./tmp" data_path = "./tmp/data.h5" @@ -63,8 +75,9 @@ def _test_mean_teacher( model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) optimizer = torch.optim.Adam(model.parameters()) + name = "mt-test" trainer = MeanTeacherTrainer( - name="mt-test", + name=name, model=model, optimizer=optimizer, device=torch.device("cpu"), @@ -81,9 +94,19 @@ def _test_mean_teacher( mixed_precision=False, ) trainer.fit(53) - self.assertTrue("./checkpoints/best.pt") - self.assertTrue("./checkpoints/latest.pt") - # TODO check that deserializing and continuing to train works + self.assertTrue(os.path.exists(f"./checkpoints/{name}/best.pt")) + self.assertTrue(os.path.exists(f"./checkpoints/{name}/latest.pt")) + + # TODO + # # make sure that the trainer can be deserialized from the checkpoint + # trainer2 = MeanTeacherTrainer.from_checkpoint(os.path.join("./checkpoints", name), name="latest") + # self.assertEqual(trainer.iteration, trainer2.iteration) + # self.assertTrue(torch_em.util.model_is_equal(trainer.model, trainer2.model)) + # self.assertTrue(torch_em.util.model_is_equal(trainer.teacher, trainer2.teacher)) + + # # and that it can be trained further + # trainer2.fit(10) + # self.assertEqual(trainer2.iteration, 63) def get_unsupervised_loader(self, n_samples): augmentations = ( @@ -101,14 +124,36 @@ def get_unsupervised_loader(self, n_samples): loader = torch_em.segmentation.get_data_loader(ds, batch_size=1, shuffle=True) return loader + def get_supervised_loader(self, n_samples): + ds = torch_em.data.SegmentationDataset( + raw_path=self.data_path, raw_key=self.raw_key, + label_path=self.data_path, label_key=self.label_key, + patch_shape=(1, 64, 64), ndim=2, + n_samples=n_samples, + ) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=1, shuffle=True) + return loader + def test_mean_teacher_unsupervised(self): unsupervised_train_loader = self.get_unsupervised_loader(n_samples=50) unsupervised_val_loader = self.get_unsupervised_loader(n_samples=4) self._test_mean_teacher( - unsupervised_train_loader, + unsupervised_train_loader=unsupervised_train_loader, unsupervised_val_loader=unsupervised_val_loader ) + def test_mean_teacher_semisupervised(self): + unsupervised_train_loader = self.get_unsupervised_loader(n_samples=50) + supervised_train_loader = self.get_supervised_loader(n_samples=50) + supervised_val_loader = self.get_supervised_loader(n_samples=4) + self._test_mean_teacher( + unsupervised_train_loader=unsupervised_train_loader, + supervised_train_loader=supervised_train_loader, + supervised_val_loader=supervised_val_loader, + supervised_loss=simple_supervised_loss, + supervised_loss_and_metric=simple_supervised_loss_and_metric, + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index a4fa5a94..fc16102d 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -92,8 +92,8 @@ def __init__( # Check that we have at least one of supvervised / unsupervised val loader. assert sum(( - supervised_val_loader is None, - unsupervised_val_loader is None, + supervised_val_loader is not None, + unsupervised_val_loader is not None, )) > 0 self.supervised_val_loader = supervised_val_loader self.unsupervised_val_loader = unsupervised_val_loader @@ -105,8 +105,8 @@ def __init__( # Check that we have at least one of supvervised / unsupervised loss and metric. assert sum(( - supervised_loss_and_metric is None, - unsupervised_loss_and_metric is None, + supervised_loss_and_metric is not None, + unsupervised_loss_and_metric is not None, )) > 0 self.supervised_loss_and_metric = supervised_loss_and_metric self.unsupervised_loss_and_metric = unsupervised_loss_and_metric @@ -221,7 +221,7 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): t_per_iter = time.time() # Sample from both the supervised and unsupervised loader. - for (xs, ys), (xu1, xu2) in zip(self.source_train_loader, self.target_train_loader): + for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): xs, ys = xs.to(self.device), ys.to(self.device) xu1, xu2 = xu1.to(self.device), xu2.to(self.device) From 3090e1def43dc236ff02a4d03b00387a6f5c94e8 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 21:41:58 +0100 Subject: [PATCH 04/21] Update pseudo-labeling functionality --- test/self_training/test_mean_teacher.py | 52 +++++----------------- torch_em/self_training/__init__.py | 2 + torch_em/self_training/loss.py | 53 +++++++++++++++++++++++ torch_em/self_training/pseudo_labeling.py | 38 ++++++++++++++++ 4 files changed, 104 insertions(+), 41 deletions(-) create mode 100644 torch_em/self_training/loss.py create mode 100644 torch_em/self_training/pseudo_labeling.py diff --git a/test/self_training/test_mean_teacher.py b/test/self_training/test_mean_teacher.py index db325ebc..b3b406e6 100644 --- a/test/self_training/test_mean_teacher.py +++ b/test/self_training/test_mean_teacher.py @@ -4,41 +4,11 @@ import torch import torch_em +import torch_em.self_training as self_training from torch_em.model import UNet2d from torch_em.util.test import create_segmentation_test_data -def simple_pseudo_labeler(teacher, input_): - pseudo_labels = teacher(input_) - return pseudo_labels, None - - -def simple_unsupservised_loss(model, model_input, pseudo_labels, label_filter): - assert label_filter is None - pred = model(model_input) - loss = torch_em.loss.dice_score(pred, pseudo_labels, invert=True) - return loss - - -def simple_unsupervised_loss_and_metric(model, model_input, pseudo_labels, label_filter): - assert label_filter is None - pred = model(model_input) - loss = torch_em.loss.dice_score(pred, pseudo_labels, invert=True) - return loss, loss - - -def simple_supervised_loss(model, input_, labels): - pred = model(input_) - loss = torch_em.loss.dice_score(pred, labels, invert=True) - return loss - - -def simple_supervised_loss_and_metric(model, input_, labels): - pred = model(input_) - loss = torch_em.loss.dice_score(pred, labels, invert=True) - return loss, loss - - class TestMeanTeacher(unittest.TestCase): tmp_folder = "./tmp" data_path = "./tmp/data.h5" @@ -69,29 +39,28 @@ def _test_mean_teacher( supervised_val_loader=None, supervised_loss=None, supervised_loss_and_metric=None, + unsupervised_loss_and_metric=None, ): - from torch_em.self_training import MeanTeacherTrainer - model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) optimizer = torch.optim.Adam(model.parameters()) name = "mt-test" - trainer = MeanTeacherTrainer( + trainer = self_training.MeanTeacherTrainer( name=name, model=model, optimizer=optimizer, - device=torch.device("cpu"), + pseudo_labeler=self_training.DefaultPseudoLabeler(), + unsupervised_loss=self_training.DefaultSelfTrainingLoss(), + unsupervised_loss_and_metric=unsupervised_loss_and_metric, unsupervised_train_loader=unsupervised_train_loader, supervised_train_loader=supervised_train_loader, unsupervised_val_loader=unsupervised_val_loader, supervised_val_loader=supervised_val_loader, - pseudo_labeler=simple_pseudo_labeler, - unsupervised_loss=simple_unsupservised_loss, - unsupervised_loss_and_metric=simple_unsupervised_loss_and_metric, supervised_loss=supervised_loss, supervised_loss_and_metric=supervised_loss_and_metric, logger=None, mixed_precision=False, + device=torch.device("cpu"), ) trainer.fit(53) self.assertTrue(os.path.exists(f"./checkpoints/{name}/best.pt")) @@ -139,7 +108,8 @@ def test_mean_teacher_unsupervised(self): unsupervised_val_loader = self.get_unsupervised_loader(n_samples=4) self._test_mean_teacher( unsupervised_train_loader=unsupervised_train_loader, - unsupervised_val_loader=unsupervised_val_loader + unsupervised_val_loader=unsupervised_val_loader, + unsupervised_loss_and_metric=self_training.DefaultSelfTrainingLossAndMetric(), ) def test_mean_teacher_semisupervised(self): @@ -150,8 +120,8 @@ def test_mean_teacher_semisupervised(self): unsupervised_train_loader=unsupervised_train_loader, supervised_train_loader=supervised_train_loader, supervised_val_loader=supervised_val_loader, - supervised_loss=simple_supervised_loss, - supervised_loss_and_metric=simple_supervised_loss_and_metric, + supervised_loss=self_training.DefaultSelfTrainingLoss(), + supervised_loss_and_metric=self_training.DefaultSelfTrainingLossAndMetric(), ) diff --git a/torch_em/self_training/__init__.py b/torch_em/self_training/__init__.py index 68f93adc..5de64128 100644 --- a/torch_em/self_training/__init__.py +++ b/torch_em/self_training/__init__.py @@ -1 +1,3 @@ +from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric from .mean_teacher import MeanTeacherTrainer +from .pseudo_labeling import DefaultPseudoLabeler diff --git a/torch_em/self_training/loss.py b/torch_em/self_training/loss.py new file mode 100644 index 00000000..030c116b --- /dev/null +++ b/torch_em/self_training/loss.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch_em + + +class DefaultSelfTrainingLoss(nn.Module): + """Loss function for self training. + + Parameters: + loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) + activation [nn.Module, callable] - the activation function to be applied to the prediction + before passing it to the loss. (default: None) + """ + def __init__(self, loss=torch_em.loss.DiceLoss(), activation=None): + super().__init__() + self.activation = activation + self.loss = loss + + def __call__(self, model, input_, labels, label_filter=None): + prediction = model(input_) + if self.activation is not None: + prediction = self.activation(prediction) + if label_filter is None: + loss = self.loss(prediction, labels) + else: + loss = self.loss(prediction * label_filter, labels * label_filter) + return loss + + +class DefaultSelfTrainingLossAndMetric(nn.Module): + """Loss and metric function for self training. + + Parameters: + loss [nn.Module] - the loss function to be used. (default: torch_em.loss.DiceLoss) + metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss) + activation [nn.Module, callable] - the activation function to be applied to the prediction + before passing it to the loss. (default: None) + """ + def __init__(self, loss=torch_em.loss.DiceLoss(), metric=torch_em.loss.DiceLoss(), activation=None): + super().__init__() + self.activation = activation + self.loss = loss + self.metric = metric + + def __call__(self, model, input_, labels, label_filter=None): + prediction = model(input_) + if self.activation is not None: + prediction = self.activation(prediction) + if label_filter is None: + loss = self.loss(prediction, labels) + else: + loss = self.loss(prediction * label_filter, labels * label_filter) + metric = self.metric(prediction, labels) + return loss, metric diff --git a/torch_em/self_training/pseudo_labeling.py b/torch_em/self_training/pseudo_labeling.py new file mode 100644 index 00000000..9e2204c5 --- /dev/null +++ b/torch_em/self_training/pseudo_labeling.py @@ -0,0 +1,38 @@ + + +class DefaultPseudoLabeler: + """Compute pseudo labels. + + Parameters: + activation [nn.Module, callable] - activation function applied to the teacher prediction. + confidence_threshold [float] - threshold for computing a mask for filterign the pseudo labels. + If none is given no mask will be computed (default: None) + threshold_from_both_sides [bool] - whether to include both values bigger than the threshold + and smaller than 1 - it, or only values bigger than it in the mask. + The former should be used for binary labels, the latter for for multiclass labels (default: False) + """ + def __init__(self, activation=None, confidence_threshold=None, threshold_from_both_sides=True): + self.activation = activation + self.confidence_threshold = confidence_threshold + self.threshold_from_both_sides = threshold_from_both_sides + + def _compute_label_mask_both_sides(self, pseudo_labels): + upper_threshold = self.confidence_threshold + lower_threshold = 1.0 - self.confidence_threshold + mask = (pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold) + return mask + + def _compute_label_mask_one_sides(self, pseudo_labels): + mask = (pseudo_labels >= self.confidence_threshold) + return mask + + def __call__(self, teacher, input_): + pseudo_labels = teacher(input_) + if self.activation is not None: + pseudo_labels = self.activation(pseudo_labels) + if self.confidence_threshold is None: + label_mask = None + else: + label_mask = self._compute_label_mask_both_side(pseudo_labels) if self.threshold_from_both_sides\ + else self._compute_label_mask_one_sides(pseudo_labels) + return pseudo_labels, label_mask From 451553e9d693ce4467900cf9e05bd768454e1ccd Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 22:05:59 +0100 Subject: [PATCH 05/21] Start implementing self-trainign experiments --- .../probabilistic_domain_adaptation/README.md | 0 .../livecell/README.md | 0 .../livecell/unet_adamt.py | 0 .../livecell/unet_mean_teacher.py | 0 .../livecell/unet_source.py | 80 +++++++++++++++++++ .../mitochondria/README.md | 0 .../mitochondria/unet_adamt.py | 0 .../mitochondria/unet_mean_teacher.py | 0 .../mitochondria/unet_source.py | 0 9 files changed, 80 insertions(+) create mode 100644 experiments/probabilistic_domain_adaptation/README.md create mode 100644 experiments/probabilistic_domain_adaptation/livecell/README.md create mode 100644 experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py create mode 100644 experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py create mode 100644 experiments/probabilistic_domain_adaptation/livecell/unet_source.py create mode 100644 experiments/probabilistic_domain_adaptation/mitochondria/README.md create mode 100644 experiments/probabilistic_domain_adaptation/mitochondria/unet_adamt.py create mode 100644 experiments/probabilistic_domain_adaptation/mitochondria/unet_mean_teacher.py create mode 100644 experiments/probabilistic_domain_adaptation/mitochondria/unet_source.py diff --git a/experiments/probabilistic_domain_adaptation/README.md b/experiments/probabilistic_domain_adaptation/README.md new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/livecell/README.md b/experiments/probabilistic_domain_adaptation/livecell/README.md new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py new file mode 100644 index 00000000..a42664ff --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py @@ -0,0 +1,80 @@ +import argparse + +import torch_em +from torch_em.model import UNet2d +from torch_em.data.datasets import get_livecell_loader + +CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + + +def _get_loader(args, split, cell_type): + patch_shape = (512, 512) + loader = get_livecell_loader( + args.input, patch_shape, split, + download=True, binary=True, batch_size=args.batch_size, + cell_types=[cell_type] + ) + return loader + + +def _train_cell_type(args, cell_type): + model = UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid") + train_loader = _get_loader(args, "train", cell_type) + val_loader = _get_loader(args, "train", cell_type) + name = f"unet_source/{cell_type}" + trainer = torch_em.default_segmentation_trainer( + name=name, + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-4, + mixed_precision=True, + log_image_interval=100, + ) + trainer.fit(iterations=args.n_iterations) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_cell_type(args, cell_type) + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the loader for the first cell type", cell_types[0]) + + loader = _get_loader(args) + check_loader(loader, n_images) + + +# TODO +def run_evaluation(args): + pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-p", "--phase", required=True) + parser.add_argument("-b", "--batch_size", default=8, type=int) + parser.add_argument("-n", "--n_iterations", default=int(1e5), type=int) + parser.add_argument("--cell_types", nargs="+", default=CELL_TYPES) + args = parser.parse_args() + + phase = args.phase + if phase in ("c", "check"): + check_loader(args) + elif phase in ("t", "train"): + run_training(args) + elif phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/README.md b/experiments/probabilistic_domain_adaptation/mitochondria/README.md new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/unet_adamt.py b/experiments/probabilistic_domain_adaptation/mitochondria/unet_adamt.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/mitochondria/unet_mean_teacher.py new file mode 100644 index 00000000..e69de29b diff --git a/experiments/probabilistic_domain_adaptation/mitochondria/unet_source.py b/experiments/probabilistic_domain_adaptation/mitochondria/unet_source.py new file mode 100644 index 00000000..e69de29b From 384eb548e1404d141ce9423116e29b6522e94fb1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 22:23:21 +0100 Subject: [PATCH 06/21] Update livecell source training --- .../livecell/common.py | 14 +++++++++++ .../livecell/unet_source.py | 25 ++++++------------- 2 files changed, 22 insertions(+), 17 deletions(-) create mode 100644 experiments/probabilistic_domain_adaptation/livecell/common.py diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py new file mode 100644 index 00000000..cf3adbd8 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -0,0 +1,14 @@ +import argparse + +CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + + +def get_parser(default_batch_size=8, default_iterations=int(1e5)): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-p", "--phase", required=True) + parser.add_argument("-b", "--batch_size", default=default_batch_size, type=int) + parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int) + parser.add_argument("-s", "--save_root") + parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) + return parser diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py index a42664ff..03477501 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py @@ -1,10 +1,7 @@ -import argparse - import torch_em from torch_em.model import UNet2d from torch_em.data.datasets import get_livecell_loader - -CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] +from common import get_parser def _get_loader(args, split, cell_type): @@ -12,7 +9,7 @@ def _get_loader(args, split, cell_type): loader = get_livecell_loader( args.input, patch_shape, split, download=True, binary=True, batch_size=args.batch_size, - cell_types=[cell_type] + cell_types=[cell_type], num_workers=8, shuffle=True, ) return loader @@ -30,6 +27,7 @@ def _train_cell_type(args, cell_type): learning_rate=1e-4, mixed_precision=True, log_image_interval=100, + save_root=args.save_root, ) trainer.fit(iterations=args.n_iterations) @@ -57,23 +55,16 @@ def run_evaluation(args): def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input", required=True) - parser.add_argument("-p", "--phase", required=True) - parser.add_argument("-b", "--batch_size", default=8, type=int) - parser.add_argument("-n", "--n_iterations", default=int(1e5), type=int) - parser.add_argument("--cell_types", nargs="+", default=CELL_TYPES) + parser = get_parser(default_iterations=50000) args = parser.parse_args() - - phase = args.phase - if phase in ("c", "check"): + if args.phase in ("c", "check"): check_loader(args) - elif phase in ("t", "train"): + elif args.phase in ("t", "train"): run_training(args) - elif phase in ("e", "evaluate"): + elif args.phase in ("e", "evaluate"): run_evaluation(args) else: - raise ValueError(f"Got phase={phase}, expect one of check, train, evaluate.") + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") if __name__ == "__main__": From 8a91b9035fce80e39cdb43cd93e8ae57c0d19569 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 22:48:39 +0100 Subject: [PATCH 07/21] Livecell training updates --- .../livecell/common.py | 42 +++++++++++++++++++ .../livecell/unet_source.py | 19 ++------- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index cf3adbd8..495503fe 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -1,8 +1,36 @@ import argparse +import torch_em +from torch_em.data.datasets import get_livecell_loader +from torchvision import transforms + CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] +# +# The augmentations we use for the LiveCELL experiments: +# - weak augmenations: blurring and additive gaussian noise +# - strong augmentations: TODO +# + + +def weak_augmentations(p=0.25): + norm = torch_em.transform.raw.standardize + aug = transforms.Compose([ + norm, + transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), + transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( + scale=(0, 0.15), clip_kwargs=False)], p=p + ), + ]) + return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) + + +# +# Other utility functions: parser, loaders +# + + def get_parser(default_batch_size=8, default_iterations=int(1e5)): parser = argparse.ArgumentParser() parser.add_argument("-i", "--input", required=True) @@ -12,3 +40,17 @@ def get_parser(default_batch_size=8, default_iterations=int(1e5)): parser.add_argument("-s", "--save_root") parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) return parser + + +def get_supervised_loader(args, split, cell_type): + patch_shape = (512, 512) + loader = get_livecell_loader( + args.input, patch_shape, split, + download=True, binary=True, batch_size=args.batch_size, + cell_types=[cell_type], num_workers=8, shuffle=True, + ) + return loader + + +def get_unsupervised_loader(args, split, cell_type): + pass diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py index 03477501..63bf1312 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py @@ -1,23 +1,12 @@ import torch_em from torch_em.model import UNet2d -from torch_em.data.datasets import get_livecell_loader -from common import get_parser - - -def _get_loader(args, split, cell_type): - patch_shape = (512, 512) - loader = get_livecell_loader( - args.input, patch_shape, split, - download=True, binary=True, batch_size=args.batch_size, - cell_types=[cell_type], num_workers=8, shuffle=True, - ) - return loader +from common import get_parser, get_supervised_loader def _train_cell_type(args, cell_type): model = UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid") - train_loader = _get_loader(args, "train", cell_type) - val_loader = _get_loader(args, "train", cell_type) + train_loader = get_supervised_loader(args, "train", cell_type) + val_loader = get_supervised_loader(args, "val", cell_type) name = f"unet_source/{cell_type}" trainer = torch_em.default_segmentation_trainer( name=name, @@ -45,7 +34,7 @@ def check_loader(args, n_images=5): print("The cell types", cell_types, "were selected.") print("Checking the loader for the first cell type", cell_types[0]) - loader = _get_loader(args) + loader = get_supervised_loader(args) check_loader(loader, n_images) From e5a7db230a83d9aa940682429ba1c3d35e3c259f Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 23:57:28 +0100 Subject: [PATCH 08/21] Fix issues in raw transforms --- torch_em/data/raw_image_collection_dataset.py | 13 ++++++---- torch_em/transform/raw.py | 26 +++++++++---------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/torch_em/data/raw_image_collection_dataset.py b/torch_em/data/raw_image_collection_dataset.py index 7dfecf6d..d6462a05 100644 --- a/torch_em/data/raw_image_collection_dataset.py +++ b/torch_em/data/raw_image_collection_dataset.py @@ -105,22 +105,25 @@ def _get_sample(self, index): if have_raw_channels: raw = raw.transpose((2, 0, 1)) - if self.augmentations is not None: - aug1, aug2 = self.augmentations - raw1, raw2 = aug1(raw), aug2(raw) - return raw1, raw2 - return raw def __getitem__(self, index): raw = self._get_sample(index) + if self.raw_transform is not None: raw = self.raw_transform(raw) + if self.transform is not None: raw = self.transform(raw) assert len(raw) == 1 raw = raw[0] # if self.trafo_halo is not None: # raw = self.crop(raw) + raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) + if self.augmentations is not None: + aug1, aug2 = self.augmentations + raw1, raw2 = aug1(raw), aug2(raw) + return raw1, raw2 + return raw diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index e20344c1..cd43c379 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -8,18 +8,6 @@ # -def standardize(raw, mean=None, std=None, axis=None, eps=1e-7): - raw = raw.astype("float32") - - mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean - raw -= mean - - std = raw.std(axis=axis, keepdims=True) if std is None else std - raw /= (std + eps) - - return raw - - TORCH_DTYPES = { "float16": torch.float16, "float32": torch.float32, @@ -42,7 +30,19 @@ def cast(inpt, typestring): return inpt.astype(typestring) -def _normalize_torch(tensor, minval=None, maxval=None, axis=None, eps=1e-7): +def standardize(raw, mean=None, std=None, axis=None, eps=1e-7): + raw = cast(raw, "float32") + + mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean + raw -= mean + + std = raw.std(axis=axis, keepdims=True) if std is None else std + raw /= (std + eps) + + return raw + + +def _normalize_torch(tensor, minval, maxval, axis, eps): if axis: # torch returns torch.return_types.min or torch.return_types.max minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval tensor -= minval From 82695084403e1d4d4914b8573d01748a06bb2dc6 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 23:57:52 +0100 Subject: [PATCH 09/21] Fix several issues in self training --- torch_em/self_training/__init__.py | 1 + torch_em/self_training/logger.py | 28 +++++++++++++++++------ torch_em/self_training/mean_teacher.py | 6 ++--- torch_em/self_training/pseudo_labeling.py | 6 ++--- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/torch_em/self_training/__init__.py b/torch_em/self_training/__init__.py index 5de64128..f5d2a38e 100644 --- a/torch_em/self_training/__init__.py +++ b/torch_em/self_training/__init__.py @@ -1,3 +1,4 @@ +from .logger import SelfTrainingTensorboardLogger from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric from .mean_teacher import MeanTeacherTrainer from .pseudo_labeling import DefaultPseudoLabeler diff --git a/torch_em/self_training/logger.py b/torch_em/self_training/logger.py index bddc8c3c..a05e6285 100644 --- a/torch_em/self_training/logger.py +++ b/torch_em/self_training/logger.py @@ -3,8 +3,10 @@ import torch import torch_em +from torchvision.utils import make_grid -class SelfTrainingTensorboardLogger(torch_em.traner.logger_base.TorchEmLogger): + +class SelfTrainingTensorboardLogger(torch_em.trainer.logger_base.TorchEmLogger): def __init__(self, trainer, save_root, **unused_kwargs): super().__init__(trainer, save_root) self.my_root = save_root @@ -15,13 +17,25 @@ def __init__(self, trainer, save_root, **unused_kwargs): self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) self.log_image_interval = trainer.log_image_interval - # TODO make a grid image - def _add_supervised_images(self): - pass + # TODO deal with 3d data + def _add_supervised_images(self, step, name, x, y, pred): + grid = make_grid([x[0], y[0], pred[0]], padding=8) + self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) - # TODO make a grid image - def _add_unsupervised_images(self): - pass + # TODO deal with 3d data + def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): + # from torch_em.transform.raw import _normalize_torch + images = [ + torch_em.transform.raw.normalize(x1[0]), + torch_em.transform.raw.normalize(x2[0]), + pred[0], pseudo_labels[0], + ] + im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels" + if label_filter is not None: + images.append(label_filter[0]) + name += "-labelfilter" + grid = make_grid(images, nrow=2, padding=8) + self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) def log_combined_loss(self, step, loss): self.tb.add_scalar(tag="train/combined_loss", scalar_value=loss, global_step=step) diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index fc16102d..a6076912 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -200,7 +200,7 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter ) lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_lr(lr) + self.logger.log_lr(self._iteration, lr) with torch.no_grad(): self._momentum_update() @@ -251,9 +251,9 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): self._iteration, unsupervised_loss, xu1, xu2, unsup_pred, pseudo_labels, label_filter ) - self.logger.log_combined_loss(loss) + self.logger.log_combined_loss(self._iteration, loss) lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_lr(lr) + self.logger.log_lr(self._iteration, lr) with torch.no_grad(): self._momentum_update() diff --git a/torch_em/self_training/pseudo_labeling.py b/torch_em/self_training/pseudo_labeling.py index 9e2204c5..dd846e13 100644 --- a/torch_em/self_training/pseudo_labeling.py +++ b/torch_em/self_training/pseudo_labeling.py @@ -22,7 +22,7 @@ def _compute_label_mask_both_sides(self, pseudo_labels): mask = (pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold) return mask - def _compute_label_mask_one_sides(self, pseudo_labels): + def _compute_label_mask_one_side(self, pseudo_labels): mask = (pseudo_labels >= self.confidence_threshold) return mask @@ -33,6 +33,6 @@ def __call__(self, teacher, input_): if self.confidence_threshold is None: label_mask = None else: - label_mask = self._compute_label_mask_both_side(pseudo_labels) if self.threshold_from_both_sides\ - else self._compute_label_mask_one_sides(pseudo_labels) + label_mask = self._compute_label_mask_both_sides(pseudo_labels) if self.threshold_from_both_sides\ + else self._compute_label_mask_one_side(pseudo_labels) return pseudo_labels, label_mask From 8ba48af4c09b17c4d2106df3467b19be9bea74ee Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 11 Mar 2023 23:59:13 +0100 Subject: [PATCH 10/21] Update domain adaptation experiments --- experiments/livecell/train_boundaries.py | 2 +- .../livecell/common.py | 61 ++++++++--- .../livecell/unet_adamt.py | 102 ++++++++++++++++++ 3 files changed, 151 insertions(+), 14 deletions(-) diff --git a/experiments/livecell/train_boundaries.py b/experiments/livecell/train_boundaries.py index ab5d2427..f54fb78a 100644 --- a/experiments/livecell/train_boundaries.py +++ b/experiments/livecell/train_boundaries.py @@ -62,7 +62,7 @@ def check_loader(args, train=True, val=True, n_images=5): check_loader(loader, n_images) -if __name__ == '__main__': +if __name__ == "__main__": parser = torch_em.util.parser_helper(default_batch_size=8) parser.add_argument("--cell_type", default=None) args = parser.parse_args() diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index 495503fe..898ed643 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -1,7 +1,9 @@ import argparse import torch_em -from torch_em.data.datasets import get_livecell_loader +from torch_em.data.datasets.livecell import (get_livecell_loader, + _download_livecell_images, + _download_livecell_annotations) from torchvision import transforms CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] @@ -26,20 +28,46 @@ def weak_augmentations(p=0.25): return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) +# TODO +def strong_augmentations(): + pass + + # -# Other utility functions: parser, loaders +# Other utility functions: loaders, parser # -def get_parser(default_batch_size=8, default_iterations=int(1e5)): - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input", required=True) - parser.add_argument("-p", "--phase", required=True) - parser.add_argument("-b", "--batch_size", default=default_batch_size, type=int) - parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int) - parser.add_argument("-s", "--save_root") - parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) - return parser +def _get_image_paths(args, split, cell_type): + _download_livecell_images(args.input, download=True) + image_paths, _ = _download_livecell_annotations(args.input, split, download=True, + cell_types=[cell_type], label_path=None) + return image_paths + + +def get_unsupervised_loader(args, split, cell_type, teacher_augmentation, student_augmentation): + patch_shape = (512, 512) + + def _parse_aug(aug): + if aug == "weak": + return weak_augmentations() + elif aug == "strong": + return strong_augmentations() + assert callable(aug) + return aug + + raw_transform = torch_em.transform.get_raw_transform() + transform = torch_em.transform.get_augmentations(ndim=2) + + image_paths = _get_image_paths(args, split, cell_type) + + augmentations = (_parse_aug(teacher_augmentation), _parse_aug(student_augmentation)) + ds = torch_em.data.RawImageCollectionDataset( + image_paths, patch_shape, raw_transform, transform, + augmentations=augmentations + ) + loader = torch_em.segmentation.get_data_loader(ds, batch_size=args.batch_size, num_workers=8, shuffle=True) + return loader def get_supervised_loader(args, split, cell_type): @@ -52,5 +80,12 @@ def get_supervised_loader(args, split, cell_type): return loader -def get_unsupervised_loader(args, split, cell_type): - pass +def get_parser(default_batch_size=8, default_iterations=int(1e5)): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-p", "--phase", required=True) + parser.add_argument("-b", "--batch_size", default=default_batch_size, type=int) + parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int) + parser.add_argument("-s", "--save_root") + parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) + return parser diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py index e69de29b..e4215554 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -0,0 +1,102 @@ +import torch +import torch_em.self_training as self_training +from torch_em.model import UNet2d + +from common import CELL_TYPES, get_parser, get_supervised_loader, get_unsupervised_loader + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="weak", + ) + check_loader(loader, n_images) + + +def _train_source_target(args, source_cell_type, target_cell_type): + model = UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid") + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + + # self training functionality + thresh = args.confidence_threshold + pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh) + loss = self_training.DefaultSelfTrainingLoss() + loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() + + # data loaders + supervised_train_loader = get_supervised_loader(args, "train", source_cell_type) + supervised_val_loader = get_supervised_loader(args, "val", source_cell_type) + unsupervised_train_loader = get_unsupervised_loader( + args, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + unsupervised_val_loader = get_unsupervised_loader( + args, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + name = f"unet_adamt/thresh-{thresh}/{source_cell_type}/{target_cell_type}" + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + supervised_train_loader=supervised_train_loader, + unsupervised_train_loader=unsupervised_train_loader, + supervised_val_loader=supervised_val_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=self_training.SelfTrainingTensorboardLogger, + mixed_precision=True, + device=device, + log_image_interval=100, + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + for target_cell_type in CELL_TYPES: + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + pass + + +def main(): + parser = get_parser(default_iterations=75000) + parser.add_argument("--confidence_threshold", default=0.9) + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() From 6c8f82458d0c502739edb46e2aeecb26a014e9b7 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 12 Mar 2023 00:15:34 +0100 Subject: [PATCH 11/21] Minor fixes --- .../probabilistic_domain_adaptation/livecell/unet_adamt.py | 2 +- torch_em/self_training/logger.py | 5 ++++- torch_em/self_training/pseudo_labeling.py | 2 +- torch_em/transform/raw.py | 6 ++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py index e4215554..5ffe9532 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -85,7 +85,7 @@ def run_evaluation(args): def main(): - parser = get_parser(default_iterations=75000) + parser = get_parser(default_iterations=75000, default_batch_size=4) parser.add_argument("--confidence_threshold", default=0.9) args = parser.parse_args() if args.phase in ("c", "check"): diff --git a/torch_em/self_training/logger.py b/torch_em/self_training/logger.py index a05e6285..83c4babc 100644 --- a/torch_em/self_training/logger.py +++ b/torch_em/self_training/logger.py @@ -19,7 +19,10 @@ def __init__(self, trainer, save_root, **unused_kwargs): # TODO deal with 3d data def _add_supervised_images(self, step, name, x, y, pred): - grid = make_grid([x[0], y[0], pred[0]], padding=8) + grid = make_grid( + [torch_em.transform.raw.normalize(x[0]), y[0], pred[0]], + padding=8 + ) self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) # TODO deal with 3d data diff --git a/torch_em/self_training/pseudo_labeling.py b/torch_em/self_training/pseudo_labeling.py index dd846e13..26f9d4ef 100644 --- a/torch_em/self_training/pseudo_labeling.py +++ b/torch_em/self_training/pseudo_labeling.py @@ -19,7 +19,7 @@ def __init__(self, activation=None, confidence_threshold=None, threshold_from_bo def _compute_label_mask_both_sides(self, pseudo_labels): upper_threshold = self.confidence_threshold lower_threshold = 1.0 - self.confidence_threshold - mask = (pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold) + mask = ((pseudo_labels >= upper_threshold) + (pseudo_labels <= lower_threshold)).to(dtype=torch.float32) return mask def _compute_label_mask_one_side(self, pseudo_labels): diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index cd43c379..cd89a10e 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -33,10 +33,12 @@ def cast(inpt, typestring): def standardize(raw, mean=None, std=None, axis=None, eps=1e-7): raw = cast(raw, "float32") - mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean + # mean = raw.mean(axis=axis, keepdims=True) if mean is None else mean + mean = raw.mean() if mean is None else mean raw -= mean - std = raw.std(axis=axis, keepdims=True) if std is None else std + # std = raw.std(axis=axis, keepdims=True) if std is None else std + std = raw.std() if std is None else std raw /= (std + eps) return raw From 020086aa87ca02ab1012b5342fef93ecbe364fd4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 12 Mar 2023 20:59:01 +0100 Subject: [PATCH 12/21] Implement mean teacer training for livecell --- .../livecell/unet_mean_teacher.py | 109 ++++++++++++++++++ torch_em/self_training/pseudo_labeling.py | 1 + 2 files changed, 110 insertions(+) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py index e69de29b..46b41ea9 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py @@ -0,0 +1,109 @@ +import os + +import torch +import torch_em.self_training as self_training +from torch_em.model import UNet2d + +from common import CELL_TYPES, get_parser, get_unsupervised_loader + + +def check_loader(args, n_images=5): + from torch_em.util.debug import check_loader + + cell_types = args.cell_types + print("The cell types", cell_types, "were selected.") + print("Checking the unsupervised loader for the first cell type", cell_types[0]) + + loader = get_unsupervised_loader( + args, "train", cell_types[0], + teacher_augmentation="weak", student_augmentation="weak", + ) + check_loader(loader, n_images) + + +def _load_model(model, checkpoint): + state = torch.load(os.path.join(checkpoint, "best.pt"))["model_state"] + model.load_state_dict(state) + return model + + +def _train_source_target(args, source_cell_type, target_cell_type): + model = UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid") + optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) + + # self training functionality + thresh = args.confidence_threshold + pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh) + loss = self_training.DefaultSelfTrainingLoss() + loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() + + # data loaders + unsupervised_train_loader = get_unsupervised_loader( + args, "train", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + unsupervised_val_loader = get_unsupervised_loader( + args, "val", target_cell_type, + teacher_augmentation="weak", student_augmentation="weak", + ) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" + model = _load_model(model, src_checkpoint) + + name = f"unet_mean_teacher/thresh-{thresh}/{source_cell_type}/{target_cell_type}" + trainer = self_training.MeanTeacherTrainer( + name=name, + model=model, + optimizer=optimizer, + lr_scheduler=scheduler, + pseudo_labeler=pseudo_labeler, + unsupervised_loss=loss, + unsupervised_loss_and_metric=loss_and_metric, + unsupervised_train_loader=unsupervised_train_loader, + unsupervised_val_loader=unsupervised_val_loader, + supervised_loss=loss, + supervised_loss_and_metric=loss_and_metric, + logger=self_training.SelfTrainingTensorboardLogger, + mixed_precision=True, + device=device, + log_image_interval=100, + ) + trainer.fit(args.n_iterations) + + +def _train_source(args, cell_type): + for target_cell_type in CELL_TYPES: + if target_cell_type == cell_type: + continue + _train_source_target(args, cell_type, target_cell_type) + + +def run_training(args): + for cell_type in args.cell_types: + print("Start training for cell type:", cell_type) + _train_source(args, cell_type) + + +def run_evaluation(args): + pass + + +def main(): + parser = get_parser(default_iterations=25000, default_batch_size=8) + parser.add_argument("--confidence_threshold", default=0.9) + args = parser.parse_args() + if args.phase in ("c", "check"): + check_loader(args) + elif args.phase in ("t", "train"): + run_training(args) + elif args.phase in ("e", "evaluate"): + run_evaluation(args) + else: + raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") + + +if __name__ == "__main__": + main() diff --git a/torch_em/self_training/pseudo_labeling.py b/torch_em/self_training/pseudo_labeling.py index 26f9d4ef..1d84f66d 100644 --- a/torch_em/self_training/pseudo_labeling.py +++ b/torch_em/self_training/pseudo_labeling.py @@ -1,3 +1,4 @@ +import torch class DefaultPseudoLabeler: From ef78bbc4f2b957c324466cf1ab951d61fe24a474 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 12 Mar 2023 21:47:02 +0100 Subject: [PATCH 13/21] Enable deserialization for mean teacher trainer --- test/self_training/test_mean_teacher.py | 23 ++++++++++++----------- torch_em/self_training/mean_teacher.py | 17 +++++++++++++---- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/test/self_training/test_mean_teacher.py b/test/self_training/test_mean_teacher.py index b3b406e6..361a9e61 100644 --- a/test/self_training/test_mean_teacher.py +++ b/test/self_training/test_mean_teacher.py @@ -58,7 +58,6 @@ def _test_mean_teacher( supervised_val_loader=supervised_val_loader, supervised_loss=supervised_loss, supervised_loss_and_metric=supervised_loss_and_metric, - logger=None, mixed_precision=False, device=torch.device("cpu"), ) @@ -66,16 +65,18 @@ def _test_mean_teacher( self.assertTrue(os.path.exists(f"./checkpoints/{name}/best.pt")) self.assertTrue(os.path.exists(f"./checkpoints/{name}/latest.pt")) - # TODO - # # make sure that the trainer can be deserialized from the checkpoint - # trainer2 = MeanTeacherTrainer.from_checkpoint(os.path.join("./checkpoints", name), name="latest") - # self.assertEqual(trainer.iteration, trainer2.iteration) - # self.assertTrue(torch_em.util.model_is_equal(trainer.model, trainer2.model)) - # self.assertTrue(torch_em.util.model_is_equal(trainer.teacher, trainer2.teacher)) + # make sure that the trainer can be deserialized from the checkpoint + trainer2 = self_training.MeanTeacherTrainer.from_checkpoint(os.path.join("./checkpoints", name), name="latest") + self.assertEqual(trainer.iteration, trainer2.iteration) + self.assertTrue(torch_em.util.model_is_equal(trainer.model, trainer2.model)) + self.assertTrue(torch_em.util.model_is_equal(trainer.teacher, trainer2.teacher)) + self.assertEqual(len(trainer.unsupervised_train_loader), len(trainer2.unsupervised_train_loader)) + if supervised_train_loader is not None: + self.assertEqual(len(trainer.supervised_train_loader), len(trainer2.supervised_train_loader)) - # # and that it can be trained further - # trainer2.fit(10) - # self.assertEqual(trainer2.iteration, 63) + # and that it can be trained further + trainer2.fit(10) + self.assertEqual(trainer2.iteration, 63) def get_unsupervised_loader(self, n_samples): augmentations = ( @@ -114,7 +115,7 @@ def test_mean_teacher_unsupervised(self): def test_mean_teacher_semisupervised(self): unsupervised_train_loader = self.get_unsupervised_loader(n_samples=50) - supervised_train_loader = self.get_supervised_loader(n_samples=50) + supervised_train_loader = self.get_supervised_loader(n_samples=51) supervised_val_loader = self.get_supervised_loader(n_samples=4) self._test_mean_teacher( unsupervised_train_loader=unsupervised_train_loader, diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index a6076912..4e52f5f8 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -4,6 +4,8 @@ import torch import torch_em +from .logger import SelfTrainingTensorboardLogger + class Dummy(torch.nn.Module): pass @@ -50,12 +52,14 @@ class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): for setting the ratio between supervised and unsupervised training samples Parameters: + model [nn.Module] - unsupervised_train_loader [torch.DataLoader] - unsupervised_loss [callable] - supervised_train_loader [torch.DataLoader] - (default: None) supervised_loss [callable] - (default: None) unsupervised_loss_and_metric [callable] - (default: None) supervised_loss_and_metric [callable] - (default: None) + logger [TorchEmLogger] - (default: SelfTrainingTensorboardLogger) momentum [float] - (default: 0.999) reinit_teacher [bool] - (default: None) **kwargs - keyword arguments for torch_em.DataLoader @@ -63,6 +67,7 @@ class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer): def __init__( self, + model, unsupervised_train_loader, unsupervised_loss, pseudo_labeler, @@ -70,8 +75,9 @@ def __init__( unsupervised_val_loader=None, supervised_val_loader=None, supervised_loss=None, - supervised_loss_and_metric=None, unsupervised_loss_and_metric=None, + supervised_loss_and_metric=None, + logger=SelfTrainingTensorboardLogger, momentum=0.999, reinit_teacher=None, **kwargs @@ -111,8 +117,10 @@ def __init__( self.supervised_loss_and_metric = supervised_loss_and_metric self.unsupervised_loss_and_metric = unsupervised_loss_and_metric - # TODO we need to recover which loadder is which, and take care of the correct deserialization! - super().__init__(train_loader=train_loader, val_loader=val_loader, loss=Dummy(), metric=Dummy(), **kwargs) + super().__init__( + model=model, train_loader=train_loader, val_loader=val_loader, + loss=Dummy(), metric=Dummy(), logger=logger, **kwargs + ) self.unsupervised_loss = unsupervised_loss self.supervised_loss = supervised_loss @@ -120,7 +128,6 @@ def __init__( self.pseudo_labeler = pseudo_labeler self.momentum = momentum - self._kwargs = {"momentum": momentum, "reinit_teacher": reinit_teacher, **kwargs} # determine how we initialize the teacher weights (copy or reinitialization) if reinit_teacher is None: @@ -139,6 +146,8 @@ def __init__( for param in self.teacher.parameters(): param.requires_grad = False + self._kwargs = kwargs + def _momentum_update(self): # if we reinit the teacher we perform much faster updates (low momentum) in the first iterations # to avoid a large gap between teacher and student weights, leading to inconsistent predictions From d73b67dd64a8795b534a1083698ed44df8c2c2a5 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 12 Mar 2023 22:51:47 +0100 Subject: [PATCH 14/21] Simplify spoco trainer --- torch_em/trainer/spoco_trainer.py | 130 ++++++------------------------ 1 file changed, 26 insertions(+), 104 deletions(-) diff --git a/torch_em/trainer/spoco_trainer.py b/torch_em/trainer/spoco_trainer.py index 3646465f..1d46dd17 100644 --- a/torch_em/trainer/spoco_trainer.py +++ b/torch_em/trainer/spoco_trainer.py @@ -2,8 +2,8 @@ from copy import deepcopy import torch -import torch.cuda.amp as amp from .default_trainer import DefaultTrainer +from .tensorboard_logger import TensorboardLogger class SPOCOTrainer(DefaultTrainer): @@ -13,9 +13,10 @@ def __init__( momentum=0.999, semisupervised_loss=None, semisupervised_loader=None, + logger=TensorboardLogger, **kwargs, ): - super().__init__(model=model, **kwargs) + super().__init__(model=model, logger=logger, **kwargs) self.momentum = momentum # copy the model and don"t require gradients for it self.model2 = deepcopy(self.model) @@ -46,7 +47,7 @@ def _initialize(self, iterations, load_from_checkpoint): self.model2.to(self.device) return best_metric - def _train_epoch_semisupervised(self, progress): + def _train_epoch_semisupervised(self, progress, forward_context, backprop): self.model.train() self.model2.train() progress.set_description( @@ -57,75 +58,17 @@ def _train_epoch_semisupervised(self, progress): x = x.to(self.device) self.optimizer.zero_grad() - prediction = self.model(x) - with torch.no_grad(): - self._momentum_update() - prediction2 = self.model2(x) - loss = self.semisupervised_loss(prediction, prediction2) - loss.backward() - self.optimizer.step() - - def _train_epoch(self, progress): - self.model.train() - self.model2.train() - - n_iter = 0 - t_per_iter = time.time() - for x, y in self.train_loader: - x, y = x.to(self.device), y.to(self.device) - - self.optimizer.zero_grad() - - prediction = self.model(x) - with torch.no_grad(): - self._momentum_update() - prediction2 = self.model2(x) - if self._iteration % self.log_image_interval == 0: - prediction.retain_grad() - loss = self.loss((prediction, prediction2), y) - - loss.backward() - self.optimizer.step() - - lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - if self.logger is not None: - self.logger.log_train(self._iteration, loss, lr, - x, y, prediction, - log_gradients=True) - - self._iteration += 1 - n_iter += 1 - if self._iteration >= self.max_iteration: - break - progress.update(1) - - if self.semisupervised_loader is not None: - self._train_epoch_semisupervised(progress) - t_per_iter = (time.time() - t_per_iter) / n_iter - return t_per_iter - - def _train_epoch_semisupervised_mixed(self, progress): - self.model.train() - self.model2.train() - progress.set_description( - f"Run semi-supervised training for {len(self.semisupervised_loader)} iterations", refresh=True - ) - - for x in self.semisupervised_loader: - x = x.to(self.device) - self.optimizer.zero_grad() - - with amp.autocast(): + with forward_context(): prediction = self.model(x) with torch.no_grad(): - self._momentum_update() prediction2 = self.model2(x) loss = self.semisupervised_loss(prediction, prediction2) - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() + backprop(loss) - def _train_epoch_mixed(self, progress): + with torch.no_grad(): + self._momentum_update() + + def _train_epoch_impl(self, progress, forward_context, backprop): self.model.train() self.model2.train() @@ -136,21 +79,25 @@ def _train_epoch_mixed(self, progress): self.optimizer.zero_grad() - with amp.autocast(): + with forward_context(): prediction = self.model(x) with torch.no_grad(): - self._momentum_update() prediction2 = self.model2(x) loss = self.loss((prediction, prediction2), y) - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() + if self._iteration % self.log_image_interval == 0: + prediction.retain_grad() + + backprop(loss) + + with torch.no_grad(): + self._momentum_update() lr = [pm["lr"] for pm in self.optimizer.param_groups][0] if self.logger is not None: self.logger.log_train(self._iteration, loss, lr, - x, y, prediction) + x, y, prediction, + log_gradients=True) self._iteration += 1 n_iter += 1 @@ -159,11 +106,11 @@ def _train_epoch_mixed(self, progress): progress.update(1) if self.semisupervised_loader is not None: - self._train_epoch_semisupervised_mixed(progress) + self._train_epoch_semisupervised(progress, forward_context, backprop) t_per_iter = (time.time() - t_per_iter) / n_iter return t_per_iter - def _validate(self): + def _validate_impl(self, forward_context): self.model.eval() self.model2.eval() @@ -173,39 +120,14 @@ def _validate(self): with torch.no_grad(): for x, y in self.val_loader: x, y = x.to(self.device), y.to(self.device) - prediction = self.model(x) - prediction2 = self.model2(x) + with forward_context(): + prediction = self.model(x) + prediction2 = self.model2(x) loss += self.loss((prediction, prediction2), y).item() metric += self.metric(prediction, y).item() metric /= len(self.val_loader) loss /= len(self.val_loader) if self.logger is not None: - self.logger.log_validation(self._iteration, metric, loss, - x, y, prediction) + self.logger.log_validation(self._iteration, metric, loss, x, y, prediction) return metric - - def _validate_mixed(self): - self.model.eval() - self.model2.eval() - - metric_val = 0.0 - loss_val = 0.0 - - with torch.no_grad(): - for x, y in self.val_loader: - x, y = x.to(self.device), y.to(self.device) - with amp.autocast(): - prediction = self.model(x) - prediction2 = self.model2(x) - loss = self.loss((prediction, prediction2), y) - metric = self.metric(prediction, y) - loss_val += loss - metric_val += metric - - metric_val /= len(self.val_loader) - loss_val /= len(self.val_loader) - if self.logger is not None: - self.logger.log_validation(self._iteration, metric, loss, - x, y, prediction) - return metric_val From 66796d54baceebc755a4371636ee48f6d23e7a2b Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 12 Mar 2023 22:52:37 +0100 Subject: [PATCH 15/21] Fix issue in mean teacher trainer --- torch_em/self_training/mean_teacher.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index 4e52f5f8..2aa760d3 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -195,6 +195,8 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): xu1, xu2 = xu1.to(self.device), xu2.to(self.device) teacher_input, model_input = xu1, xu2 + + self.optimizer.zero_grad() # Perform unsupervised training with forward_context(): # Compute the pseudo labels. From 1644d180cdc9edd2ae6b9c9919043b9a3796f1bb Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 13 Mar 2023 21:40:05 +0100 Subject: [PATCH 16/21] Implement unet source eval --- .../livecell/README.md | 2 + .../livecell/unet_source.py | 57 ++++++++++++++++++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/README.md b/experiments/probabilistic_domain_adaptation/livecell/README.md index e69de29b..fc9e1774 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/README.md +++ b/experiments/probabilistic_domain_adaptation/livecell/README.md @@ -0,0 +1,2 @@ +TODO: double check the unet-src results with the results from Anwai. +Results are in `results/unet_source.csv`. diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py index 63bf1312..e2cfe192 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py @@ -1,6 +1,21 @@ +import os +from glob import glob + +import torch import torch_em +import numpy as np +import pandas as pd +try: + import imageio.v2 as imageio +except ImportError: + import imageio + +from elf.evaluation import dice_score from torch_em.model import UNet2d -from common import get_parser, get_supervised_loader +from torch_em.util.prediction import predict_with_padding +from tqdm import tqdm + +from common import CELL_TYPES, get_parser, get_supervised_loader def _train_cell_type(args, cell_type): @@ -38,9 +53,45 @@ def check_loader(args, n_images=5): check_loader(loader, n_images) -# TODO +def _eval_src(args, ct_src): + ckpt = f"checkpoints/unet_source/{ct_src}" + model = torch_em.util.get_trainer(ckpt).model + + image_folder = os.path.join(args.input, "images", "livecell_test_images") + label_root = os.path.join(args.input, "annotations", "livecell_test_images") + + results = {"src": [ct_src]} + device = torch.device("cuda") + + with torch.no_grad(): + for ct_trg in CELL_TYPES: + label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) + scores = [] + for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"): + image_path = os.path.join(image_folder, os.path.basename(label_path)) + assert os.path.exists(image_path) + image = imageio.imread(image_path) + image = torch_em.transform.raw.standardize(image) + pred = predict_with_padding(model, image, min_divisible=(16, 16), device=device).squeeze() + labels = imageio.imread(label_path) + assert image.shape == labels.shape + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + scores.append(score) + results[ct_trg] = np.mean(scores) + return pd.DataFrame(results) + + def run_evaluation(args): - pass + results = [] + for ct in args.cell_types: + res = _eval_src(args, ct) + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_source.csv"), index=False) def main(): From 39b00f9c4a86e36502bb99921c85e3afa2e5fcc1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 14 Mar 2023 20:19:41 +0100 Subject: [PATCH 17/21] Update livecell domain adaptation training --- .../livecell/check_result.py | 19 +++ .../livecell/common.py | 145 +++++++++++++++++- .../livecell/unet_adamt.py | 30 ++-- .../livecell/unet_mean_teacher.py | 39 ++--- .../livecell/unet_source.py | 54 +------ 5 files changed, 212 insertions(+), 75 deletions(-) create mode 100644 experiments/probabilistic_domain_adaptation/livecell/check_result.py diff --git a/experiments/probabilistic_domain_adaptation/livecell/check_result.py b/experiments/probabilistic_domain_adaptation/livecell/check_result.py new file mode 100644 index 00000000..884bc5c1 --- /dev/null +++ b/experiments/probabilistic_domain_adaptation/livecell/check_result.py @@ -0,0 +1,19 @@ +import argparse +import pandas as pd + + +def check_result(path): + table = pd.read_csv(path) + print(table) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("paths", nargs="+") + args = parser.parse_args() + for path in args.paths: + check_result(path) + + +if __name__ == "__main__": + main() diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index 898ed643..cd7ad1c9 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -1,10 +1,24 @@ import argparse - +import os +from glob import glob + +try: + import imageio.v2 as imageio +except ImportError: + import imageio +import numpy as np +import pandas as pd +import torch import torch_em + +from elf.evaluation import dice_score from torch_em.data.datasets.livecell import (get_livecell_loader, _download_livecell_images, _download_livecell_annotations) +from torch_em.model import UNet2d +from torch_em.util.prediction import predict_with_padding from torchvision import transforms +from tqdm import tqdm CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] @@ -33,6 +47,134 @@ def strong_augmentations(): pass +# +# Model and prediction functionality: the models we use in all experiments +# + +def get_unet(): + return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4) + + +def load_model(model, ckpt, state="model_state", device=None): + state = torch.load(os.path.join(ckpt, "best.pt"))[state] + model.load_state_dict(state) + if device is not None: + model.to(device) + return model + + +# use get_model and prediction_function to customize this, e.g. for using it with the PUNet +# set model_state to "teacher_state" when using this with a mean-teacher method +def evaluate_transfered_model( + args, ct_src, method, get_model=get_unet, prediction_function=None, model_state="model_state" +): + image_folder = os.path.join(args.input, "images", "livecell_test_images") + label_root = os.path.join(args.input, "annotations", "livecell_test_images") + + results = {"src": [ct_src]} + device = torch.device("cuda") + + thresh = args.confidence_threshold + with torch.no_grad(): + for ct_trg in CELL_TYPES: + + if ct_trg == ct_src: + results[ct_trg] = None + continue + + out_folder = None if args.output is None else os.path.join( + args.output, f"thresh-{thresh}", ct_src, ct_trg + ) + if out_folder is not None: + os.makedirs(out_folder, exist_ok=True) + + ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}" + model = get_model() + model = load_model(model, ckpt, device=device, state=model_state) + + label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) + scores = [] + for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"): + + labels = imageio.imread(label_path) + if out_folder is None: + out_path = None + else: + out_path = os.path.join(out_folder, os.path.basename(label_path)) + if os.path.exists(out_path): + pred = imageio.imread(out_path) + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + scores.append(score) + continue + + image_path = os.path.join(image_folder, os.path.basename(label_path)) + assert os.path.exists(image_path) + image = imageio.imread(image_path) + image = torch_em.transform.raw.standardize(image) + pred = predict_with_padding( + model, image, min_divisible=(16, 16), device=device, prediction_function=prediction_function, + ).squeeze() + assert image.shape == labels.shape + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + if out_path is not None: + imageio.imwrite(out_path, pred) + scores.append(score) + + results[ct_trg] = np.mean(scores) + return pd.DataFrame(results) + + +# use get_model and prediction_function to customize this, e.g. for using it with the PUNet +def evaluate_source_model(args, ct_src, method, get_model=get_unet, prediction_function=None): + ckpt = f"checkpoints/{method}/{ct_src}" + model = get_model() + model = torch_em.util.get_trainer(ckpt).model + + image_folder = os.path.join(args.input, "images", "livecell_test_images") + label_root = os.path.join(args.input, "annotations", "livecell_test_images") + + results = {"src": [ct_src]} + device = torch.device("cuda") + + with torch.no_grad(): + for ct_trg in CELL_TYPES: + + out_folder = None if args.output is None else os.path.join(args.output, ct_src, ct_trg) + if out_folder is not None: + os.makedirs(out_folder, exist_ok=True) + + label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) + scores = [] + for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"): + + labels = imageio.imread(label_path) + if out_folder is None: + out_path = None + else: + out_path = os.path.join(out_folder, os.path.basename(label_path)) + if os.path.exists(out_path): + pred = imageio.imread(out_path) + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + scores.append(score) + continue + + image_path = os.path.join(image_folder, os.path.basename(label_path)) + assert os.path.exists(image_path) + image = imageio.imread(image_path) + image = torch_em.transform.raw.standardize(image) + pred = predict_with_padding( + model, image, min_divisible=(16, 16), device=device, prediction_function=prediction_function + ).squeeze() + assert image.shape == labels.shape + score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) + if out_path is not None: + imageio.imwrite(out_path, pred) + scores.append(score) + + results[ct_trg] = np.mean(scores) + return pd.DataFrame(results) + + # # Other utility functions: loaders, parser # @@ -88,4 +230,5 @@ def get_parser(default_batch_size=8, default_iterations=int(1e5)): parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int) parser.add_argument("-s", "--save_root") parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) + parser.add_argument("-o", "--output") return parser diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py index 5ffe9532..dd4a7749 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -1,8 +1,11 @@ +import os + +import pandas as pd import torch import torch_em.self_training as self_training from torch_em.model import UNet2d -from common import CELL_TYPES, get_parser, get_supervised_loader, get_unsupervised_loader +import common def check_loader(args, n_images=5): @@ -12,7 +15,7 @@ def check_loader(args, n_images=5): print("The cell types", cell_types, "were selected.") print("Checking the unsupervised loader for the first cell type", cell_types[0]) - loader = get_unsupervised_loader( + loader = common.get_unsupervised_loader( args, "train", cell_types[0], teacher_augmentation="weak", student_augmentation="weak", ) @@ -31,13 +34,13 @@ def _train_source_target(args, source_cell_type, target_cell_type): loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() # data loaders - supervised_train_loader = get_supervised_loader(args, "train", source_cell_type) - supervised_val_loader = get_supervised_loader(args, "val", source_cell_type) - unsupervised_train_loader = get_unsupervised_loader( + supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type) + supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type) + unsupervised_train_loader = common.get_unsupervised_loader( args, "train", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) - unsupervised_val_loader = get_unsupervised_loader( + unsupervised_val_loader = common.get_unsupervised_loader( args, "val", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) @@ -68,7 +71,7 @@ def _train_source_target(args, source_cell_type, target_cell_type): def _train_source(args, cell_type): - for target_cell_type in CELL_TYPES: + for target_cell_type in common.CELL_TYPES: if target_cell_type == cell_type: continue _train_source_target(args, cell_type, target_cell_type) @@ -81,11 +84,20 @@ def run_training(args): def run_evaluation(args): - pass + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "unet_adamt", model_state="teacher_state") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_adamt.csv"), index=False) def main(): - parser = get_parser(default_iterations=75000, default_batch_size=4) + parser = common.get_parser(default_iterations=75000, default_batch_size=4) parser.add_argument("--confidence_threshold", default=0.9) args = parser.parse_args() if args.phase in ("c", "check"): diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py index 46b41ea9..4f6e9dc7 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py @@ -1,10 +1,10 @@ import os +import pandas as pd import torch import torch_em.self_training as self_training -from torch_em.model import UNet2d -from common import CELL_TYPES, get_parser, get_unsupervised_loader +import common def check_loader(args, n_images=5): @@ -14,21 +14,18 @@ def check_loader(args, n_images=5): print("The cell types", cell_types, "were selected.") print("Checking the unsupervised loader for the first cell type", cell_types[0]) - loader = get_unsupervised_loader( + loader = common.get_unsupervised_loader( args, "train", cell_types[0], teacher_augmentation="weak", student_augmentation="weak", ) check_loader(loader, n_images) -def _load_model(model, checkpoint): - state = torch.load(os.path.join(checkpoint, "best.pt"))["model_state"] - model.load_state_dict(state) - return model - - def _train_source_target(args, source_cell_type, target_cell_type): - model = UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid") + model = common.get_unet() + src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" + model = common.load_model(model, src_checkpoint) + optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) @@ -39,20 +36,17 @@ def _train_source_target(args, source_cell_type, target_cell_type): loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() # data loaders - unsupervised_train_loader = get_unsupervised_loader( + unsupervised_train_loader = common.get_unsupervised_loader( args, "train", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) - unsupervised_val_loader = get_unsupervised_loader( + unsupervised_val_loader = common.get_unsupervised_loader( args, "val", target_cell_type, teacher_augmentation="weak", student_augmentation="weak", ) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - src_checkpoint = f"./checkpoints/unet_source/{source_cell_type}" - model = _load_model(model, src_checkpoint) - name = f"unet_mean_teacher/thresh-{thresh}/{source_cell_type}/{target_cell_type}" trainer = self_training.MeanTeacherTrainer( name=name, @@ -75,7 +69,7 @@ def _train_source_target(args, source_cell_type, target_cell_type): def _train_source(args, cell_type): - for target_cell_type in CELL_TYPES: + for target_cell_type in common.CELL_TYPES: if target_cell_type == cell_type: continue _train_source_target(args, cell_type, target_cell_type) @@ -88,11 +82,20 @@ def run_training(args): def run_evaluation(args): - pass + results = [] + for ct in args.cell_types: + res = common.evaluate_transfered_model(args, ct, "unet_mean_teacher", model_state="teacher_state") + results.append(res) + results = pd.concat(results) + print("Evaluation results:") + print(results) + result_folder = "./results" + os.makedirs(result_folder, exist_ok=True) + results.to_csv(os.path.join(result_folder, "unet_mean_teacher.csv"), index=False) def main(): - parser = get_parser(default_iterations=25000, default_batch_size=8) + parser = common.get_parser(default_iterations=25000, default_batch_size=8) parser.add_argument("--confidence_threshold", default=0.9) args = parser.parse_args() if args.phase in ("c", "check"): diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py index e2cfe192..4347a8e5 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_source.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_source.py @@ -1,27 +1,15 @@ import os -from glob import glob -import torch import torch_em -import numpy as np import pandas as pd -try: - import imageio.v2 as imageio -except ImportError: - import imageio -from elf.evaluation import dice_score -from torch_em.model import UNet2d -from torch_em.util.prediction import predict_with_padding -from tqdm import tqdm - -from common import CELL_TYPES, get_parser, get_supervised_loader +import common def _train_cell_type(args, cell_type): - model = UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid") - train_loader = get_supervised_loader(args, "train", cell_type) - val_loader = get_supervised_loader(args, "val", cell_type) + model = common.get_unet() + train_loader = common.get_supervised_loader(args, "train", cell_type) + val_loader = common.get_supervised_loader(args, "val", cell_type) name = f"unet_source/{cell_type}" trainer = torch_em.default_segmentation_trainer( name=name, @@ -49,42 +37,14 @@ def check_loader(args, n_images=5): print("The cell types", cell_types, "were selected.") print("Checking the loader for the first cell type", cell_types[0]) - loader = get_supervised_loader(args) + loader = common.get_supervised_loader(args) check_loader(loader, n_images) -def _eval_src(args, ct_src): - ckpt = f"checkpoints/unet_source/{ct_src}" - model = torch_em.util.get_trainer(ckpt).model - - image_folder = os.path.join(args.input, "images", "livecell_test_images") - label_root = os.path.join(args.input, "annotations", "livecell_test_images") - - results = {"src": [ct_src]} - device = torch.device("cuda") - - with torch.no_grad(): - for ct_trg in CELL_TYPES: - label_paths = glob(os.path.join(label_root, ct_trg, "*.tif")) - scores = [] - for label_path in tqdm(label_paths, desc=f"Predict for src={ct_src}, trgt={ct_trg}"): - image_path = os.path.join(image_folder, os.path.basename(label_path)) - assert os.path.exists(image_path) - image = imageio.imread(image_path) - image = torch_em.transform.raw.standardize(image) - pred = predict_with_padding(model, image, min_divisible=(16, 16), device=device).squeeze() - labels = imageio.imread(label_path) - assert image.shape == labels.shape - score = dice_score(pred, labels, threshold_seg=None, threshold_gt=0) - scores.append(score) - results[ct_trg] = np.mean(scores) - return pd.DataFrame(results) - - def run_evaluation(args): results = [] for ct in args.cell_types: - res = _eval_src(args, ct) + res = common.evaluate_source_model(args, ct, "unet_source") results.append(res) results = pd.concat(results) print("Evaluation results:") @@ -95,7 +55,7 @@ def run_evaluation(args): def main(): - parser = get_parser(default_iterations=50000) + parser = common.get_parser(default_iterations=50000) args = parser.parse_args() if args.phase in ("c", "check"): check_loader(args) From a7057c4af19a4e14215a0ef867fef750dc041854 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 14 Mar 2023 21:25:55 +0100 Subject: [PATCH 18/21] Enable self-training for 3d data --- .../livecell/README.md | 6 +++-- .../livecell/unet_adamt.py | 4 ++-- .../livecell/unet_mean_teacher.py | 1 + torch_em/data/raw_dataset.py | 3 +++ torch_em/self_training/logger.py | 23 ++++++++++++++----- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/README.md b/experiments/probabilistic_domain_adaptation/livecell/README.md index fc9e1774..38424eb1 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/README.md +++ b/experiments/probabilistic_domain_adaptation/livecell/README.md @@ -1,2 +1,4 @@ -TODO: double check the unet-src results with the results from Anwai. -Results are in `results/unet_source.csv`. +# Prelim. results + +-> UNet results are a bit worse than from Anwai, double check how the training differs. +-> Mean Teacher improves results. diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py index dd4a7749..f3587d5c 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -3,7 +3,6 @@ import pandas as pd import torch import torch_em.self_training as self_training -from torch_em.model import UNet2d import common @@ -23,7 +22,7 @@ def check_loader(args, n_images=5): def _train_source_target(args, source_cell_type, target_cell_type): - model = UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid") + model = common.get_model() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) @@ -66,6 +65,7 @@ def _train_source_target(args, source_cell_type, target_cell_type): mixed_precision=True, device=device, log_image_interval=100, + save_root=args.save_root, ) trainer.fit(args.n_iterations) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py index 4f6e9dc7..79bc309b 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py @@ -64,6 +64,7 @@ def _train_source_target(args, source_cell_type, target_cell_type): mixed_precision=True, device=device, log_image_interval=100, + save_root=args.save_root, ) trainer.fit(args.n_iterations) diff --git a/torch_em/data/raw_dataset.py b/torch_em/data/raw_dataset.py index ae1ab034..d54fd083 100644 --- a/torch_em/data/raw_dataset.py +++ b/torch_em/data/raw_dataset.py @@ -129,6 +129,9 @@ def __getitem__(self, index): if self.transform is not None: raw = self.transform(raw) + if isinstance(raw, list): + assert len(raw) == 1 + raw = raw[0] if self.trafo_halo is not None: raw = self.crop(raw) diff --git a/torch_em/self_training/logger.py b/torch_em/self_training/logger.py index 83c4babc..3ecbfa9d 100644 --- a/torch_em/self_training/logger.py +++ b/torch_em/self_training/logger.py @@ -17,25 +17,36 @@ def __init__(self, trainer, save_root, **unused_kwargs): self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) self.log_image_interval = trainer.log_image_interval - # TODO deal with 3d data def _add_supervised_images(self, step, name, x, y, pred): + if x.ndim == 5: + assert y.ndim == pred.ndim == 5 + zindex = x.shape[2] // 2 + x, y, pred = x[:, :, zindex], y[:, :, zindex], pred[:, :, zindex] + grid = make_grid( - [torch_em.transform.raw.normalize(x[0]), y[0], pred[0]], + [torch_em.transform.raw.normalize(x[0]), y[0, 0:1], pred[0, 0:1]], padding=8 ) self.tb.add_image(tag=f"{name}/supervised/input-labels-prediction", img_tensor=grid, global_step=step) - # TODO deal with 3d data def _add_unsupervised_images(self, step, name, x1, x2, pred, pseudo_labels, label_filter): - # from torch_em.transform.raw import _normalize_torch + if x1.ndim == 5: + assert x2.ndim == pred.ndim == pseudo_labels.ndim == 5 + zindex = x1.shape[2] // 2 + x1, x2, pred = x1[:, :, zindex], x2[:, :, zindex], pred[:, :, zindex] + pseudo_labels = pseudo_labels[:, :, zindex] + if label_filter is not None: + assert label_filter.ndim == 5 + label_filter = label_filter[:, :, zindex] + images = [ torch_em.transform.raw.normalize(x1[0]), torch_em.transform.raw.normalize(x2[0]), - pred[0], pseudo_labels[0], + pred[0, 0:1], pseudo_labels[0, 0:1], ] im_name = f"{name}/unsupervised/aug1-aug2-prediction-pseudolabels" if label_filter is not None: - images.append(label_filter[0]) + images.append(label_filter[0, 0:1]) name += "-labelfilter" grid = make_grid(images, nrow=2, padding=8) self.tb.add_image(tag=im_name, img_tensor=grid, global_step=step) From d41122dce1969a0005745b345b45d11289fd8af2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 14 Mar 2023 21:41:34 +0100 Subject: [PATCH 19/21] Enable pytorch 2 --- torch_em/util/modelzoo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_em/util/modelzoo.py b/torch_em/util/modelzoo.py index 2f2a1693..5c0273f3 100644 --- a/torch_em/util/modelzoo.py +++ b/torch_em/util/modelzoo.py @@ -118,7 +118,9 @@ def _write_depedencies(export_folder, dependencies): if dependencies is None: ver = torch.__version__ major, minor = list(map(int, ver.split(".")[:2])) - assert major == 1 + assert major in (1, 2) + if major == 2: + warn("Modelzoo functionality is not fully tested for PyTorch 2") # the torch zip layout changed for a few versions: torch_min_version = "1.0" if minor > 6 and minor < 10: @@ -129,7 +131,7 @@ def _write_depedencies(export_folder, dependencies): dependencies = { "channels": ["pytorch", "conda-forge"], "name": "torch-em-deploy", - "dependencies": [f"pytorch>={torch_min_version},<2.0"] + "dependencies": [f"pytorch>={torch_min_version}"] } with open(dep_path, "w") as f: yaml.dump(dependencies, f) From 9e6208e3b69620a7d1651b498dd6fc44036fc1b4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 15 Mar 2023 09:17:01 +0100 Subject: [PATCH 20/21] Bump python versions in CI and use pytorch 2 in the env files --- .github/workflows/test.yaml | 2 +- environment_cpu.yaml | 2 +- environment_gpu.yaml | 3 ++- test/util/test_modelzoo.py | 18 +++++++++++++++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7a1c8b2e..abe6c349 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,7 +11,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: [3.8, 3.9] + python-version: [3.9, 3.10] steps: - name: Checkout diff --git a/environment_cpu.yaml b/environment_cpu.yaml index 061986bd..e93d7320 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -8,7 +8,7 @@ dependencies: - bioimageio.core >=0.5.0 - cpuonly - python-elf - - pytorch + - pytorch >=2.0 - tensorboard - tifffile - torchvision diff --git a/environment_gpu.yaml b/environment_gpu.yaml index 4aa5c5ef..6b2fdf6d 100644 --- a/environment_gpu.yaml +++ b/environment_gpu.yaml @@ -8,7 +8,8 @@ dependencies: - affogato - bioimageio.core >=0.5.0 - python-elf - - pytorch-cuda=11.6 # you may need to update the pytorch version to match your system + - pytorch >=2.0 + - pytorch-cuda>=11.7 # you may need to update the pytorch version to match your system - tensorboard - tifffile - torchvision diff --git a/test/util/test_modelzoo.py b/test/util/test_modelzoo.py index 78ca700e..38f2a826 100644 --- a/test/util/test_modelzoo.py +++ b/test/util/test_modelzoo.py @@ -11,6 +11,11 @@ from torch_em.model import UNet2d from torch_em.trainer import DefaultTrainer +try: + import onnx +except ImportError: + onnx = None + class ExpandChannels: def __init__(self, n_channels): @@ -89,14 +94,21 @@ def test_export_single_channel(self): def test_export_multi_channel(self): self._test_export(4) - def test_add_weights(self): + def test_add_weights_torchscript(self): from torch_em.util.modelzoo import add_weight_formats self._test_export(1) - additional_formats = ["onnx", "torchscript"] + additional_formats = ["torchscript"] add_weight_formats(self.save_folder, additional_formats) - self.assertTrue(os.path.join(self.save_folder, "weigths.onnx")) self.assertTrue(os.path.join(self.save_folder, "weigths-torchscript.pt")) + @unittest.skipIf(onnx is None, "Needs onnx") + def test_add_weights_onnx(self): + from torch_em.util.modelzoo import add_weight_formats + self._test_export(1) + additional_formats = ["onnx"] + add_weight_formats(self.save_folder, additional_formats) + self.assertTrue(os.path.join(self.save_folder, "weigths.onnx")) + if __name__ == "__main__": unittest.main() From a785af9bf436b7bb838c48980442d8a4bfb9b537 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 15 Mar 2023 09:21:21 +0100 Subject: [PATCH 21/21] Fix python version names in CI --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index abe6c349..355f7bc5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -11,7 +11,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: [3.9, 3.10] + python-version: ["3.9", "3.10"] steps: - name: Checkout