## PyTorch Tensor Slicing and Concatenation

In [1]:
import torch

### Slicing and Concatenation

#### Indexing and Slicing

Prepare target tensor.

In [2]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])
print(x.size())

torch.Size([3, 2, 2])


In [3]:
print(x[0])
print(x[1])
print(x[2])

tensor([[1., 2.],
        [3., 4.]])
tensor([[5., 6.],
        [7., 8.]])
tensor([[ 9., 10.],
        [11., 12.]])


In [5]:
print(x[0, :])
print(x[0, :, :])

tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])


In [6]:
print(x[-1])
print(x[-1, :])
print(x[-1, :, :])

tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])


In [8]:
print(x[:, 1, :])

tensor([[ 3.,  4.],
        [ 7.,  8.],
        [11., 12.]])


### split: Split tensor to desirable shapes

In [12]:
x = torch.FloatTensor(10, 4)
x

tensor([[ 0.0000e+00, -3.6893e+19,  8.1493e-32, -1.0845e-19],
        [-4.0527e-02,  4.5738e-41, -2.4645e+01,  4.5738e-41],
        [-3.2870e-02,  4.5738e-41, -3.2875e-02,  4.5738e-41],
        [-3.2068e-02,  4.5738e-41, -3.3393e-02,  4.5738e-41],
        [-4.6177e+00,  4.5738e-41, -2.4523e+01,  4.5738e-41],
        [-3.3400e-02,  4.5738e-41, -3.2880e-02,  4.5738e-41],
        [-3.2882e-02,  4.5738e-41, -3.1810e-02,  4.5738e-41],
        [-3.2887e-02,  4.5738e-41, -2.4645e+01,  4.5738e-41],
        [-3.3473e-02,  4.5738e-41, -2.4645e+01,  4.5738e-41],
        [-3.2939e-02,  4.5738e-41, -6.2138e+00,  4.5738e-41]])

In [13]:
splits = x.split(4, dim=0)

In [15]:
for elem in splits:
    print(elem.size())

torch.Size([4, 4])
torch.Size([4, 4])
torch.Size([2, 4])


### chunk: Split tensor to number of chunks.

In [16]:
x = torch.FloatTensor(8, 4)

In [19]:
x

tensor([[ 0.0000e+00, -3.6893e+19,  8.2289e-32,  8.5920e+09],
        [ 1.6816e-44,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  4.7089e-33,  1.4013e-45]])

In [17]:
chunks = x.chunk(3, dim=0)


In [18]:
chunks

(tensor([[ 0.0000e+00, -3.6893e+19,  8.2289e-32,  8.5920e+09],
         [ 1.6816e-44,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]),
 tensor([[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]),
 tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 4.7089e-33, 1.4013e-45]]))

In [20]:
for chunk in chunks:
    print(chunk.size())

torch.Size([3, 4])
torch.Size([3, 4])
torch.Size([2, 4])


### Index_select: Select elements by using dimension index.

In [23]:
x = torch.FloatTensor([[[1, 1],
                        [2, 2]],
                       [[3, 3],
                        [4, 4]],
                       [[5, 5],
                        [6, 6]]])

indice = torch.LongTensor([2, 1])

print(x.size())

torch.Size([3, 2, 2])


In [25]:
y = x.index_select(dim=0, index=indice)

print(y)
print(y.size())

tensor([[[5., 5.],
         [6., 6.]],

        [[3., 3.],
         [4., 4.]]])
torch.Size([2, 2, 2])


### cat: Concatenation of multiple tensors in the list.

In [26]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])
print(x.size(), y.size())

torch.Size([3, 3]) torch.Size([3, 3])


In [28]:
z = torch.cat([x, y], dim=-1)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]])
torch.Size([3, 6])


In [29]:
z = torch.cat([x, y], dim=0)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([6, 3])


### stack: Stacking of multiple tensors in the list.

In [31]:
z = torch.stack([x, y])
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


In [32]:
z = torch.stack([x, y], dim=-1)
print(z)
print(z.size())

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])


In [35]:
x.unsqueeze(0).size()

torch.Size([1, 3, 3])

In [36]:
y.unsqueeze(0).size()

torch.Size([1, 3, 3])

In [38]:
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


### Useful Trick: Merge results from iterations

In [47]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]
    
torch.stack(result).size()

torch.Size([5, 2, 2])