# Tensor Manipulation in PyTorch

In [2]:
import torch

## Usage of `torch.cat()`

In [3]:
f = torch.arange(6).reshape(2, 3)  # shape: [2, 3]
g = torch.arange(6).reshape(2, 3)  # shape: [2, 3]
print("g: ", g)

h = torch.cat((f, g), dim=0)  # [2+2, 3]
print("h: ", h)
print(h.shape)  # shape: [4, 3]

i = torch.cat((f, g), dim=1)  # [2, 3+3]
print("i: ", i)
print(i.shape)  # shape: [2, 6]

g:  tensor([[0, 1, 2],
        [3, 4, 5]])
h:  tensor([[0, 1, 2],
        [3, 4, 5],
        [0, 1, 2],
        [3, 4, 5]])
torch.Size([4, 3])
i:  tensor([[0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5]])
torch.Size([2, 6])


In [6]:
a = torch.arange(6).reshape(2, 3)  # shape: [2, 3]
b = torch.arange(8).reshape(2, 4)  # shape: [2, 4]

#! c = torch.cat((a, b), dim=0)
# Raise error, Sizes of tensors a and b must match except in dimension 0.

## Usage of `torch.flatten()`

In [11]:
# Range to flatten: [start_dim, end_dim]

t = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])  # shape: [2, 2, 2]

t1 = torch.flatten(t)
print(t1.shape)  # [8]

t2 = torch.flatten(t, start_dim=1)
print(t2.shape)  # [2, 3]

torch.Size([8])
torch.Size([2, 4])


In [16]:
u = torch.ones(2, 3, 4, 5, 6)  # shape: [2, 3, 4, 5, 6]

u1 = torch.flatten(u)  # [720]
print(u1.shape)

u2 = torch.flatten(u, start_dim=2)  # [2, 3, 120]
print(u2.shape)

u3 = torch.flatten(u, end_dim=3)  # [120, 6]
print(u3.shape)

u4 = torch.flatten(u, start_dim=2, end_dim=3)  # [2, 3, 20, 6]
print(u4.shape)

torch.Size([720])
torch.Size([2, 3, 120])
torch.Size([120, 6])
torch.Size([2, 3, 20, 6])


## Usage of `torch.permute()`

In [3]:
v = torch.ones(2, 3, 4, 5, 6)  # shape: [2, 3, 4, 5, 6]

v1 = torch.permute(v, (0, 1, 2, 3, 4))  # [2, 3, 4, 5, 6]
print(v1.shape)

v2 = torch.permute(v, (4, 3, 2, 1, 0))  # [6, 5, 4, 3, 2]
print(v2.shape)

torch.Size([2, 3, 4, 5, 6])
torch.Size([6, 5, 4, 3, 2])


## Usage of `torch.view()`

In [10]:
w = torch.ones(2, 3, 4, 5, 6)  # shape: [2, 3, 4, 5, 6]

w1 = w.view(2, 3, -1)  # [2, 3, 120]
print(w1.shape)

w2 = w.view(-1, 1)  # [720, 1]
print(w2.shape)

torch.Size([2, 3, 120])
torch.Size([720, 1])


## Usage of `torch.transpose()`

In [5]:
x = torch.ones(2, 3, 4, 5, 6)  # shape: [2, 3, 4, 5, 6]

x1 = x.transpose(1, 3)  # [2, 5, 4, 3, 6]
print(x1.shape)

x2 = x.transpose(0, -1)  # [6, 3, 4, 5, 2]
print(x2.shape)

torch.Size([2, 5, 4, 3, 6])
torch.Size([6, 3, 4, 5, 2])


## Usage of `torch.squeeze()`

In [13]:
y = torch.ones(1, 2, 1, 3, 1, 4)  # shape: [1, 2, 1, 3, 1, 4]

y1 = torch.squeeze(y)
print(y1.shape)  # [2, 3, 4]

y2 = torch.squeeze(y, 0)  # The squeeze operation is done only in the given dimension (0 here)
print(y2.shape)  # [2, 1, 3, 1, 4]

y3 = torch.squeeze(y, 1)  # Leave unchange
print(y3.shape)  # [1, 2, 1, 3, 1, 4]

y4 = torch.squeeze(y, 2)
print(y4.shape)  # [1, 2, 3, 1, 4]

torch.Size([2, 3, 4])
torch.Size([2, 1, 3, 1, 4])
torch.Size([1, 2, 1, 3, 1, 4])
torch.Size([1, 2, 3, 1, 4])


## Usage of `torch.unsqueeze()`

In [17]:
z = torch.ones(2, 3, 4)

z1 = torch.unsqueeze(z, 0)  # [1, 2, 3, 4]
print(z1.shape)

z2 = torch.unsqueeze(z, 1)  # [2, 1, 3, 4]
print(z2.shape)

z3 = torch.unsqueeze(z, 2)  # [2, 3, 1, 4]
print(z3.shape)

torch.Size([1, 2, 3, 4])
torch.Size([2, 1, 3, 4])
torch.Size([2, 3, 1, 4])


### Usage of `torch.meshgrid`

![meshgrid](https://www.dropbox.com/s/gk9jr9y1mdiwneu/meshgrid.png?raw=1)