In [1]:
# -*- coding: utf-8 -*-

'''
@Author   :   Corley Tang
@contact  :   cutercorleytd@gmail.com
@Github   :   https://github.com/corleytd
@Time     :   2023-01-10 15:03
@Project  :   Hands-on Deep Learning with PyTorch-tensor_index_slice_merge
张量的索引、分片、合并以及维度调整
'''

# 导入所需的库
import torch

## 1.张量的符号索引

In [2]:
# 1.一维张量索引
t1 = torch.arange(0, 30, 3)
t1

tensor([ 0,  3,  6,  9, 12, 15, 18, 21, 24, 27])

In [3]:
t1[0], t1[3]  # 张量索引的结果还是张量

(tensor(0), tensor(9))

In [4]:
# 切片
t1[2:5]  # 左包含右不包含

tensor([ 6,  9, 12])

In [5]:
# 切片，带索引间隔
# t1[8:1:-1]  # ValueError，在张量的索引中，step必须大于0
t1[1:8:2]

tensor([ 3,  9, 15, 21])

In [6]:
t1[1::2], t1[:8:2]

(tensor([ 3,  9, 15, 21, 27]), tensor([ 0,  6, 12, 18]))

In [7]:
# 2.二维张量索引
t2 = torch.arange(1, 25).reshape(4, 6)
t2

tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18],
        [19, 20, 21, 22, 23, 24]])

In [8]:
t2[1, 2], t2[1, ::2], t2[1, [0, 2, 4]], t2[[0, 2], 1], t2[::2, ::2]

(tensor(9),
 tensor([ 7,  9, 11]),
 tensor([ 7,  9, 11]),
 tensor([ 2, 14]),
 tensor([[ 1,  3,  5],
         [13, 15, 17]]))

In [9]:
# 3.三维张量的索引
t3 = torch.arange(1, 121).reshape(4, 5, 6)
t3

tensor([[[  1,   2,   3,   4,   5,   6],
         [  7,   8,   9,  10,  11,  12],
         [ 13,  14,  15,  16,  17,  18],
         [ 19,  20,  21,  22,  23,  24],
         [ 25,  26,  27,  28,  29,  30]],

        [[ 31,  32,  33,  34,  35,  36],
         [ 37,  38,  39,  40,  41,  42],
         [ 43,  44,  45,  46,  47,  48],
         [ 49,  50,  51,  52,  53,  54],
         [ 55,  56,  57,  58,  59,  60]],

        [[ 61,  62,  63,  64,  65,  66],
         [ 67,  68,  69,  70,  71,  72],
         [ 73,  74,  75,  76,  77,  78],
         [ 79,  80,  81,  82,  83,  84],
         [ 85,  86,  87,  88,  89,  90]],

        [[ 91,  92,  93,  94,  95,  96],
         [ 97,  98,  99, 100, 101, 102],
         [103, 104, 105, 106, 107, 108],
         [109, 110, 111, 112, 113, 114],
         [115, 116, 117, 118, 119, 120]]])

In [10]:
t3[2, 3, 4], t3[2, ::2, 1::2], t3[::2, 1::2, 2::2]

(tensor(83),
 tensor([[62, 64, 66],
         [74, 76, 78],
         [86, 88, 90]]),
 tensor([[[ 9, 11],
          [21, 23]],
 
         [[69, 71],
          [81, 83]]]))

## 2.张量的函数索引

In [11]:
t1.ndim, t1

(1, tensor([ 0,  3,  6,  9, 12, 15, 18, 21, 24, 27]))

In [12]:
# 使用index_select函数进行索引
indices = torch.tensor([1, 3])
torch.index_select(t1, 0, indices)  # 也可以用torch.index_select(t1, -1, indices)，等价于t1[[1, 2, 5]]

tensor([3, 9])

In [13]:
t2.shape, t2

