# Reshaping tensors

There are several ways of reshaping tensors in PyTorch:

- Reshaping - reshapes an input tensor to a defined shape.
- View - return a view of an input tensor but keep the same memory as the original tensor.
- Stacking - combine multiple tensors on top of each other (vstack) or side by side (hstack).
- Squeeze - removes all `1` dimensions from a tensor.
- Unsqueeze - add a `1` dimension to a target tensor.
- Permute - returns a view of the input with dimensions permuted (swapped) in a certain way


In [1]:
# Importing torch
import torch

In [2]:
torch.set_default_device('mps')

In [3]:
# Creating tensor
x = torch.arange(1, 11, dtype=torch.float32)
x, x.shape

(tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], device='mps:0'),
 torch.Size([10]))

The reshaped tensor must fit all of the original data


In [4]:
# Adding an extra dimension
x_reshaped = x.reshape(1, 10)
x_reshaped, x_reshaped.shape

(tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]], device='mps:0'),
 torch.Size([1, 10]))

In [5]:
# Different reshape
x_reshaped.reshape(10, 1)

tensor([[ 1.],
        [ 2.],
        [ 3.],
        [ 4.],
        [ 5.],
        [ 6.],
        [ 7.],
        [ 8.],
        [ 9.],
        [10.]], device='mps:0')

In [6]:
# Changing the view
z = x.view(10, 1)
z, z.shape

(tensor([[ 1.],
         [ 2.],
         [ 3.],
         [ 4.],
         [ 5.],
         [ 6.],
         [ 7.],
         [ 8.],
         [ 9.],
         [10.]], device='mps:0'),
 torch.Size([10, 1]))

Changing `z` changes `x`, because a view of a tensor shares the memory as the original input


In [7]:
# Changing z
z[0, :] = 5
z, x

(tensor([[ 5.],
         [ 2.],
         [ 3.],
         [ 4.],
         [ 5.],
         [ 6.],
         [ 7.],
         [ 8.],
         [ 9.],
         [10.]], device='mps:0'),
 tensor([ 5.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], device='mps:0'))

Stacking tensors is also a pretty simple concept, it concatenates tensors along a different dimension. You can use the `axis` parameter to choose which dimension the data should be concatenated on. For a 2D tensor, 0 stacks them horizontally, and 1 stacks them vertically.


In [8]:
# Stack tensors on top of each other
x_stacked = torch.stack([x, x, x, x], dim=0)
x_stacked, x_stacked.shape

(tensor([[ 5.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
         [ 5.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
         [ 5.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
         [ 5.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]], device='mps:0'),
 torch.Size([4, 10]))

In [9]:
# Resetting first index to 1
x[0] = 1

In [10]:
# Squeezing
x_squeezed = x_reshaped.squeeze()
x_squeezed, x_reshaped

(tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], device='mps:0'),
 tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]], device='mps:0'))

Squeezing removes all dimensions with value 1.


In [11]:
# Unsqueezing tensor
x_unsqueezed = x_squeezed.unsqueeze(-1)  # Adding dimension to the last possible index
x_unsqueezed, x_unsqueezed.shape

(tensor([[ 1.],
         [ 2.],
         [ 3.],
         [ 4.],
         [ 5.],
         [ 6.],
         [ 7.],
         [ 8.],
         [ 9.],
         [10.]], device='mps:0'),
 torch.Size([10, 1]))

Unsqueeze does the opposite operation, instead of removing a dimension, it adds one in the index you choose


In [12]:
# Creating new tensor
simulated_image = torch.rand(size=(224, 224, 3))

# Permuting tensor
permuted_simulated_image = simulated_image.permute(
    2, 0, 1
)  # Making the third dimension the first

# Checking tensors's shapes
simulated_image.shape, permuted_simulated_image.shape

(torch.Size([224, 224, 3]), torch.Size([3, 224, 224]))

Permuting a tensor allows the dimensions to be "moved around". This is nothing more than a view of the original tensor, meaning the new tensor shares memory with the original one.


## Indexing

The other form or reshaping tensors is with indexing. This is a basic operation in python and numpy, that is also available in PyTorch. Much like in Python, indexing is done with square brackets and the index of the dimensions.


In [13]:
# Creating a tensor
indexing_tensor = torch.arange(1, 10, dtype=torch.float32).reshape(1, 3, 3)
indexing_tensor

tensor([[[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]], device='mps:0')

In [14]:
# Accessing the first dimension
indexing_tensor[0]

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]], device='mps:0')

In [15]:
# Indexing on the middle bracket
indexing_tensor[0, 0]

tensor([1., 2., 3.], device='mps:0')

In [16]:
# Both ways produce equal outcomes
indexing_tensor[0][0]

tensor([1., 2., 3.], device='mps:0')

In [17]:
# Indexing on innermost bracket
indexing_tensor[0, 0, 0]

tensor(1., device='mps:0')

In [18]:
# Using ':' to get the entire first column (1x1, 2x1, 3x1)
indexing_tensor[:, :, 0]

tensor([[1., 4., 7.]], device='mps:0')

The `:` operator is one of the most common you'll see when indexing. It is used to get all values on a target dimension. In our case the target dimension was the first one and the second in their entirety, and we wanted the first indexes of all the dimensions until that point, so first element of the lines in the arrays of the tensor.


In [19]:
# Getting the value 9 from tensor
indexing_tensor[-1, -1, -1]

tensor(9., device='mps:0')

In [21]:
# Getting tensor 3, 6, 9
indexing_tensor[:, :, 2]

tensor([[3., 6., 9.]], device='mps:0')