Reshaping, stacking, squeezing and unsqueezing tensors

- Reshaping - reshapes an input tensor to a defined shape
- View - return a view of an input tensor of certain shape but keep the same memory as the original tensor
- Stacking - combine multiple tensors on top of each other(vstack) or side by side(hstack)
- Squeeze - removes all `1` dimensions from a tensor
- Unsqueeze - add a `1` dimension to a target tensor
- Permute - return a view of the input with dimensions permuted(swapped) in a certain way

In [1]:
import torch
x = torch.arange(1.,10.)
x, x.shape

(tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.]), torch.Size([9]))

In [2]:
# Add an extra dimension to the tensor
x_reshaped = x.reshape(1, 9)
# Print the reshaped tensor and its shape
print(x_reshaped, x_reshaped.shape)

tensor([[1., 2., 3., 4., 5., 6., 7., 8., 9.]]) torch.Size([1, 9])


In [3]:
# Add an extra dimension to the tensor
x_reshaped = x.reshape(9, 1)
# Print the reshaped tensor and its shape
print(x_reshaped, x_reshaped.shape)

tensor([[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.]]) torch.Size([9, 1])


In [4]:
# Change the view
z = x.view(1,9)
z, z.shape

(tensor([[1., 2., 3., 4., 5., 6., 7., 8., 9.]]), torch.Size([1, 9]))

In [5]:
# Changing z changes x(because a view of a tensor shares the same memory as the original tensor)
z[:, 0] = 100
z, x

(tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]]),
 tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]))

In [6]:
# Stack tensors on top of each other
x_stacked = torch.stack([x, x, x], dim=0)
x_stacked

tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.],
        [100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.],
        [100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]])

In [9]:
# torch.squeeze() removes dimensions of size 1 from the tensor
x

tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.])

In [10]:
x_reshaped = x.reshape(1, 9)

In [11]:
x_reshaped

tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]])

In [13]:
x_reshaped.shape

torch.Size([1, 9])

In [14]:
x_reshaped.squeeze()

tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.])

In [15]:
x_reshaped.squeeze().shape

torch.Size([9])

In [16]:
print(f"Previous tensor: {x_reshaped}")
print(f"Previous shape: {x_reshaped.shape}")

# Remove extra dimensions from x_reshaped
x_squeezed = x_reshaped.squeeze()
print(f"New tensor: {x_squeezed}")
print(f"New shape: {x_squeezed.shape}")

Previous tensor: tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]])
Previous shape: torch.Size([1, 9])
New tensor: tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.])
New shape: torch.Size([9])
