##### 设置随机数种子

In [1]:
import torch
import numpy as np
import random
import torch.nn.functional as fn

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

setup_seed(0)

### 实现 Attention-Based Conv2d Pruning

In [2]:
x = torch.randn(64, 4, 3, 3)
z = torch.ones(64, 4, 3, 3)
x, z, x.shape, z.shape

(tensor([[[[-1.1258e+00, -1.1524e+00, -2.5058e-01],
           [-4.3388e-01,  8.4871e-01,  6.9201e-01],
           [-3.1601e-01, -2.1152e+00,  3.2227e-01]],
 
          [[-1.2633e+00,  3.4998e-01,  3.0813e-01],
           [ 1.1984e-01,  1.2377e+00,  1.1168e+00],
           [-2.4728e-01, -1.3527e+00, -1.6959e+00]],
 
          [[ 5.6665e-01,  7.9351e-01,  5.9884e-01],
           [-1.5551e+00, -3.4136e-01,  1.8530e+00],
           [ 7.5019e-01, -5.8550e-01, -1.7340e-01]],
 
          [[ 1.8348e-01,  1.3894e+00,  1.5863e+00],
           [ 9.4630e-01, -8.4368e-01, -6.1358e-01],
           [ 3.1593e-02, -4.9268e-01,  2.4841e-01]]],
 
 
         [[[ 4.3970e-01,  1.1241e-01,  6.4079e-01],
           [ 4.4116e-01, -1.0231e-01,  7.9244e-01],
           [-2.8967e-01,  5.2507e-02,  5.2286e-01]],
 
          [[ 2.3022e+00, -1.4689e+00, -1.5867e+00],
           [-6.7309e-01,  8.7283e-01,  1.0554e+00],
           [ 1.7784e-01, -2.3034e-01, -3.9175e-01]],
 
          [[ 5.4329e-01, -3.9516e-01, -4.46

#### 1. 将 Weight data 类型转换和求绝对值
- x [N<sub>in</sub>, N<sub>out</sub>, kernel_size[0], kernel_size[1]]
- A [N<sub>in</sub>, N<sub>out</sub>, kernel_size[0] * kernel_size[1]]
- C = N<sub>in</sub>, H =  N<sub>out</sub>, W = kernel_size[0] * kernel_size[1]
- A [C, H, W]

In [3]:
d11, d12, _, _ = x.shape
d21, d22, _, _ = z.shape
Ax = x.view(d11, d12, -1).abs()
Az = z.view(d21, d22, -1).abs()
Ax, Az, Ax.shape, Az.shape

(tensor([[[1.1258e+00, 1.1524e+00, 2.5058e-01,  ..., 3.1601e-01,
           2.1152e+00, 3.2227e-01],
          [1.2633e+00, 3.4998e-01, 3.0813e-01,  ..., 2.4728e-01,
           1.3527e+00, 1.6959e+00],
          [5.6665e-01, 7.9351e-01, 5.9884e-01,  ..., 7.5019e-01,
           5.8550e-01, 1.7340e-01],
          [1.8348e-01, 1.3894e+00, 1.5863e+00,  ..., 3.1593e-02,
           4.9268e-01, 2.4841e-01]],
 
         [[4.3970e-01, 1.1241e-01, 6.4079e-01,  ..., 2.8967e-01,
           5.2507e-02, 5.2286e-01],
          [2.3022e+00, 1.4689e+00, 1.5867e+00,  ..., 1.7784e-01,
           2.3034e-01, 3.9175e-01],
          [5.4329e-01, 3.9516e-01, 4.4622e-01,  ..., 1.5312e+00,
           1.2341e+00, 1.8197e+00],
          [5.5153e-01, 5.6925e-01, 9.1997e-01,  ..., 2.5672e+00,
           4.7312e-01, 3.3555e-01]],
 
         [[1.6293e+00, 5.4974e-01, 4.7983e-01,  ..., 1.4067e-01,
           8.0575e-01, 9.3348e-02],
          [6.8705e-01, 8.3832e-01, 8.9182e-04,  ..., 3.5815e-01,
           2.4600e-0

#### 2. 沿通道方向绝对值之和 
F(A)=∑<sub>i=1</sub><sup>C</sup> |A<sub>i</sub>|

In [5]:
c, h, w = Ax.shape
FAx = torch.zeros(h, w)
FAz = torch.zeros(h, w)
for i in range(d11):
    FAx.add_(torch.abs(Ax[i]))
    FAz.add_(torch.abs(Az[i]))
FAx, FAz

(tensor([[56.6199, 40.9421, 53.9351, 48.8748, 46.8690, 59.5128, 39.5230, 53.0775,
          52.0177],
         [57.4621, 50.7123, 49.4178, 51.8026, 54.0485, 59.0183, 42.7864, 40.8320,
          47.9534],
         [51.3775, 62.0132, 44.3578, 48.8647, 51.0990, 50.5940, 46.7558, 52.7277,
          57.0936],
         [58.7794, 51.1150, 49.7080, 53.7365, 50.6284, 46.9924, 57.3430, 52.9584,
          49.6174]]),
 tensor([[64., 64., 64., 64., 64., 64., 64., 64., 64.],
         [64., 64., 64., 64., 64., 64., 64., 64., 64.],
         [64., 64., 64., 64., 64., 64., 64., 64., 64.],
         [64., 64., 64., 64., 64., 64., 64., 64., 64.]]))

#### 3. 计算 ||F(A)||<sup>2</sup>  二范数的平方

In [7]:
FAx_s, FAz_s = torch.linalg.norm(FAx), torch.linalg.norm(FAz)
FAx_s, FAz_s

(tensor(308.5361), tensor(384.))

#### 4. 计算 F(A) / ||F(A)||<sup>2</sup> 

In [8]:
Fx, Fz = FAx / FAx_s,  FAz / FAz_s
Fx, Fz

(tensor([[0.1835, 0.1327, 0.1748, 0.1584, 0.1519, 0.1929, 0.1281, 0.1720, 0.1686],
         [0.1862, 0.1644, 0.1602, 0.1679, 0.1752, 0.1913, 0.1387, 0.1323, 0.1554],
         [0.1665, 0.2010, 0.1438, 0.1584, 0.1656, 0.1640, 0.1515, 0.1709, 0.1850],
         [0.1905, 0.1657, 0.1611, 0.1742, 0.1641, 0.1523, 0.1859, 0.1716, 0.1608]]),
 tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]))

#### 5. 计算 F(A<sub>j</sub>) / ||F(A<sub>j</sub>)||<sup>2</sup>  和 gamma = ∑ | F(A) / ||F(A)||<sup>2</sup> - F(A<sub>j</sub>) / ||F(A<sub>j</sub>)||<sup>2</sup> |

In [19]:
Ax0, Az0 = Ax[0], Az[0]
print(Ax0, Az0)

FAx0, FAz0 = FAx - Ax0, FAz - Az0
print(FAx0, FAz0)

FAx0_s, FAz0_s = torch.linalg.norm(FAx0), torch.linalg.norm(FAz0)
print(FAx0_s, FAz0_s)

Fx0, Fz0 = FAx0 / FAx0_s, FAz0 / FAz0_s
print(Fx0, Fz0)

gammax0, gammaz0 = (Fx - Fx0).abs().sum(), (Fz - Fz0).abs().sum()
gammax0, gammaz0

tensor([[1.1258, 1.1524, 0.2506, 0.4339, 0.8487, 0.6920, 0.3160, 2.1152, 0.3223],
        [1.2633, 0.3500, 0.3081, 0.1198, 1.2377, 1.1168, 0.2473, 1.3527, 1.6959],
        [0.5667, 0.7935, 0.5988, 1.5551, 0.3414, 1.8530, 0.7502, 0.5855, 0.1734],
        [0.1835, 1.3894, 1.5863, 0.9463, 0.8437, 0.6136, 0.0316, 0.4927, 0.2484]]) tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[55.4941, 39.7897, 53.6846, 48.4409, 46.0203, 58.8208, 39.2070, 50.9622,
         51.6954],
        [56.1988, 50.3623, 49.1097, 51.6828, 52.8109, 57.9015, 42.5391, 39.4794,
         46.2575],
        [50.8109, 61.2197, 43.7589, 47.3096, 50.7576, 48.7410, 46.0056, 52.1422,
         56.9202],
        [58.5959, 49.7256, 48.1216, 52.7902, 49.7847, 46.3789, 57.3114, 52.4657,
         49.3689]]) tensor([[63., 63., 63., 63., 63., 63., 63., 63., 63.],
        [63., 63., 63., 63., 6

(tensor(0.0546), tensor(0.))

In [22]:
gammax_list, gammaz_list = torch.zeros(c), torch.zeros(c)
for j in range(c):
    Axj, Azj = Ax[j], Az[j]
    FAxj, FAzj = FAx - Axj, FAz - Azj
    FAxj_s, FAzj_s = torch.linalg.norm(FAxj), torch.linalg.norm(FAzj)
    Fxj, Fzj = FAxj / FAxj_s, FAzj / FAzj_s
    gammmax, gammmaz = (Fx - Fxj).abs().sum(), (Fz - Fzj).abs().sum()
    gammax_list[j] = gammmax
    gammaz_list[j] = gammmaz

gammax_list, gammaz_list

(tensor([0.0546, 0.0671, 0.0609, 0.0496, 0.0437, 0.0406, 0.0526, 0.0524, 0.0428,
         0.0674, 0.0513, 0.0571, 0.0700, 0.0606, 0.0682, 0.0583, 0.0720, 0.0541,
         0.0582, 0.0520, 0.0550, 0.0501, 0.0600, 0.0514, 0.0582, 0.0712, 0.0583,
         0.0555, 0.0482, 0.0614, 0.0545, 0.0447, 0.0492, 0.0716, 0.0651, 0.0625,
         0.0574, 0.0620, 0.0560, 0.0640, 0.0508, 0.0675, 0.0500, 0.0625, 0.0535,
         0.0604, 0.0685, 0.0520, 0.0569, 0.0577, 0.0527, 0.0724, 0.0488, 0.0514,
         0.0493, 0.0551, 0.0405, 0.0504, 0.0475, 0.0516, 0.0435, 0.0553, 0.0502,
         0.0469]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))

In [26]:
sortx, sortx_index = torch.sort(gammax_list, descending=True)
sortz, sorty_index = torch.sort(gammaz_list, descending=True)
maxx, maxx_index = sortx[0], sortx_index[0]
maxx, maxx_index, Ax[maxx_index]

(tensor(0.0724),
 tensor(51),
 tensor([[0.2082, 0.4852, 0.0437, 0.1160, 0.7978, 0.2425, 0.4761, 0.3796, 0.0114],
         [0.4512, 1.0632, 0.2497, 0.0545, 0.8129, 0.9576, 1.2139, 0.5725, 0.0793],
         [1.1229, 1.4157, 0.2296, 0.1575, 1.1857, 1.2910, 0.0035, 0.7711, 2.6715],
         [1.9461, 2.2148, 1.6746, 0.2388, 1.9177, 0.6798, 2.5549, 1.3455, 1.9527]]))

### 1. 测试 `linalg.norm` 与 `nn.functional.normalize` 的区别
- 二范数：( |x<sub>1</sub>|<sup>2</sup> + |x<sub>2</sub>|<sup>2</sup> +...+|x<sub>n</sub>|<sup>2</sup>)<sup>1/2</sup>
- `torch.norm` 已废弃，不建议使用
- [linalg.norm](https://pytorch.org/docs/stable/linalg.html?highlight=norm#torch.linalg.norm) 在计算矩阵二范数时，没有开平方根
- [nn.functional.normalize](https://pytorch.org/docs/stable/nn.functional.html?highlight=normalize#torch.nn.functional.normalize) 在计算矩阵二范数时，开了平方根
- 即 `torch.sqrt(torch.linalg.norm(data))` 等价于 `fn.normalize(data)`

In [None]:
data = torch.ones(2, 2)
data

####  1.1 linalg.norm without square root

In [None]:
norm_data = torch.linalg.norm(data)
result1 = data / norm_data
norm_data, result1

#### 1.2  linalg.norm square root

In [None]:
norm_data = torch.sqrt(torch.linalg.norm(data))  # 增加开根号
result2 = data / norm_data
norm_data, result2

#### 1.3 manual functional.normalize

In [None]:
denom = data.norm(2, 1, False).clamp_min(1e-12)
ret = data / denom
denom, ret

#### 1.4 auto functional.normalize

In [None]:
ret2 = fn.normalize(result_z)
ret2