diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f02beab0f..29cff8fc46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed +- WinCLIP: set device in text embedding collection and apply forward pass with no grad, by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/1984 - 🔨 Move all export functionalities to AnomalyModule as base methods by @thinhngo-x in () - Remove unnecessary jsonargparse dependencies by @davnn in () - Use default model-specific eval transform when only train_transform specified by @djdameln(https://github.com/djdameln) in () diff --git a/src/anomalib/models/image/winclip/torch_model.py b/src/anomalib/models/image/winclip/torch_model.py index e8dbc10b51..5c69853db6 100644 --- a/src/anomalib/models/image/winclip/torch_model.py +++ b/src/anomalib/models/image/winclip/torch_model.py @@ -222,6 +222,7 @@ def _get_window_embeddings(self, feature_map: torch.Tensor, masks: torch.Tensor) return pooled.reshape((n_masks, batch_size, -1)).permute(1, 0, 2) + @torch.no_grad def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Forward-pass through the model to obtain image and pixel scores. @@ -314,6 +315,7 @@ def _compute_few_shot_scores( return torch.stack(multi_scale_scores).mean(dim=0) + @torch.no_grad def _collect_text_embeddings(self, class_name: str) -> None: """Collect text embeddings for the object class using a compositional prompt ensemble. @@ -325,15 +327,16 @@ def _collect_text_embeddings(self, class_name: str) -> None: Args: class_name (str): The name of the object class used in the prompt ensemble. """ + # get the device, this is to ensure that we move the text embeddings to the same device as the model + device = next(self.parameters()).device # collect prompt ensemble normal_prompts, anomalous_prompts = create_prompt_ensemble(class_name) # tokenize prompts normal_tokens = tokenize(normal_prompts) anomalous_tokens = tokenize(anomalous_prompts) # encode tokens to obtain prompt embeddings - with torch.no_grad(): - normal_embeddings = self.clip.encode_text(normal_tokens) - anomalous_embeddings = self.clip.encode_text(anomalous_tokens) + normal_embeddings = self.clip.encode_text(normal_tokens.to(device)) + anomalous_embeddings = self.clip.encode_text(anomalous_tokens.to(device)) # average prompt embeddings normal_embeddings = torch.mean(normal_embeddings, dim=0, keepdim=True) anomalous_embeddings = torch.mean(anomalous_embeddings, dim=0, keepdim=True) @@ -341,14 +344,14 @@ def _collect_text_embeddings(self, class_name: str) -> None: text_embeddings = torch.cat((normal_embeddings, anomalous_embeddings)) self._text_embeddings = text_embeddings + @torch.no_grad def _collect_visual_embeddings(self, images: torch.Tensor) -> None: """Collect visual embeddings based on a set of normal reference images. Args: images (torch.Tensor): Tensor of shape ``(K, C, H, W)`` containing the reference images. """ - with torch.no_grad(): - _, self._visual_embeddings, self._patch_embeddings = self.encode_image(images) + _, self._visual_embeddings, self._patch_embeddings = self.encode_image(images) def _generate_masks(self) -> list[torch.Tensor]: """Prepare a set of masks that operate as multi-scale sliding windows.