Skip to content

Commit

Permalink
[PyTorch][Tensor] Introduce tensor.dim_order (pytorch#106835)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#106835

This is a stride based attribute for a tensor available in Python.

This can help inspect tensors generated using `torch.empty_permuted(.., physical_layout, ...)`, where physical_layout should match the dim_order returned here. `empty_permuted` will be renamed to use dim_order as the param name in the future. And also help Executorch export pipeline with implementing dim_order based tensors.

Differential Revision: D48134476

fbshipit-source-id: 484a3b9fd4bc62d096f66b9f8a4d3179c7930c2c
  • Loading branch information
digantdesai authored and facebook-github-bot committed Aug 22, 2023
1 parent 5025fb9 commit c3d8a1d
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ Tensor class reference
Tensor.digamma
Tensor.digamma_
Tensor.dim
Tensor.dim_order
Tensor.dist
Tensor.div
Tensor.div_
Expand Down
66 changes: 66 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7735,6 +7735,72 @@ def test_helper(dim1, dim2, memory_format):
test_helper((3, 3), (3, 3, 3, 3), torch.channels_last)
test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d)

def _check_dim_order_against_prim_impl(self, tensor):
import torch._prims_common as prim_utils

a = tuple(
prim_utils.compute_elementwise_output_logical_to_physical_perm(tensor, True)
)
b = tensor.dim_order()
self.assertSequenceEqual(a, b, seq_type=tuple)

def test_dim_order(self):
shape = (2, 3, 5, 7)

t = torch.empty(shape)
self.assertSequenceEqual(t.dim_order(), (0, 1, 2, 3), seq_type=tuple)
# transpose doesn't really change the underlying physical memory
# so execpting dim_order change to reflect that (like strides)
self.assertSequenceEqual(t.transpose(0, 1).dim_order(), (1, 0, 2, 3))
self._check_dim_order_against_prim_impl(t)

t = torch.empty(shape, memory_format=torch.channels_last)
self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 1))
self._check_dim_order_against_prim_impl(t)

t = torch.empty((2, 3, 5, 7, 8), memory_format=torch.channels_last_3d)
self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 4, 1))
self._check_dim_order_against_prim_impl(t)

for dim_order in itertools.permutations(range(4)):
self.assertSequenceEqual(
dim_order, torch.empty_permuted(shape, dim_order).dim_order()
)

self.assertRaises(
RuntimeError, lambda: torch.empty(shape).to_sparse().dim_order()
)
self.assertRaises(
RuntimeError, lambda: torch.empty(shape).to_mkldnn().dim_order()
)

def test_dim_order_ambiguous_cases(self):
K = 6
for ndim in range(K):
# generate a ndim array with ~same number of 1s and 2s e.g. [1, 1, 2, 2]
random_shape = random.sample([1, 2] * ndim, k=ndim)

for shape in set(itertools.permutations(random_shape)):

# cases with same strides, and size == 1
t = torch.empty(shape)
self._check_dim_order_against_prim_impl(torch.empty(shape))

# cases with zero strides
self._check_dim_order_against_prim_impl(
torch.broadcast_to(torch.empty(shape), tuple([2] * K))
)

# cases with other memory formats than contiguous
if ndim == 4:
self._check_dim_order_against_prim_impl(
torch.empty(shape, memory_format=torch.channels_last)
)
elif ndim == 5:
self._check_dim_order_against_prim_impl(
torch.empty(shape, memory_format=torch.channels_last_3d)
)

def test_subclass_tensors(self):
# raise an error when trying to subclass FloatTensor
with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):
Expand Down
1 change: 0 additions & 1 deletion torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ def should_swap(idx_a, idx_b):
#
# also, note this returns the logical to physical shape permutation
perm = list(reversed(range(ndim)))

# insertion sort with support for ambiguous comparisons
for i in range(1, ndim):
dim1 = i
Expand Down
91 changes: 91 additions & 0 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,97 @@ def to_sparse_coo(self):
"""
return self.to_sparse()

def dim_order(self):
"""
dim_order() -> tuple
Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
Args:
None
Dim order represents how dimensions are laid out in memory,
starting from the outermost to the innermost dimension.
Thus, the conversion from strides is done by sorting the strides
from larger to smaller since the dimension with the largest stride
is the outermost and the dimension with the smallest stride is the innermost.
For example, tensor with sizes = (3, 5, 2) and strides = (5, 1, 15), implies
physical layout of (2, 0, 1). Dimension order of (2, 0, 1) can be obtained
by sorting strides from large to smaller.
Example::
>>> torch.empty((2, 3, 5, 7)).dim_order()
(0, 1, 2, 3)
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
(0, 2, 3, 1)
.. warning::
The dim_order tensor API is experimental and subject to change.
"""

# NB:
# - Based on the implementation in TensorIterator.cpp
# - Should have similar behavior (esp in ambiguous cases) to,
# torch._prims_common.compute_elementwise_output_logical_to_physical_perm

if self.layout != torch.strided:
raise RuntimeError("dim_order is only supported for strided tensors.")

ndim = self.ndim

if ndim == 0:
return tuple()

if ndim == 1:
return tuple([0])

if self.is_contiguous(memory_format=torch.contiguous_format):
return tuple(range(ndim))

shape = self.shape

def should_swap(idx_a, idx_b):
stride_a = self.stride()[idx_a]
stride_b = self.stride()[idx_b]

if stride_a == 0 or stride_b == 0:
return 0

if stride_a < stride_b:
return -1

if stride_a > stride_b:
return 1

# stride_a == stride_b
if shape[idx_a] > shape[idx_b]:
return 1

# Note: this case is hit when all strides are non-zero and equal, and all
# dimensions have the same length
return 0

# The "sort" order for the permutation is back-to-front, but
# the natural order for permutations is front-to-back. Do the
# sorting back-to-front and then reverse it on output.
perm = list(reversed(range(ndim)))
# insertion sort with support for ambiguous comparisons
for i in range(1, ndim):
dim1 = i
for dim0 in reversed(range(i)):
comparison = should_swap(perm[dim0], perm[dim1])
if comparison > 0:
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
dim1 = dim0
elif comparison < 0:
break

return tuple(reversed(perm))

def _update_names(self, names, inplace):
if has_torch_function_unary(self):
return handle_torch_function(
Expand Down
2 changes: 2 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12437,6 +12437,8 @@ def merge_dicts(*dicts):
(105, 1, 21, 3)
>>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride()
(105, 1, 21, 3)
>>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).dim_order()
(0, 2, 3, 1)
""".format(
**factory_common_args
),
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._is_any_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
Tensor.dim_order,
}


Expand Down

0 comments on commit c3d8a1d

Please sign in to comment.