Skip to content

Commit

Permalink
Add extra checks to TorchInferencer model and metadata loading (#1350)
Browse files Browse the repository at this point in the history
* Fix metadata path

* Ignore hidden directories in folder dataset

* Add check for mask_dir for segmentation tasks in Folder dataset

* Add extra checks to ensure that torch model has metadata and model information
  • Loading branch information
samet-akcay committed Sep 22, 2023
1 parent 4eaa9be commit 5eb20c6
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions src/anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(
) -> None:
self.device = self._get_device(device)

# Load the model weights.
# Load the model weights, metadata and data transforms.
self.checkpoint = self._load_checkpoint(path)
self.model = self.load_model(path)
self.metadata = self._load_metadata(path)
self.transform = A.from_dict(self.metadata["transform"])
Expand All @@ -60,6 +61,24 @@ def _get_device(device: str) -> torch.device:
device = "cuda"
return torch.device(device)

def _load_checkpoint(self, path: str | Path) -> dict:
"""Load the checkpoint.
Args:
path (str | Path): Path to the torch ckpt file.
Returns:
dict: Dictionary containing the model and metadata.
"""
if isinstance(path, str):
path = Path(path)

if path.suffix not in (".pt", ".pth"):
raise ValueError(f"Unknown torch checkpoint file format {path.suffix}. Make sure you save the Torch model.")

checkpoint = torch.load(path, map_location=self.device)
return checkpoint

def _load_metadata(self, path: str | Path | dict | None = None) -> dict | DictConfig:
"""Load metadata from file.
Expand All @@ -69,20 +88,40 @@ def _load_metadata(self, path: str | Path | dict | None = None) -> dict | DictCo
Returns:
dict: Dictionary containing the metadata.
"""
metadata = torch.load(path, map_location=self.device)["metadata"] if path else {}
metadata: dict | DictConfig

if isinstance(path, dict):
metadata = path
elif isinstance(path, (str, Path)):
checkpoint = self._load_checkpoint(path)

# Torch model should ideally contain the metadata in the checkpoint.
# Check if the metadata is present in the checkpoint.
if "metadata" not in checkpoint.keys():
raise KeyError(
"``metadata`` is not found in the checkpoint. Please ensure that you save the model as Torch model."
)
metadata = checkpoint["metadata"]
else:
raise ValueError(f"Unknown ``path`` type {type(path)}")

return metadata

def load_model(self, path: str | Path) -> nn.Module:
"""Load the PyTorch model.
Args:
path (str | Path): Path to model ckpt file.
path (str | Path): Path to the Torch model.
Returns:
(AnomalyModule): PyTorch Lightning model.
(nn.Module): Torch model.
"""

model = torch.load(path, map_location=self.device)["model"]
checkpoint = self._load_checkpoint(path)
if "model" not in checkpoint.keys():
raise KeyError("``model`` is not found in the checkpoint. Please check the checkpoint file.")

model = checkpoint["model"]
model.eval()
return model.to(self.device)

Expand Down

0 comments on commit 5eb20c6

Please sign in to comment.