From 023cc8f51a90a4c20329bdc67a7c3a3310db5936 Mon Sep 17 00:00:00 2001 From: Eunwoo Shin Date: Thu, 11 Apr 2024 20:36:15 +0900 Subject: [PATCH] Fix a bug that engine.test doesn't work with XPU (#3293) * fix bug * align with pre-commit --------- Co-authored-by: Emily --- src/otx/algo/strategies/xpu_single.py | 12 ++++++++---- src/otx/engine/engine.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/otx/algo/strategies/xpu_single.py b/src/otx/algo/strategies/xpu_single.py index aa1ecf3d559..4b9501dd36f 100644 --- a/src/otx/algo/strategies/xpu_single.py +++ b/src/otx/algo/strategies/xpu_single.py @@ -53,13 +53,17 @@ def is_distributed(self) -> bool: def setup_optimizers(self, trainer: pl.Trainer) -> None: """Sets up optimizers.""" super().setup_optimizers(trainer) - if len(self.optimizers) != 1: # type: ignore[has-type] + if len(self.optimizers) > 1: # type: ignore[has-type] msg = "XPU strategy doesn't support multiple optimizers" raise RuntimeError(msg) if trainer.task != "SEMANTIC_SEGMENTATION": - model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore[has-type] - self.optimizers = [optimizer] - self.model = model + if len(self.optimizers) == 1: # type: ignore[has-type] + model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore[has-type] + self.optimizers = [optimizer] + self.model = model + else: # for inference + trainer.model.eval() + self.model = torch.xpu.optimize(trainer.model) StrategyRegistry.register( diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 6f6240cad4e..cf220b23d8b 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -886,7 +886,7 @@ def _build_trainer(self, **kwargs) -> None: if self._device.accelerator == DeviceType.xpu: self._cache.update(strategy="xpu_single") # add plugin for Automatic Mixed Precision on XPU - if self._cache.args["precision"] == 16: + if self._cache.args.get("precision", 32) == 16: self._cache.update(plugins=[MixedPrecisionXPUPlugin()]) self._cache.args["precision"] = None