Skip to content

Commit

Permalink
Improve torch.flatten docs and add tests to test_view_ops (pytorch#…
Browse files Browse the repository at this point in the history
…49501)

Summary:
Addresses pytorch#39474

Pull Request resolved: pytorch#49501

Reviewed By: mruberry

Differential Revision: D25734450

Pulled By: soulitzer

fbshipit-source-id: 993667dd07acd81a4616465e0a3b94bde449193e
  • Loading branch information
soulitzer authored and Brandon Lin committed Jan 4, 2021
1 parent 0b1c50a commit 5b915f2
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
63 changes: 63 additions & 0 deletions test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def is_view_of(self, base, other):

return True

# Returns true if v1 and v2 are views of the same base
def is_view_of_same_base(self, v1, v2):
if (not v1._is_view() or v1 is v2):
return False
return self.is_view_of(v1._base, v2)

# Performs transpose if contiguous=True, else returns the input tensor as is
def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
if contiguous:
Expand Down Expand Up @@ -457,6 +463,63 @@ def test_reshape_nonview(self, device):
nv[6] = 0
self.assertNotEqual(t[1, 1], nv[6])

def test_flatten_view(self, device):
def test_writes_propagate(t, v):
idx_t = (0,) * t.ndim
idx_v = (0,) * v.ndim
v[idx_v] = 0
self.assertEqual(t[idx_t], v[idx_v])

t = torch.ones(1, 2, 3, 4, device=device)
v = t.flatten()
self.assertTrue(self.is_view_of(t, v))
test_writes_propagate(t, v)

# zero-dimensional tensor
t = torch.tensor(1, device=device)
v = t.flatten()
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of(t, v))

t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
v = t.flatten(0, 1)
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of_same_base(t, v))

# stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
t = torch.ones(720, device=device) \
.as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
# [--1--|---2---|-3-] [--1--|----2---|-3-]
v1 = t.flatten(0, 1)
v2 = v1.flatten(1, 3)
v3 = v2.flatten(2, 2)
test_writes_propagate(t, v1)
self.assertTrue(self.is_view_of_same_base(t, v1))
test_writes_propagate(t, v2)
self.assertTrue(self.is_view_of_same_base(t, v2))
test_writes_propagate(t, v3)
self.assertTrue(self.is_view_of_same_base(t, v3))

def test_flatten_nonview(self, device):
def assert_is_nonview(t, nv):
idx_t = (0,) * t.ndim
idx_nv = (0,) * nv.ndim
self.assertTrue(not nv._is_view())
nv[idx_nv] = 0
self.assertNotEqual(t[idx_t], nv[idx_nv])
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
nv = t.flatten(1, 3)
assert_is_nonview(t, nv)

t = torch.ones(2, 2, device=device).T
nv = t.flatten()
assert_is_nonview(t, nv)

# flatten returns the original object if start_dim=end_dim
t = t = torch.ones(2, 2, device=device)
nv = t.flatten(1, 1)
self.assertTrue(t is nv)

def test_basic_indexing_slice_view(self, device):
t = torch.ones(5, 5, device=device)
v = t[:2, :3]
Expand Down
12 changes: 11 additions & 1 deletion torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3095,7 +3095,17 @@ def merge_dicts(*dicts):
r"""
flatten(input, start_dim=0, end_dim=-1) -> Tensor
Flattens a contiguous range of dims in a tensor.
Flattens :attr:`input` by reshaping it into a one-dimensional tensor. If :attr:`start_dim` or :attr:`end_dim`
are passed, only dimensions starting with :attr:`start_dim` and ending with :attr:`end_dim` are flattened.
The order of elements in :attr:`input` is unchanged.
Unlike NumPy's flatten, which always copies input's data, this function may return the original object, a view,
or copy. If no dimensions are flattened, then the original object :attr:`input` is returned. Otherwise, if input can
be viewed as the flattened shape, then that view is returned. Finally, only if the input cannot be viewed as the
flattened shape is input's data copied. See :meth:`torch.Tensor.view` for details on when a view will be returned.
.. note::
Flattening a zero-dimensional tensor will return a one-dimensional view.
Args:
{input}
Expand Down

0 comments on commit 5b915f2

Please sign in to comment.