Skip to content

Commit

Permalink
[PyTorch][Tensor] Introduce tensor.dim_order (pytorch#106835)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#106835

This is a stride based attribute for a tensor available in Python.

This can help inspect tensors generated using `torch.empty_permuted(.., physical_layout, ...)`, where physical_layout should match the dim_order returned here. `empty_permuted` will be renamed to use dim_order as the param name in the future. And also help Executorch export pipeline with implementing dim_order based tensors.

Differential Revision: D48134476

fbshipit-source-id: 12855d49038e76e27714cae554d3730bcf83f5d7
  • Loading branch information
digantdesai authored and facebook-github-bot committed Aug 9, 2023
1 parent c379d62 commit dcf26e8
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7734,6 +7734,20 @@ def test_helper(dim1, dim2, memory_format):
test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)

def test_dim_order(self):
shape = (2, 3, 5, 7)

self.assertSequenceEqual(torch.empty(shape).dim_order(), (0, 1, 2, 3), seq_type=tuple)
self.assertSequenceEqual(torch.empty(shape).transpose(0, 1).dim_order(), (1, 0, 2, 3))
self.assertSequenceEqual(torch.empty(shape, memory_format=torch.channels_last).dim_order(), (0, 2, 3, 1))
self.assertSequenceEqual(torch.empty((2, 3, 5, 7, 8), memory_format=torch.channels_last_3d).dim_order(), (0, 2, 3, 4, 1))

for dim_order in itertools.permutations(range(4)):
self.assertSequenceEqual(dim_order, torch.empty_permuted(shape, dim_order).dim_order())

self.assertRaises(RuntimeError, lambda: torch.empty(shape).to_sparse().dim_order())
self.assertRaises(RuntimeError, lambda: torch.empty(shape).to_mkldnn().dim_order())

def test_subclass_tensors(self):
# raise an error when trying to subclass FloatTensor
with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
Expand Down
48 changes: 48 additions & 0 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,54 @@ def to_sparse_coo(self):
"""
return self.to_sparse()

def dim_order(self):
"""
dim_order() -> tuple
Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
Args:
None
Dim order represents how dimensions are laid out in memory,
starting from the outermost to the innermost dimension.
Thus, the conversion from strides is done by sorting the strides
from larger to smaller since the dimension with the largest stride
is the outermost and the dimension with the smallest stride is the innermost.
For example, tensor with sizes = (3, 5, 2) and strides = (5, 1, 15), implies
physical layout of (2, 0, 1). Dimension order of (2, 0, 1) can be obtained
by sorting strides from large to smaller.
When strides do not convey dim order unambiguously, returned value is dependent
on stability of sort. In python same key elements are kept
in original order. Thus when strides = (4, 3, 1, 1) returned value is (0, 1, 2, 3)
Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned
value is (0, 2, 3, 1)
Example::
>>> torch.empty((2, 3, 5, 7)).dim_order()
(0, 1, 2, 3)
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
(0, 2, 3, 1)
.. warning::
The dim_order tensor API is experimental and subject to change.
"""
if self.layout == torch.strided:
return tuple(
[
i[0]
for i in sorted(
enumerate(self.stride()), key=lambda x: x[1], reverse=True
)
]
)
raise RuntimeError("dim_order is only supported for strided tensors.")

def _update_names(self, names, inplace):
if has_torch_function_unary(self):
return handle_torch_function(
Expand Down
2 changes: 2 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12436,6 +12436,8 @@ def merge_dicts(*dicts):
(105, 1, 21, 3)
>>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride()
(105, 1, 21, 3)
>>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order()
(0, 2, 3, 1)
""".format(
**factory_common_args
),
Expand Down

0 comments on commit dcf26e8

Please sign in to comment.