# PyTorch Tensor Slicing and Concatenation

In [1]:
import torch

## Slicing and Concatenation

### Indexing and Slicing

Prepare target tensor.

In [2]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])
print(x.size())

torch.Size([3, 2, 2])


In [7]:
## Access to certain dimension
print(x[0])
print(x[0, :])
print(x[0, :, :])

tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])


In [11]:
print(x[0, 1, :])
print(x[0, :, 1])

tensor([3., 4.])
tensor([2., 4.])


In [12]:
print(x[:, 0, :])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])


In [15]:
print(x[:, :, 0])

tensor([[ 1.,  3.],
        [ 5.,  7.],
        [ 9., 11.]])


In [26]:
## Access by range. Note that the number of dimensions would not be changed.

print(x[1:3, :, :].size())
print(x[:, :1, :].size())
print(x[:, :-1, :].size())

torch.Size([2, 2, 2])
torch.Size([3, 1, 2])
torch.Size([3, 1, 2])


---
## Split: Split tensor to desirable shapes.

In [27]:
x = torch.FloatTensor(10, 4)

In [28]:
x

tensor([[ 0.0000e+00,  3.6893e+19,  2.6717e-14, -2.8615e-42],
        [-2.0313e-29,  4.5786e-41, -2.0385e-29,  4.5786e-41],
        [ 1.7125e+22, -1.2097e+05, -2.7205e-32,  4.5786e-41],
        [-3.8749e-32,  4.5786e-41, -5.2378e+26,  6.9190e-31],
        [-1.0264e-29,  4.5786e-41, -1.0281e-29,  4.5786e-41],
        [ 1.5852e-38, -3.5502e-30, -7.1089e-30,  4.5786e-41],
        [-7.4702e-30,  4.5786e-41,  5.2871e-24, -1.2590e-30],
        [-7.8829e-30,  4.5786e-41, -9.5723e-30,  4.5786e-41],
        [-8.4507e-16,  3.6548e+11, -7.1754e-30,  4.5786e-41],
        [-7.2930e-30,  4.5786e-41,  5.0865e+17, -2.2052e+24]])

In [29]:
splits = x.split(4, dim=0)

for s in splits:
    print(s.size())

torch.Size([4, 4])
torch.Size([4, 4])
torch.Size([2, 4])


---
## Chunk: Split tensor to number of chunks.

In [30]:
x = torch.FloatTensor(10, 4)

In [31]:
x

tensor([[-7.1813e-28,  4.5786e-41, -7.1813e-28,  4.5786e-41],
        [-7.1814e-28,  4.5786e-41, -7.1469e-28,  4.5786e-41],
        [-7.1831e-28,  4.5786e-41, -7.1831e-28,  4.5786e-41],
        [-7.1832e-28,  4.5786e-41, -7.1814e-28,  4.5786e-41],
        [-7.1832e-28,  4.5786e-41, -7.5843e-29,  4.5786e-41],
        [ 0.0000e+00,  3.6893e+19,  2.0172e-14, -2.5250e-29],
        [-7.5731e-29,  4.5786e-41, -7.5845e-29,  4.5786e-41],
        [-7.5822e-29,  4.5786e-41, -7.5848e-29,  4.5786e-41],
        [-7.5823e-29,  4.5786e-41, -7.5786e-29,  4.5786e-41],
        [-7.5786e-29,  4.5786e-41, -6.8993e-29,  4.5786e-41]])

In [33]:
chunks = x.chunk(5, dim=0)

for c in chunks:
    print(c.size())

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


---
## index_select: Select elements by using dimension index.

In [52]:
x = torch.FloatTensor([[[1, 1],
                        [2, 2]],
                       [[3, 3],
                        [4, 4]],
                       [[5, 5],
                        [6, 6]]])
indices = torch.LongTensor([2, 1, 0])

print(x.size())

torch.Size([3, 2, 2])


In [53]:
y = x.index_select(dim=0, index=indices)

print(y)
print(y.size())

tensor([[[5., 5.],
         [6., 6.]],

        [[3., 3.],
         [4., 4.]],

        [[1., 1.],
         [2., 2.]]])
torch.Size([3, 2, 2])


---
## cat: Concatenation of multiple tensors in the list.

In [54]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6], 
                       [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])
print(x.size(), y.size())

torch.Size([3, 3]) torch.Size([3, 3])


In [59]:
z = torch.cat([x, y], dim=0)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([6, 3])


In [60]:
z = torch.cat([y,x], dim=-1)
print(z)
print(z.size())

tensor([[10., 11., 12.,  1.,  2.,  3.],
        [13., 14., 15.,  4.,  5.,  6.],
        [16., 17., 18.,  7.,  8.,  9.]])
torch.Size([3, 6])


---
## Stack: Stacking of multiple tensors in the list.

In [71]:
for n in range(3):
    z = torch.stack([x, y], dim=n)
    print(z)
    print(z.size())
    print('')

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])

tensor([[[ 1.,  2.,  3.],
         [10., 11., 12.]],

        [[ 4.,  5.,  6.],
         [13., 14., 15.]],

        [[ 7.,  8.,  9.],
         [16., 17., 18.]]])
torch.Size([3, 2, 3])

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])



---
## Implement 'stack' function by using 'cat'

In [72]:
for n in range(3):
    z = torch.cat([x.unsqueeze(n), y.unsqueeze(n)], dim=n)
    print(z)
    print(z.size())
    print('')

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])

tensor([[[ 1.,  2.,  3.],
         [10., 11., 12.]],

        [[ 4.,  5.,  6.],
         [13., 14., 15.]],

        [[ 7.,  8.,  9.],
         [16., 17., 18.]]])
torch.Size([3, 2, 3])

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])



---
## Useful Trick: Merge results from iterations

In [87]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]

result = torch.stack(result)
result

tensor([[[ 0.0000e+00,  3.6893e+19],
         [ 1.9786e-14, -1.5849e+29]],

        [[-1.0842e-19,  1.0737e+01],
         [-1.3146e-28,  4.5786e-41]],

        [[ 1.0737e+08,  3.5101e+06],
         [ 1.3146e-28,  4.5786e-41]],

        [[ 0.0000e+00,  4.4766e+00],
         [ 1.3146e-28,  4.5786e-41]],

        [[ 0.0000e+00,  3.6893e+19],
         [ 1.9811e-14,  2.5250e-29]]])

In [89]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]
    
for i in range(len(result)):
    result[i] = result[i].unsqueeze(0)
    
result = torch.cat(result, dim=0)
result

tensor([[[ 0.0000e+00,  3.6893e+19],
         [ 1.9865e-14, -2.0005e+00]],

        [[ 0.0000e+00,         nan],
         [-1.3146e-28,  4.5786e-41]],

        [[        inf,  3.5101e+06],
         [ 1.3146e-28,  4.5786e-41]],

        [[ 0.0000e+00,  1.7155e-05],
         [ 2.3694e-38,  4.5786e-41]],

        [[ 2.3694e-38,  3.8177e-05],
         [ 8.0000e+00,  6.0000e+00]]])