Skip to content

Commit

Permalink
vmap support for torch.tril and torch.triu (#94287)
Browse files Browse the repository at this point in the history
Summary:
Add vmap support for torch.tril and torch.triu.

Fix: #91403

Test Plan: GitHub pipeline

Differential Revision: D43016624

### Expected behavior
Same as using for-loop:

```python
import torch

x = torch.randn(32, 3)
results = []
for xi in x:
  y = torch.triu(xi)
  results.append(y)
"""
triu: input tensor must have at least 2 dimensions
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-d726203efb0e> in <module>
      4 results = []
      5 for xi in x:
----> 6   y = torch.triu(xi)
      7   results.append(y)
RuntimeError: triu: input tensor must have at least 2 dimensions
"""
```

Pull Request resolved: pytorch/pytorch#94287
Approved by: https://github.com/Skylion007, https://github.com/zou3519
  • Loading branch information
isdanni authored and cyyever committed Mar 27, 2023
1 parent 47eb184 commit e9e3f67
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
24 changes: 22 additions & 2 deletions aten/src/ATen/functorch/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,12 +563,32 @@ Tensor trace_decomp(const Tensor& tensor) {
return tensor.diagonal().sum();
}

std::tuple<Tensor,optional<int64_t>> tril_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t diagonal = 0) {
TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::tril(self_, diagonal);
return std::make_tuple(std::move(result), 0);
}

std::tuple<Tensor,optional<int64_t>> triu_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t diagonal = 0) {
TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::triu(self_, diagonal);
return std::make_tuple(std::move(result), 0);
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
VMAP_SUPPORT(flip, flip_batch_rule);
m.impl("trace", trace_decomp);
VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));
VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));
VMAP_SUPPORT(tril, tril_batch_rule);
VMAP_SUPPORT(triu, triu_batch_rule);
VMAP_SUPPORT(repeat, repeat_batch_rule);
VMAP_SUPPORT(_unsafe_view, _unsafe_view_batch_rule);
VMAP_SUPPORT(unsqueeze, unsqueeze_batch_rule);
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3796,8 +3796,6 @@ def test_op_has_batch_rule(self, device, dtype, op):
'scatter',
'square',
'sub',
'tril',
'triu',
'trunc',
'xlogy',
)
Expand Down

0 comments on commit e9e3f67

Please sign in to comment.