Skip to content

Commit

Permalink
Replace cdist in Patchcore (#1267)
Browse files Browse the repository at this point in the history
  • Loading branch information
blaz-r committed Aug 14, 2023
1 parent f2c5a01 commit 4f51177
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,28 @@ def subsample_embedding(self, embedding: Tensor, sampling_ratio: float) -> None:
coreset = sampler.sample_coreset()
self.memory_bank = coreset

@staticmethod
def euclidean_dist(x: Tensor, y: Tensor) -> Tensor:
"""
Calculates pair-wise distance between row vectors in x and those in y.
Replaces torch cdist with p=2, as cdist is not properly exported to onnx and openvino format.
Resulting matrix is indexed by x vectors in rows and y vectors in columns.
Args:
x: input tensor 1
y: input tensor 2
Returns:
Matrix of distances between row vectors in x and y.
"""
x_norm = x.pow(2).sum(dim=-1, keepdim=True) # |x|
y_norm = y.pow(2).sum(dim=-1, keepdim=True) # |y|
# row distance can be rewritten as sqrt(|x| - 2 * x @ y.T + |y|.T)
res = x_norm - 2 * torch.matmul(x, y.transpose(-2, -1)) + y_norm.transpose(-2, -1)
res = res.clamp_min_(0).sqrt_()
return res

def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> tuple[Tensor, Tensor]:
"""Nearest Neighbours using brute force method and euclidean norm.
Expand All @@ -153,7 +175,7 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> tuple[Tensor
Tensor: Patch scores.
Tensor: Locations of the nearest neighbor(s).
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
distances = self.euclidean_dist(embedding, self.memory_bank)
if n_neighbors == 1:
# when n_neighbors is 1, speed up computation by using min instead of topk
patch_scores, locations = distances.min(1)
Expand Down Expand Up @@ -188,7 +210,7 @@ def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embeddi
# indices of N_b(m^*) in the paper
_, support_samples = self.nearest_neighbors(nn_sample, n_neighbors=self.num_neighbors)
# 4. Find the distance of the patch features to each of the support samples
distances = torch.cdist(max_patches_features.unsqueeze(1), self.memory_bank[support_samples], p=2.0)
distances = self.euclidean_dist(max_patches_features.unsqueeze(1), self.memory_bank[support_samples])
# 5. Apply softmax to find the weights
weights = (1 - F.softmax(distances.squeeze(1), 1))[..., 0]
# 6. Apply the weight factor to the score
Expand Down

0 comments on commit 4f51177

Please sign in to comment.