In [1]:
import numpy as np
import torch

## Tensor 的连接操作

### cat
torch.cat(tensors, dim = 0, out = None)

In [2]:
A = torch.ones(3,3)
B = 2 * torch.ones(3,3)
A, B

(tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]),
 tensor([[2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.]]))

In [3]:
# dim=0
C = torch.cat((A,B),0)
C

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])

In [4]:
# dim=1
D = torch.cat((A,B),1)
D

tensor([[1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 2., 2., 2.]])

### stack
torch.stack(inputs, dim=0)<br>
inputs 表示需要拼接的 Tensor，dim 表示新建立维度的方向。

In [5]:
A = torch.arange(0,4)
A

tensor([0, 1, 2, 3])

In [6]:
B = torch.arange(5,9)
B

tensor([5, 6, 7, 8])

In [8]:
C = torch.stack((A,B),0)
C

tensor([[0, 1, 2, 3],
        [5, 6, 7, 8]])

In [9]:
D = torch.stack((A,B),1)
D

tensor([[0, 5],
        [1, 6],
        [2, 7],
        [3, 8]])

## Tensor 的切分操作
切分的操作主要分为三种类型：chunk、split、unbind。

### chunk
chunk 的作用就是将 Tensor 按照声明的 dim，进行尽可能平均的划分。<br>
torch.chunk(input, chunks, dim=0)

In [12]:
A = torch.tensor([1,2,3,4,5,6,7,8,9,10])
B = torch.chunk(A, 2, 0)
B

(tensor([1, 2, 3, 4, 5]), tensor([ 6,  7,  8,  9, 10]))

In [14]:
B = torch.chunk(A, 3, 0)
B

(tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10]))

In [15]:
A = torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17])
B = torch.chunk(A, 4, 0)
B

(tensor([1, 2, 3, 4, 5]),
 tensor([ 6,  7,  8,  9, 10]),
 tensor([11, 12, 13, 14, 15]),
 tensor([16, 17]))

In [16]:
A = torch.tensor([1,2,3])
B = torch.chunk(A, 5, 0)
B

(tensor([1]), tensor([2]), tensor([3]))

In [17]:
A = torch.ones(4,4)
A

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [21]:
B = torch.chunk(A, 2, 0)
B

(tensor([[1., 1., 1., 1.],
         [1., 1., 1., 1.]]),
 tensor([[1., 1., 1., 1.],
         [1., 1., 1., 1.]]))

In [20]:
B = torch.chunk(A, 2, 1)
B

(tensor([[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]),
 tensor([[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]))

### split
torch.split(tensor, split_size_or_sections, dim=0)

In [22]:
A = torch.rand(4,4)
A

tensor([[0.8614, 0.3151, 0.6770, 0.2734],
        [0.1798, 0.6282, 0.9865, 0.3871],
        [0.6097, 0.9024, 0.0637, 0.4632],
        [0.6161, 0.0773, 0.2672, 0.4330]])

In [27]:
B = torch.split(A, 2, 0)
B

(tensor([[0.8614, 0.3151, 0.6770, 0.2734],
         [0.1798, 0.6282, 0.9865, 0.3871]]),
 tensor([[0.6097, 0.9024, 0.0637, 0.4632],
         [0.6161, 0.0773, 0.2672, 0.4330]]))

In [25]:
B = torch.split(A, 3, 0)
B

(tensor([[0.8614, 0.3151, 0.6770, 0.2734],
         [0.1798, 0.6282, 0.9865, 0.3871],
         [0.6097, 0.9024, 0.0637, 0.4632]]),
 tensor([[0.6161, 0.0773, 0.2672, 0.4330]]))

In [28]:
A = torch.rand(5,4)
A

tensor([[0.4634, 0.9227, 0.9395, 0.0556],
        [0.1426, 0.9531, 0.6741, 0.0741],
        [0.3275, 0.8230, 0.3568, 0.9161],
        [0.7750, 0.8154, 0.9317, 0.2937],
        [0.5073, 0.3438, 0.0835, 0.4672]])

In [33]:
B = torch.split(A, (2,3), 0)
B

(tensor([[0.4634, 0.9227, 0.9395, 0.0556],
         [0.1426, 0.9531, 0.6741, 0.0741]]),
 tensor([[0.3275, 0.8230, 0.3568, 0.9161],
         [0.7750, 0.8154, 0.9317, 0.2937],
         [0.5073, 0.3438, 0.0835, 0.4672]]))

In [31]:
B = torch.split(A, (1,3), 1)
B

(tensor([[0.4634],
         [0.1426],
         [0.3275],
         [0.7750],
         [0.5073]]),
 tensor([[0.9227, 0.9395, 0.0556],
         [0.9531, 0.6741, 0.0741],
         [0.8230, 0.3568, 0.9161],
         [0.8154, 0.9317, 0.2937],
         [0.3438, 0.0835, 0.4672]]))

### unbind
torch.unbind(input, dim=0)

In [34]:
A = torch.arange(0,16).view(4,4)
A

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [38]:
B = torch.unbind(A, 0)
B

(tensor([0, 1, 2, 3]),
 tensor([4, 5, 6, 7]),
 tensor([ 8,  9, 10, 11]),
 tensor([12, 13, 14, 15]))

In [39]:
B = torch.unbind(A, 1)
B

(tensor([ 0,  4,  8, 12]),
 tensor([ 1,  5,  9, 13]),
 tensor([ 2,  6, 10, 14]),
 tensor([ 3,  7, 11, 15]))

## Tensor 的索引操作
最常用的两个操作就是 index_select 和 masked_select

### index_select
torch.index_select(tensor, dim, index)index_select

In [40]:
A = torch.arange(0,16).view(4,4)
A

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [41]:
B = torch.index_select(A, 0, torch.tensor([1,3]))
B

tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])

In [45]:
C = torch.index_select(A, 1, torch.tensor([0,3]))
C

tensor([[ 0,  3],
        [ 4,  7],
        [ 8, 11],
        [12, 15]])

### masked_select
torch.masked_select(input, mask, out=None) 

In [46]:
A = torch.rand(5)
A

tensor([0.0280, 0.3388, 0.0150, 0.5246, 0.3163])

In [47]:
B = A > 0.3
B

tensor([False,  True, False,  True,  True])

In [48]:
C = torch.masked_select(A, B)
C

tensor([0.3388, 0.5246, 0.3163])

In [49]:
A = torch.rand(5)
C = torch.masked_select(A, A>0.3)
C

tensor([0.7191, 0.6315, 0.9283])