Skip to content

Commit

Permalink
Skip manual backward for cdist with case p=2 (pytorch#31167)
Browse files Browse the repository at this point in the history
Summary:
Fixes an issue with `cdist` backward calculation for large inputs for the euclidean case.

The grid size when launching the kernel exceeded the 2^16 limit for the second dimension, resulting in `RuntimeError: CUDA error: invalid configuration argument`

Code to reproduce:

```
h, w, d = 800, 1216, 12
n = 133
A = torch.randn(n, d).cuda()
B = torch.randn(h, w, d).cuda()
A.requires_grad = True
B.requires_grad = True

B = B.reshape(-1, d).contiguous()
dist = torch.cdist(A, B)
loss = dist.sum()
loss.backward()
```

Thanks to tkerola for the bug report, reproduction and suggesting a solution.
Pull Request resolved: pytorch#31167

Differential Revision: D20035605

Pulled By: ngimel

fbshipit-source-id: ae28ba4b549ee07a8bd937bb1de2438dc24eaa17
  • Loading branch information
Emilio Castillo authored and facebook-github-bot committed Feb 26, 2020
1 parent 9a5ea71 commit a836c4c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
22 changes: 22 additions & 0 deletions aten/src/ATen/native/Distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,28 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
}

Tensor cdist(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
auto maybe_outnames = namedinference::compute_cdist_outnames(x1, x2);
auto result = [&]() {
NoNamesGuard guard;
// This is for pytorch to figure the backward pass itself
// when p=2
int64_t r1 = x1.size(-2);
int64_t r2 = x2.size(-2);
int64_t mode = compute_mode.value_or(0);
if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
return cdist_impl(x1, x2, p, compute_mode);
} else {
return at::_cdist_forward(x1, x2, p, compute_mode);
}
}();
namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}

Tensor _cdist_forward(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2135,6 +2135,9 @@
use_c10_dispatcher: full

- func: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor
supports_named_tensor: True

- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
use_c10_dispatcher: full
supports_named_tensor: True

Expand Down
19 changes: 19 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4136,12 +4136,31 @@ def f(a, b):
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
run_functional_checks(self, "test_cdist", "cdist", f,
True, f_args_variable, f_args_tensor)

def _test_euclidean_large_cdist(sizex, sizey=None):
if sizey is None:
sizey = sizex
x = torch.randn(sizex, device=device, dtype=torch.float)
y = torch.randn(sizey, device=device, dtype=torch.float)
eps = 1e-6
# to avoid extremum
x = x - (((x - y) < eps).float() * 2 * eps)
x.requires_grad = True
y.requires_grad = True
f_args_variable = (x, y)
dist = torch.cdist(x, y, p=2)
# Do a backward pass to check that it is valid for large
# matrices
loss = dist.sum()
loss.backward()

_test_cdist_for_size((S, S))
_test_cdist_for_size((S, S, S))
_test_cdist_for_size((3, 5))
_test_cdist_for_size((2, 3, 5))
_test_cdist_for_size((1, 2, 3))
_test_cdist_for_size((1, 1), (S, 1))
_test_euclidean_large_cdist((2000, 5))


# NOTE: flaky on ROCm CI
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@
self: not_implemented("_pdist_backward")
pdist: not_implemented("_pdist_backward")

- name: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor
- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
x1: _cdist_backward(grad.contiguous(), x1, x2, p, result)
x2: _cdist_backward(grad.transpose(-1, -2).contiguous(), x2, x1, p, result.transpose(-1, -2).contiguous())

Expand Down

0 comments on commit a836c4c

Please sign in to comment.