(torch.Size([4, 6]),
 tensor([[ 1,  2,  3,  4,  5,  6],
         [ 7,  8,  9, 10, 11, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 23, 24]]))

In [14]:
torch.index_select(t2, 0, indices), torch.index_select(t2, 1, indices)

(tensor([[ 7,  8,  9, 10, 11, 12],
         [19, 20, 21, 22, 23, 24]]),
 tensor([[ 2,  4],
         [ 8, 10],
         [14, 16],
         [20, 22]]))

## 3.torch.view()方法
PyTorch中的view()方法会返回一个类似视图的结果，该结果和原张量对象共享一块数据存储空间，并且通过view()方法，还可以改变对象结构，生成一个不同结构，但共享一个存储空间的张量。当然，共享一个存储空间，也就代表者是**浅拷贝**的关系，修改其中一个，另一个也会同步进行更改。视图的核心作用就是节省空间，张量的切分和合并得到的结果都是视图，而不是生成的新的对象。

In [15]:
t2_view_2 = t2.view(3, 8)  # 构建一个数据相同、但形状不同的“视图”，两者指向同一个对象
t2, t2_view_2

(tensor([[ 1,  2,  3,  4,  5,  6],
         [ 7,  8,  9, 10, 11, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 23, 24]]),
 tensor([[ 1,  2,  3,  4,  5,  6,  7,  8],
         [ 9, 10, 11, 12, 13, 14, 15, 16],
         [17, 18, 19, 20, 21, 22, 23, 24]]))

In [16]:
t2[1, ::2] = 20  # 两者会同步发生改变
t2, t2_view_2

(tensor([[ 1,  2,  3,  4,  5,  6],
         [20,  8, 20, 10, 20, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 23, 24]]),
 tensor([[ 1,  2,  3,  4,  5,  6, 20,  8],
         [20, 10, 20, 12, 13, 14, 15, 16],
         [17, 18, 19, 20, 21, 22, 23, 24]]))

In [17]:
t2_view_3 = t2.view(3, 2, 4)
t2, t2_view_3

(tensor([[ 1,  2,  3,  4,  5,  6],
         [20,  8, 20, 10, 20, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 23, 24]]),
 tensor([[[ 1,  2,  3,  4],
          [ 5,  6, 20,  8]],
 
         [[20, 10, 20, 12],
          [13, 14, 15, 16]],
 
         [[17, 18, 19, 20],
          [21, 22, 23, 24]]]))

In [18]:
t2_view_3[::2, 1, 2:] = 30
t2, t2_view_3

(tensor([[ 1,  2,  3,  4,  5,  6],
         [30, 30, 20, 10, 20, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 30, 30]]),
 tensor([[[ 1,  2,  3,  4],
          [ 5,  6, 30, 30]],
 
         [[20, 10, 20, 12],
          [13, 14, 15, 16]],
 
         [[17, 18, 19, 20],
          [21, 22, 30, 30]]]))

## 4.张量的分片函数

In [19]:
## 1.分块——chunk函数：按照某维度，对张量进行均匀切分，并且返回结果是原张量的视图，不改变维度
t2_chunked = torch.chunk(t2, 4, dim=0)
t2, t2_chunked, t2_chunked[0][0]

(tensor([[ 1,  2,  3,  4,  5,  6],
         [30, 30, 20, 10, 20, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 30, 30]]),
 (tensor([[1, 2, 3, 4, 5, 6]]),
  tensor([[30, 30, 20, 10, 20, 12]]),
  tensor([[13, 14, 15, 16, 17, 18]]),
  tensor([[19, 20, 21, 22, 30, 30]])),
 tensor([1, 2, 3, 4, 5, 6]))

In [20]:
t2_chunked[0][0][::2] = 10  # 两者会同步发生改变
t2, t2_chunked, t2_chunked[0][0]

