Skip to content

Commit

Permalink
using torch.cdist for batch pairwise distance
Browse files Browse the repository at this point in the history
  • Loading branch information
justanhduc committed Nov 27, 2020
1 parent 9c5ec45 commit cbb0c5a
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions neuralnet_pytorch/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,10 @@ def batch_pairwise_sqdist(x: T.Tensor, y: T.Tensor, c_code=cuda_ext_available):
to the output.
:param x:
a tensor of shape ``(m, nx, d)`` or ``(nx, d)``.
a tensor of shape ``(..., nx, d)``.
If the tensor dimension is 2, the tensor batch dim is broadcasted.
:param y:
a tensor of shape ``(m, ny, d)`` or ``(ny, d)``.
a tensor of shape ``(..., ny, d)``.
If the tensor dimension is 2, the tensor batch dim is broadcasted.
:param c_code:
whether to use a C++ implementation.
Expand All @@ -536,14 +536,8 @@ def batch_pairwise_sqdist(x: T.Tensor, y: T.Tensor, c_code=cuda_ext_available):
from ..extensions import batch_pairwise_dist
return batch_pairwise_dist(x, y)
else:
xx = T.sum(x ** 2, -1)
yy = T.sum(y ** 2, -1)
zz = T.matmul(x, y.transpose(-1, -2).contiguous())

rx = xx.unsqueeze(-2).expand_as(zz.transpose(-2, -1))
ry = yy.unsqueeze(-2).expand_as(zz)
P = (rx.transpose(-2, -1) + ry - 2. * zz)
return P
P = T.cdist(x, y)
return P ** 2


def gram_matrix(x: T.Tensor) -> T.Tensor:
Expand Down

0 comments on commit cbb0c5a

Please sign in to comment.