Reshaping, stacking, squeezing and unsqueezing tensors

- Reshaping - reshapes an input tensor to a defined shape
- View - return a view of an input tensor of certain shape but keep the same memory as the original tensor
- Stacking - combine multiple tensors on top of each other(vstack) or side by side(hstack)
- Squeeze - removes all `1` dimensions from a tensor
- Unsqueeze - add a `1` dimension to a target tensor
- Permute - return a view of the input with dimensions permuted(swapped) in a certain way

In [1]:
import torch
x = torch.arange(1.,10.)
x, x.shape

(tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.]), torch.Size([9]))

In [2]:
# Add an extra dimension to the tensor
x_reshaped = x.reshape(1, 9)
# Print the reshaped tensor and its shape
print(x_reshaped, x_reshaped.shape)

tensor([[1., 2., 3., 4., 5., 6., 7., 8., 9.]]) torch.Size([1, 9])


In [3]:
# Add an extra dimension to the tensor
x_reshaped = x.reshape(9, 1)
# Print the reshaped tensor and its shape
print(x_reshaped, x_reshaped.shape)

tensor([[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.]]) torch.Size([9, 1])


In [4]:
# Change the view
z = x.view(1,9)
z, z.shape

(tensor([[1., 2., 3., 4., 5., 6., 7., 8., 9.]]), torch.Size([1, 9]))

In [5]:
# Changing z changes x(because a view of a tensor shares the same memory as the original tensor)
z[:, 0] = 100
z, x

(tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]]),
 tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]))

In [6]:
# Stack tensors on top of each other
x_stacked = torch.stack([x, x, x], dim=0)
x_stacked

tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.],
        [100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.],
        [100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]])

In [9]:
# torch.squeeze() removes dimensions of size 1 from the tensor
x

tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.])

In [10]:
x_reshaped = x.reshape(1, 9)

In [11]:
x_reshaped

tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]])

In [13]:
x_reshaped.shape

torch.Size([1, 9])

In [14]:
x_reshaped.squeeze()

tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.])

In [15]:
x_reshaped.squeeze().shape

torch.Size([9])

In [16]:
print(f"Previous tensor: {x_reshaped}")
print(f"Previous shape: {x_reshaped.shape}")

# Remove extra dimensions from x_reshaped
x_squeezed = x_reshaped.squeeze()
print(f"New tensor: {x_squeezed}")
print(f"New shape: {x_squeezed.shape}")

Previous tensor: tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]])
Previous shape: torch.Size([1, 9])
New tensor: tensor([100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.])
New shape: torch.Size([9])


In [17]:
print(f"Previous target: {x_reshaped}")
print(f"Previous target shape: {x_reshaped.shape}")

# Add an extra dimension to the tensor with unsqueeze
x_unsqueezed = x_reshaped.unsqueeze(dim=0)
print(f"New target: {x_unsqueezed}")
print(f"New target shape: {x_unsqueezed.shape}")

Previous target: tensor([[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]])
Previous target shape: torch.Size([1, 9])
New target: tensor([[[100.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.]]])
New target shape: torch.Size([1, 1, 9])


In [18]:
# torch.permute - rearranges the dimensions of a tensor in a specified order
x_original = torch.rand(size=(224, 224, 3))

# Permute the original tensor to rearrange thee axis(or dim) order
x_permuted = x_original.permute(2, 0, 1)  # Rearranging to (3, 224, 224)
print(f"Original shape: {x_original.shape}")
print(f"Permuted shape: {x_permuted.shape}")

Original shape: torch.Size([224, 224, 3])
Permuted shape: torch.Size([3, 224, 224])


In [19]:
x_original

tensor([[[0.8275, 0.3739, 0.8007],
         [0.3216, 0.9339, 0.2541],
         [0.7217, 0.7620, 0.7132],
         ...,
         [0.1767, 0.0775, 0.0247],
         [0.4301, 0.2354, 0.1562],
         [0.4917, 0.3802, 0.6971]],

        [[0.3952, 0.9722, 0.9669],
         [0.6012, 0.4404, 0.3503],
         [0.5522, 0.5255, 0.1056],
         ...,
         [0.2792, 0.5348, 0.5240],
         [0.4260, 0.6974, 0.1722],
         [0.9120, 0.2957, 0.3107]],

        [[0.2575, 0.3923, 0.8202],
         [0.3176, 0.9830, 0.0719],
         [0.3396, 0.0985, 0.6171],
         ...,
         [0.3427, 0.2934, 0.0212],
         [0.0899, 0.6357, 0.0689],
         [0.6106, 0.7341, 0.5091]],

        ...,

        [[0.5679, 0.8764, 0.1135],
         [0.4450, 0.9396, 0.4361],
         [0.1673, 0.2583, 0.2281],
         ...,
         [0.2109, 0.0964, 0.5060],
         [0.8314, 0.5859, 0.9143],
         [0.4767, 0.8473, 0.8143]],

        [[0.9193, 0.2179, 0.2190],
         [0.8066, 0.2152, 0.3740],
         [0.

In [20]:
x_original[0, 0, 0]

tensor(0.8275)

In [21]:
x_original[0, 0, 0] = 100
x_original[0, 0, 0]

tensor(100.)