(tensor([[10,  2, 10,  4, 10,  6],
         [30, 30, 20, 10, 20, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 30, 30]]),
 (tensor([[10,  2, 10,  4, 10,  6]]),
  tensor([[30, 30, 20, 10, 20, 12]]),
  tensor([[13, 14, 15, 16, 17, 18]]),
  tensor([[19, 20, 21, 22, 30, 30]])),
 tensor([10,  2, 10,  4, 10,  6]))

In [21]:
torch.chunk(t2, 3, dim=0), torch.chunk(t2, 7, dim=1)  # 当原张量不能均分时，chunk不会报错，但会返回其他均分的结果，即次一级均分结果或者非等分结果

((tensor([[10,  2, 10,  4, 10,  6],
          [30, 30, 20, 10, 20, 12]]),
  tensor([[13, 14, 15, 16, 17, 18],
          [19, 20, 21, 22, 30, 30]])),
 (tensor([[10],
          [30],
          [13],
          [19]]),
  tensor([[ 2],
          [30],
          [14],
          [20]]),
  tensor([[10],
          [20],
          [15],
          [21]]),
  tensor([[ 4],
          [10],
          [16],
          [22]]),
  tensor([[10],
          [20],
          [17],
          [30]]),
  tensor([[ 6],
          [12],
          [18],
          [30]])))

In [22]:
# 2.拆分——split函数：split既能进行均分，也能进行自定义切分，返回结果也是view
# 均分
t2_splitted = torch.split(t2, 2, 0)  # 第二个参数只输入一个数时表示均分
t2_splitted

(tensor([[10,  2, 10,  4, 10,  6],
         [30, 30, 20, 10, 20, 12]]),
 tensor([[13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 30, 30]]))

In [23]:
# torch.split(t2, [3, 2], 0)  # RuntimeError，当第二个参数位输入一个序列时，序列的各数值之和必须等于对应维度下形状分量的取值
torch.split(t2, [3, 1], 0), torch.split(t2, [1, 2, 1], 0), torch.split(t2, [1, 2, 1, 2], 1)  # 第二个参数输入一个序列时表示按照序列数值进行切分

((tensor([[10,  2, 10,  4, 10,  6],
          [30, 30, 20, 10, 20, 12],
          [13, 14, 15, 16, 17, 18]]),
  tensor([[19, 20, 21, 22, 30, 30]])),
 (tensor([[10,  2, 10,  4, 10,  6]]),
  tensor([[30, 30, 20, 10, 20, 12],
          [13, 14, 15, 16, 17, 18]]),
  tensor([[19, 20, 21, 22, 30, 30]])),
 (tensor([[10],
          [30],
          [13],
          [19]]),
  tensor([[ 2, 10],
          [30, 20],
          [14, 15],
          [20, 21]]),
  tensor([[ 4],
          [10],
          [16],
          [22]]),
  tensor([[10,  6],
          [20, 12],
          [17, 18],
          [30, 30]])))

In [24]:
t2_splitted[0][1:, 2::2] = 40
t2, t2_splitted  # view进行修改，原对象同步修改

(tensor([[10,  2, 10,  4, 10,  6],
         [30, 30, 40, 10, 40, 12],
         [13, 14, 15, 16, 17, 18],
         [19, 20, 21, 22, 30, 30]]),
 (tensor([[10,  2, 10,  4, 10,  6],
          [30, 30, 40, 10, 40, 12]]),
  tensor([[13, 14, 15, 16, 17, 18],
          [19, 20, 21, 22, 30, 30]])))

## 5.张量的合并操作
张量的合并操作类似于列表的追加元素，可以拼接、也可以堆叠。

In [25]:
# 1.拼接——cat函数：实现张量的拼接
a = torch.zeros(2, 3)
b = torch.ones(2, 3)
c = torch.zeros(3, 3)
a, b, c

(tensor([[0., 0., 0.],
         [0., 0., 0.]]),
 tensor([[1., 1., 1.],
         [1., 1., 1.]]),
 tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]))

