# Quick: `torch.flatten`

## `torch.flatten`

`torch.flatten` **reshapes a tensor into a 1-D vector**, or flattens
only a slice of its dimensions.

Basic use:

``` python
torch.flatten(tensor) # flatten all dimensions
torch.flatten(tensor, start_dim, end_dim)
```

-   `start_dim`: first dimension to begin flattening (default: `0`)
-   `end_dim`: last dimension to flatten (default: `-1`)
-   All dimensions in `[start_dim, end_dim]` collapse into a single one.
-   Dimensions outside that range remain unchanged.

This is commonly used before feeding data into fully-connected layers.

## Example

import torch

x = torch.tensor(\[ \[\[1, 2, 3\], \[4, 5, 6\]\], \[\[7, 8, 9\], \[10,
11, 12\]\]\]) \# shape (2, 2, 3)

# Flatten all dimensions → shape (12,)

y1 = torch.flatten(x)

# Flatten from dim 1 to dim 2 → shape (2, 6)

y2 = torch.flatten(x, start_dim=1)

print(“Original shape:”, x.shape) print(“Flattened (all dims):”,
y1.shape) print(“Flattened (start_dim=1):”, y2.shape) print(y2)