Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
vmap support for torch.tril and torch.triu (#94287)
Summary: Add vmap support for torch.tril and torch.triu. Fix: #91403 Test Plan: GitHub pipeline Differential Revision: D43016624 ### Expected behavior Same as using for-loop: ```python import torch x = torch.randn(32, 3) results = [] for xi in x: y = torch.triu(xi) results.append(y) """ triu: input tensor must have at least 2 dimensions --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-7-d726203efb0e> in <module> 4 results = [] 5 for xi in x: ----> 6 y = torch.triu(xi) 7 results.append(y) RuntimeError: triu: input tensor must have at least 2 dimensions """ ``` Pull Request resolved: pytorch/pytorch#94287 Approved by: https://github.com/Skylion007, https://github.com/zou3519
- Loading branch information