In [26]:
torch.cat([a, b])  # 按行进行拼接，dim参数默认为0

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [27]:
# torch.cat([a, c], dim=1)  # RuntimeError，对应维度形状不匹配无法进行拼接
torch.cat([a, b], dim=1)  # 按列进行拼接

tensor([[0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.]])

In [28]:
# 2.堆叠——stack函数：堆叠不是将元素拆分重装，而是简单地将各参与堆叠的对象分装到一个更高维度的张量
a.shape, b.shape, c.shape

(torch.Size([2, 3]), torch.Size([2, 3]), torch.Size([3, 3]))

In [29]:
# torch.stack([a, c])  # RuntimeError，被堆叠的张量形状应该一致
ab_stacked = torch.stack([a, b])  # 堆叠之后，生成1个三维张量
ab_stacked.shape, ab_stacked

(torch.Size([2, 2, 3]),
 tensor([[[0., 0., 0.],
          [0., 0., 0.]],
 
         [[1., 1., 1.],
          [1., 1., 1.]]]))

二者区别：
- 拼接之后维度不变，堆叠之后维度升高
- 拼接是把一个个元素单独提取出来之后再放到二维张量中，而堆叠则是直接将两个二维张量封装到一个三维张量中
- 堆叠的要求更高，参与堆叠的张量必须形状完全相同
## 6.张量维度变换
通过reshape方法， 能够灵活调整张量的形状，而在实际操作张量进行计算时，往往需要另外进行降维和升维的操作：
- 当我们需要除去不必要的维度时，可以使用squeeze函数
- 需要手动升维时，则可采用unsqueeze函数

In [30]:
t4 = torch.ones(1, 1, 3, 1)  # 1个包含1个三维的四维张量，三维张量只包含1个三行一列的二维张量
t4.ndim, t4.shape, t4

(4,
 torch.Size([1, 1, 3, 1]),
 tensor([[[[1.],
           [1.],
           [1.]]]]))

In [31]:
# 1.squeeze函数：删除不必要的维度
t4_squeezed = torch.squeeze(t4)  # 去除为1的维度，等价于t4.squeeze()
t4_squeezed.shape, t4_squeezed

(torch.Size([3]), tensor([1., 1., 1.]))

In [32]:
t5 = torch.ones(1, 1, 3, 2, 1, 2)
t5.ndim, t5.shape, t5

(6,
 torch.Size([1, 1, 3, 2, 1, 2]),
 tensor([[[[[[1., 1.]],
 
            [[1., 1.]]],
 
 
           [[[1., 1.]],
 
            [[1., 1.]]],
 
 
           [[[1., 1.]],
 
            [[1., 1.]]]]]]))

In [33]:
t5_squeezed = torch.squeeze(t5)
t5_squeezed.shape, t5_squeezed

(torch.Size([3, 2, 2]),
 tensor([[[1., 1.],
          [1., 1.]],
 
         [[1., 1.],
          [1., 1.]],
 
         [[1., 1.],
          [1., 1.]]]))

In [34]:
# 2.unsqueeze函数：手动升维
t6 = torch.full((1, 2, 1, 3), 5)
t6.shape, t6

(torch.Size([1, 2, 1, 3]),
 tensor([[[[5, 5, 5]],
 
          [[5, 5, 5]]]]))

In [35]:
t6_unsqueezed = torch.unsqueeze(t6, dim=0)  # 在指定维度上升维，等价于t6.unsqueeze(0)
t6_unsqueezed.shape, t6_unsqueezed

(torch.Size([1, 1, 2, 1, 3]),
 tensor([[[[[5, 5, 5]],
 
           [[5, 5, 5]]]]]))

In [36]:
torch.unsqueeze(t6, dim=2).shape, torch.unsqueeze(t6, dim=4).shape

(torch.Size([1, 2, 1, 1, 3]), torch.Size([1, 2, 1, 3, 1]))