Skip to content

Commit

Permalink
WinCLIP improvements (#1985)
Browse files Browse the repository at this point in the history
* text embedding device handling and forward with no grad

* update changelog

* use no_grad decorator

* use decorator for no_grad

---------

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
djdameln and samet-akcay committed May 16, 2024
1 parent 3472a12 commit 28e023e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Changed

- WinCLIP: set device in text embedding collection and apply forward pass with no grad, by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/1984
- 🔨 Move all export functionalities to AnomalyModule as base methods by @thinhngo-x in (<https://github.com/openvinotoolkit/anomalib/pull/1803>)
- Remove unnecessary jsonargparse dependencies by @davnn in (<https://github.com/openvinotoolkit/anomalib/pull/2046>)
- Use default model-specific eval transform when only train_transform specified by @djdameln(https://github.com/djdameln) in (<https://github.com/openvinotoolkit/anomalib/pull/1953>)
Expand Down
13 changes: 8 additions & 5 deletions src/anomalib/models/image/winclip/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _get_window_embeddings(self, feature_map: torch.Tensor, masks: torch.Tensor)

return pooled.reshape((n_masks, batch_size, -1)).permute(1, 0, 2)

@torch.no_grad
def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward-pass through the model to obtain image and pixel scores.
Expand Down Expand Up @@ -314,6 +315,7 @@ def _compute_few_shot_scores(

return torch.stack(multi_scale_scores).mean(dim=0)

@torch.no_grad
def _collect_text_embeddings(self, class_name: str) -> None:
"""Collect text embeddings for the object class using a compositional prompt ensemble.
Expand All @@ -325,30 +327,31 @@ def _collect_text_embeddings(self, class_name: str) -> None:
Args:
class_name (str): The name of the object class used in the prompt ensemble.
"""
# get the device, this is to ensure that we move the text embeddings to the same device as the model
device = next(self.parameters()).device
# collect prompt ensemble
normal_prompts, anomalous_prompts = create_prompt_ensemble(class_name)
# tokenize prompts
normal_tokens = tokenize(normal_prompts)
anomalous_tokens = tokenize(anomalous_prompts)
# encode tokens to obtain prompt embeddings
with torch.no_grad():
normal_embeddings = self.clip.encode_text(normal_tokens)
anomalous_embeddings = self.clip.encode_text(anomalous_tokens)
normal_embeddings = self.clip.encode_text(normal_tokens.to(device))
anomalous_embeddings = self.clip.encode_text(anomalous_tokens.to(device))
# average prompt embeddings
normal_embeddings = torch.mean(normal_embeddings, dim=0, keepdim=True)
anomalous_embeddings = torch.mean(anomalous_embeddings, dim=0, keepdim=True)
# concatenate and store
text_embeddings = torch.cat((normal_embeddings, anomalous_embeddings))
self._text_embeddings = text_embeddings

@torch.no_grad
def _collect_visual_embeddings(self, images: torch.Tensor) -> None:
"""Collect visual embeddings based on a set of normal reference images.
Args:
images (torch.Tensor): Tensor of shape ``(K, C, H, W)`` containing the reference images.
"""
with torch.no_grad():
_, self._visual_embeddings, self._patch_embeddings = self.encode_image(images)
_, self._visual_embeddings, self._patch_embeddings = self.encode_image(images)

def _generate_masks(self) -> list[torch.Tensor]:
"""Prepare a set of masks that operate as multi-scale sliding windows.
Expand Down

0 comments on commit 28e023e

Please sign in to comment.