# Flatten/unflatten

`torch.nn.Flatten` and `torch.nn.Unflatten` layers allow you to manipulate the dimensionalities of the output directly within the forward pass of a neural network.

In [2]:
import torch

## Operation deep dive

The operation essentially functions like a basic `reshape` transformation. However, it's worth examining more carefully what happens during this process.

### Flatten

It simply concatenates elements along the "flattened" dimensions.

---

For example, consider a 3-dimensional tensor; it's easy to think of it as a cube of numbers.

In [13]:
input = torch.arange(27).reshape([3, 3, 3])
input

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]])

The default `Flatten` concatenates all the last dimensions, so instead of a set of two-dimensional matrices, we end up with a set of vectors.

In [14]:
torch.nn.Flatten()(input)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23, 24, 25, 26]])

Now let's try using a non-default `Flatten`. Suppose we want to concatenate outer dimensions. This way, we concatenate layers of the cube and end up with a long matrix.

In [15]:
torch.nn.Flatten(start_dim=0, end_dim=1)(input)

tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]])

### Unflatten

It splits the specified dimension into pieces and arranges these pieces in the specified order.

---

As an example, consider a matrix. The outer dimension of the matrix refers to its rows.

In [23]:
input = torch.arange(81).reshape([9, 9])
input

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23, 24, 25, 26],
        [27, 28, 29, 30, 31, 32, 33, 34, 35],
        [36, 37, 38, 39, 40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49, 50, 51, 52, 53],
        [54, 55, 56, 57, 58, 59, 60, 61, 62],
        [63, 64, 65, 66, 67, 68, 69, 70, 71],
        [72, 73, 74, 75, 76, 77, 78, 79, 80]])

The following cell applies `Unflatten` to the outer dimension of the input matrix—rows are grouped into separate matrices and arranged as layers of a 3D array.

In [24]:
torch.nn.Unflatten(dim=0, unflattened_size=(3,3))(input)

tensor([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
         [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23, 24, 25, 26]],

        [[27, 28, 29, 30, 31, 32, 33, 34, 35],
         [36, 37, 38, 39, 40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49, 50, 51, 52, 53]],

        [[54, 55, 56, 57, 58, 59, 60, 61, 62],
         [63, 64, 65, 66, 67, 68, 69, 70, 71],
         [72, 73, 74, 75, 76, 77, 78, 79, 80]]])

Applying `Unflatten` to the inner dimension of the input splits rows into subarrays and arranges them as new rows inside the matrices.

In [25]:
torch.nn.Unflatten(dim=1, unflattened_size=(3,3))(input)

tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]],

        [[27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]],

        [[36, 37, 38],
         [39, 40, 41],
         [42, 43, 44]],

        [[45, 46, 47],
         [48, 49, 50],
         [51, 52, 53]],

        [[54, 55, 56],
         [57, 58, 59],
         [60, 61, 62]],

        [[63, 64, 65],
         [66, 67, 68],
         [69, 70, 71]],

        [[72, 73, 74],
         [75, 76, 77],
         [78, 79, 80]]])