## Pythorch Tensor Slicing and Concatenation

### Slicing and Concatenation

#### indexing and Slicing

In [1]:
import torch

In [2]:
dim = 3
row = 4
col = 5 
data = [[[i*(row*col) + j*col + k for k in range(col)] for j in range(row)] for i in range(dim)]
x = torch.FloatTensor(data)
print(f"x:{x}, size:{x.size()}")

x:tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]],

        [[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.]]]), size:torch.Size([3, 4, 5])


In [6]:
## Slicing 
print(f"x[0]: {x[0]}, size:{x[0].size()}")
print(f"x[0,:]: {x[0,:]}, size:{x[0,:].size()}")
print(f"x[0, :, :]: {x[0, :, :]}" )

x[0]: tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.]]), size:torch.Size([4, 5])
x[0,:]: tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.]]), size:torch.Size([4, 5])
x[0, :, :]: tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.]])


In [11]:
## Slicing 
print(f"x[-1]:  {x[-1]}")
print(f"x[-1, :] = \n{x[-1, :]}")
print(f"x[-1, :, :] = \n{x[-1, :, :]}")

x[-1]:  tensor([[40., 41., 42., 43., 44.],
        [45., 46., 47., 48., 49.],
        [50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59.]])
x[-1, :] = 
tensor([[40., 41., 42., 43., 44.],
        [45., 46., 47., 48., 49.],
        [50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59.]])
x[-1, :, :] = 
tensor([[40., 41., 42., 43., 44.],
        [45., 46., 47., 48., 49.],
        [50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59.]])


In [None]:
## Slicing - 2번째 dimension의 0번째 index 모두 가져와라
print(f"x[:, 0, :] = \n{x[:, 0, :]}\nsize: {x[:, 0, :].size()}")

x[:, 0, :] = 
tensor([[ 0.,  1.,  2.,  3.,  4.],
        [20., 21., 22., 23., 24.],
        [40., 41., 42., 43., 44.]])
size: torch.Size([3, 5])


In [16]:
## Range Access
## x[1:3, :, :] 2번째 index부터 3번째 index까지 가져와라
print(f"x[1:3, :, :] = \n{x[1:3, :, :]}\nsize: {x[1:3, :, :].size()}")
##Range의 특징은 1개일지라도 dimension이 없어지지 않는다.
print(f"x[:, :1, :] = \n{x[:, :1, :]}\nsize: {x[:, :1, :].size()}") 
print(f"x[:, :-1, :] = \n{x[:, :-1, :]}\nsize: {x[:, :-1, :].size()}")

x[1:3, :, :] = 
tensor([[[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]],

        [[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.]]])
size: torch.Size([2, 4, 5])
x[:, :1, :] = 
tensor([[[ 0.,  1.,  2.,  3.,  4.]],

        [[20., 21., 22., 23., 24.]],

        [[40., 41., 42., 43., 44.]]])
size: torch.Size([3, 1, 5])
x[:, :-1, :] = 
tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.]],

        [[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.]]])
size: torch.Size([3, 3, 5])


### Split: Split Tensor to desirable shape.

In [17]:
x = torch.FloatTensor(10, 4) ## 10x4 Tensor
splits = x.split(4, dim=0)

for s in splits:
    print(f"s:{s}, size:{s.size()}")

s:tensor([[1.0124e-20, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.4000e+01, 2.5000e+01, 2.6000e+01, 2.7000e+01],
        [2.8000e+01, 2.9000e+01, 3.0000e+01, 3.1000e+01],
        [3.2000e+01, 3.3000e+01, 3.4000e+01, 3.5000e+01]]), size:torch.Size([4, 4])
s:tensor([[36., 37., 38., 39.],
        [40., 41., 42., 43.],
        [44., 45., 46., 47.],
        [48., 49., 50., 51.]]), size:torch.Size([4, 4])
s:tensor([[52., 53., 54., 55.],
        [56., 57., 58., 59.]]), size:torch.Size([2, 4])


### Chunk: Split tensor to number of Chunks

In [18]:
x = torch.LongTensor(8, 4)
chunks = x.chunk(3, dim=0)
for c in chunks:
    print(f"c:{c}, size:{c.size()}")

c:tensor([[             125482,                   0,                   0,
                           0],
        [                  0, 7235419174270214779, 3689399400580266530,
         3618985543144124774],
        [3487302567453931875, 3617854187786811494, 3832949646228141872,
         3907210438864024369]]), size:torch.Size([3, 4])
c:tensor([[6874590268765514786, 2459029315949459828, 6874028453778520165,
         2318265813094655346],
        [7881702260482471202, 7739546057368543845, 8296228837497728609,
         4189032028197581669],
        [3559588851250176544, 4135204104193270369, 4063988925762725222,
         3976733675042058289]]), size:torch.Size([3, 4])
c:tensor([[7215364926928664931, 3612485285161825377, 3615606755904795184,
         4193750903543714864],
        [4121128130547953973, 7311068566657325624, 2322206415573840754,
         7233200062374753570]]), size:torch.Size([2, 4])


### index_select: Select elements by using dimension index.

In [22]:
x = torch.FloatTensor(data)
indice = torch.LongTensor([2,1])
print(f"x.size(): {x.size()}")


x.size(): torch.Size([3, 4, 5])


In [23]:
y = x.index_select(dim=0, index=indice)
print(f"y:{y}, size: {y.size()}")

y:tensor([[[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]]]), size: torch.Size([2, 4, 5])


### CAT: Concatenation of Multiple Tensors in the list.

In [24]:
x = torch.FloatTensor([[1, 2, 3],
                    [4, 5, 6],
                    [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                    [13, 14, 15],
                    [16, 17, 18]])
print(f"x.size(): {x.size()}, y.size(): {y.size()}")

x.size(): torch.Size([3, 3]), y.size(): torch.Size([3, 3])


In [25]:
z = torch.cat([x, y], dim=0)
print(f"z:{z}, size: {z.size()}")

z:tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]]), size: torch.Size([6, 3])


In [26]:
z = torch.cat([x, y], dim=-1)
print(f"z:{z}, size: {z.size()}")

z:tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]]), size: torch.Size([3, 6])


### Stack: Stacking of multiple tensors in the list

In [None]:
## 쌓는다?! = 차원이 하나 더 생긴다.
z = torch.stack([x, y], dim=0)
## step1) 3x3 -> 1x3x3
## step2) 1x3x3 matrix -> 2x3x3 matrix
print(f"z:{z}, size: {z.size()}")


z:tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]]), size: torch.Size([2, 3, 3])


In [30]:
z = torch.stack([x, y], dim=-1)
## step1) 3x3 -> 3x3x1
## step2) 3x3x1 matrix -> 3x3x2 matrix
print(f"z:{z}, size: {z.size()}")

z:tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]]), size: torch.Size([3, 3, 2])


#### Implement 'stack' function by using cat.

In [32]:
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
print(f"z:{z}, size: {z.size()}")

z:tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]]), size: torch.Size([2, 3, 3])


### Useful Trick: Merge Results from iterations

In [None]:
result = []
for i in range(5):
    x = torch.FloatTensor(2,2)
    result += [x]
    
result = torch.stack(result)
result.size()  # torch.Size([5, 2, 2])