diff --git a/test/test_classification.py b/test/test_classification.py index 1cf7df50..dde749e8 100644 --- a/test/test_classification.py +++ b/test/test_classification.py @@ -51,7 +51,8 @@ def test_classification_2d(self): trainer.fit(n_iterations) self._check_checkpoint( - "./checkpoints/test-model-2d/latest.pt", 18, trainer.model, resnet18, num_classes=n_classes + "./checkpoints/test-model-2d/latest.pt", 18, trainer.model, resnet18, num_classes=n_classes, + compile_model=False, ) def test_classification_3d(self): @@ -72,6 +73,7 @@ def test_classification_3d(self): model = resnet3d_18(in_channels=1, out_channels=n_classes) trainer = default_classification_trainer( name="test-model-3d", model=model, train_loader=loader, val_loader=loader, + compile_model=False, ) trainer.fit(12)