diff --git a/test/self_training/test_fix_match.py b/test/self_training/test_fix_match.py index 6d8a0b12..d375db9e 100644 --- a/test/self_training/test_fix_match.py +++ b/test/self_training/test_fix_match.py @@ -35,7 +35,7 @@ def _test_fix_match( unsupervised_loss_and_metric=None, ): model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) - optimizer = torch.optim.Adam(model.parameters()) + optimizer = torch.optim.AdamW(model.parameters()) name = "fm-test" trainer = self_training.FixMatchTrainer( diff --git a/test/self_training/test_mean_teacher.py b/test/self_training/test_mean_teacher.py index 110e5d9e..ee907f48 100644 --- a/test/self_training/test_mean_teacher.py +++ b/test/self_training/test_mean_teacher.py @@ -35,7 +35,7 @@ def _test_mean_teacher( unsupervised_loss_and_metric=None, ): model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) - optimizer = torch.optim.Adam(model.parameters()) + optimizer = torch.optim.AdamW(model.parameters()) name = "mt-test" trainer = self_training.MeanTeacherTrainer( diff --git a/test/trainer/test_default_trainer.py b/test/trainer/test_default_trainer.py index 65251690..d238978a 100644 --- a/test/trainer/test_default_trainer.py +++ b/test/trainer/test_default_trainer.py @@ -47,7 +47,7 @@ def _get_kwargs(self, with_roi=False, compile_model=False): "model": model, "loss": torch_em.loss.DiceLoss(), "metric": torch_em.loss.DiceLoss(), - "optimizer": torch.optim.Adam(model.parameters(), lr=1e-5), + "optimizer": torch.optim.AdamW(model.parameters(), lr=1e-5), "device": torch.device("cpu"), "mixed_precision": False, "compile_model": compile_model, diff --git a/test/trainer/test_spoco_trainer.py b/test/trainer/test_spoco_trainer.py index 195f841f..94c0811d 100644 --- a/test/trainer/test_spoco_trainer.py +++ b/test/trainer/test_spoco_trainer.py @@ -60,7 +60,7 @@ def _get_kwargs(self, with_roi=False): "model": model, "loss": DummySpocoLoss(), "metric": DummySpocoMetric(), - "optimizer": torch.optim.Adam(model.parameters(), lr=1e-5), + "optimizer": torch.optim.AdamW(model.parameters(), lr=1e-5), "device": torch.device("cpu"), "mixed_precision": False, "momentum": 0.95, diff --git a/test/util/test_modelzoo.py b/test/util/test_modelzoo.py index fa0d5f14..e560c094 100644 --- a/test/util/test_modelzoo.py +++ b/test/util/test_modelzoo.py @@ -63,7 +63,7 @@ def _create_checkpoint(self, n_channels): ) model = UNet2d(in_channels=1, out_channels=n_channels, depth=2, initial_features=4, norm="BatchNorm") - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) trainer = DefaultTrainer( name=self.name, train_loader=loader, val_loader=loader, model=model, loss=DiceLoss(), metric=DiceLoss(), diff --git a/torch_em/segmentation.py b/torch_em/segmentation.py index 9ea62efc..b3bf8eb4 100644 --- a/torch_em/segmentation.py +++ b/torch_em/segmentation.py @@ -338,7 +338,7 @@ def default_segmentation_trainer( save_root=None, compile_model=None, ): - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **optimizer_kwargs) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) loss = DiceLoss() if loss is None else loss