Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace cdist in Patchcore #1267

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,28 @@ def subsample_embedding(self, embedding: Tensor, sampling_ratio: float) -> None:
coreset = sampler.sample_coreset()
self.memory_bank = coreset

@staticmethod
def euclidean_dist(x: Tensor, y: Tensor) -> Tensor:
"""
Calculates pair-wise distance between row vectors in x and those in y.

Replaces torch cdist with p=2, as cdist is not properly exported to onnx and openvino format.
Resulting matrix is indexed by x vectors in rows and y vectors in columns.

Args:
x: input tensor 1
y: input tensor 2

Returns:
Matrix of distances between row vectors in x and y.
"""
x_norm = x.pow(2).sum(dim=-1, keepdim=True) # |x|
y_norm = y.pow(2).sum(dim=-1, keepdim=True) # |y|
# row distance can be rewritten as sqrt(|x| - 2 * x @ y.T + |y|.T)
res = x_norm - 2 * torch.matmul(x, y.transpose(-2, -1)) + y_norm.transpose(-2, -1)
res = res.clamp_min_(0).sqrt_()
return res

def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> tuple[Tensor, Tensor]:
"""Nearest Neighbours using brute force method and euclidean norm.

Expand All @@ -153,7 +175,7 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> tuple[Tensor
Tensor: Patch scores.
Tensor: Locations of the nearest neighbor(s).
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
distances = self.euclidean_dist(embedding, self.memory_bank)
if n_neighbors == 1:
# when n_neighbors is 1, speed up computation by using min instead of topk
patch_scores, locations = distances.min(1)
Expand Down Expand Up @@ -188,7 +210,7 @@ def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embeddi
# indices of N_b(m^*) in the paper
_, support_samples = self.nearest_neighbors(nn_sample, n_neighbors=self.num_neighbors)
# 4. Find the distance of the patch features to each of the support samples
distances = torch.cdist(max_patches_features.unsqueeze(1), self.memory_bank[support_samples], p=2.0)
distances = self.euclidean_dist(max_patches_features.unsqueeze(1), self.memory_bank[support_samples])
# 5. Apply softmax to find the weights
weights = (1 - F.softmax(distances.squeeze(1), 1))[..., 0]
# 6. Apply the weight factor to the score
Expand Down