### 实现 Attention-Based Conv2d Pruning

设置随机数种子

In [2]:
import torch
import numpy as np
import random
import torch.nn.functional as fn

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

setup_seed(0)

定义随机数据

In [3]:
d1, d2, d3 = 64, 32, 32

x, y, z = torch.randn(d1, d2, d3), torch.randn(d1, d2, d3), torch.ones(d1, d2, d3)

y[1] = torch.ones(1, d2, d3)

x[0], y[:2], z[0]

(tensor([[-1.1258, -1.1524, -0.2506,  ...,  1.5863,  0.9463, -0.8437],
         [-0.6136,  0.0316, -0.4927,  ..., -1.2341,  1.8197, -0.5515],
         [-0.5692,  0.9200,  1.1108,  ..., -0.9565,  0.0335,  0.7101],
         ...,
         [ 1.0166,  1.2868,  2.0820,  ...,  0.8161, -0.5711, -0.1195],
         [-0.4274,  0.8143, -1.4121,  ..., -0.1394, -0.3677, -0.4574],
         [-1.2945,  0.7012, -1.9098,  ...,  0.5374,  1.0826, -1.7105]]),
 tensor([[[ 0.0193,  0.4089,  0.1344,  ..., -0.6761, -1.3392,  1.8296],
          [ 0.7607, -0.3626, -0.8501,  ..., -1.1197, -0.9004,  1.3018],
          [-1.2728,  0.3214,  0.0853,  ..., -1.2682, -1.2450, -1.5951],
          ...,
          [-0.3060,  0.4043,  2.3663,  ...,  0.7321, -0.9249, -2.1863],
          [-0.3336,  2.4964,  1.0345,  ..., -1.8370,  0.1747,  0.3298],
          [-0.6356,  1.6734, -0.0258,  ...,  0.3021,  0.7552,  0.1049]],
 
         [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000

#### 1. 将 Weight data 类型转换和求绝对值
- x [N<sub>in</sub>, N<sub>out</sub>, kernel_size[0], kernel_size[1]]
- A [N<sub>in</sub>, N<sub>out</sub>, kernel_size[0] * kernel_size[1]]
- C = N<sub>in</sub>, H =  N<sub>out</sub>, W = kernel_size[0] * kernel_size[1]
- A [C, H, W]

In [8]:
Ax = x.view(d1, d2, -1).abs()
Ay = y.view(d1, d2, -1).abs()
Az = z.view(d1, d2, -1).abs()
Ax.shape, Ay.shape, Az.shape

(torch.Size([64, 32, 32]), torch.Size([64, 32, 32]), torch.Size([64, 32, 32]))

#### 2. 计算 F<sub>sum</sub>(A) = ∑<sub>i=1</sub><sup>C</sup> |A<sub>i</sub>| 沿通道方向绝对值之和 


In [9]:
c, h, w = Ax.shape
FsumAx = torch.zeros(h, w)
FsumAy = torch.zeros(h, w)
FsumAz = torch.zeros(h, w)
for i in range(c):
    FsumAx.add_(torch.abs(Ax[i]))
    FsumAy.add_(torch.abs(Ay[i]))
    FsumAz.add_(torch.abs(Az[i]))
FsumAx, FsumAy, FsumAz

(tensor([[48.1977, 51.6474, 46.4231,  ..., 61.6596, 50.9151, 43.6817],
         [42.1180, 54.0695, 48.3262,  ..., 46.9935, 50.8309, 47.9499],
         [56.6747, 45.5154, 45.9777,  ..., 58.7579, 48.5513, 42.0687],
         ...,
         [50.3370, 47.9057, 46.3192,  ..., 53.0649, 45.7578, 56.4153],
         [54.6961, 47.7189, 50.2934,  ..., 43.5095, 51.5398, 44.3939],
         [51.5721, 50.8895, 54.3762,  ..., 55.9808, 56.1004, 62.8565]]),
 tensor([[48.9983, 56.7777, 56.5189,  ..., 41.2521, 58.0901, 53.5648],
         [50.3488, 61.2142, 57.5986,  ..., 56.4483, 50.8274, 48.7947],
         [48.6556, 56.6372, 53.2805,  ..., 49.5691, 60.1352, 54.4315],
         ...,
         [50.0536, 47.7087, 56.0015,  ..., 59.2379, 57.3356, 46.5085],
         [44.0457, 58.9825, 52.7712,  ..., 53.6677, 46.2976, 51.7568],
         [51.2917, 49.5835, 48.6539,  ..., 57.3652, 54.6319, 49.1777]]),
 tensor([[64., 64., 64.,  ..., 64., 64., 64.],
         [64., 64., 64.,  ..., 64., 64., 64.],
         [64., 64., 64

#### 3. 计算 ||F(A)||<sub>2</sub>  二范式

In [11]:
FAx_s, FAy_s, FAz_s = torch.linalg.norm(FsumAx), torch.linalg.norm(FsumAy), torch.linalg.norm(FsumAz)
FAx_s, FAy_s, FAz_s

(tensor(1640.5884), tensor(1654.4967), tensor(2048.))

#### 4. 计算 F(A) / ||F(A)||<sub>2</sub> 二范式规则化的矩阵 

In [13]:
Fx, Fy, Fz = FsumAx / FAx_s, FsumAy / FAy_s,  FsumAz / FAz_s
Fx, Fy, Fz

(tensor([[0.0294, 0.0315, 0.0283,  ..., 0.0376, 0.0310, 0.0266],
         [0.0257, 0.0330, 0.0295,  ..., 0.0286, 0.0310, 0.0292],
         [0.0345, 0.0277, 0.0280,  ..., 0.0358, 0.0296, 0.0256],
         ...,
         [0.0307, 0.0292, 0.0282,  ..., 0.0323, 0.0279, 0.0344],
         [0.0333, 0.0291, 0.0307,  ..., 0.0265, 0.0314, 0.0271],
         [0.0314, 0.0310, 0.0331,  ..., 0.0341, 0.0342, 0.0383]]),
 tensor([[0.0296, 0.0343, 0.0342,  ..., 0.0249, 0.0351, 0.0324],
         [0.0304, 0.0370, 0.0348,  ..., 0.0341, 0.0307, 0.0295],
         [0.0294, 0.0342, 0.0322,  ..., 0.0300, 0.0363, 0.0329],
         ...,
         [0.0303, 0.0288, 0.0338,  ..., 0.0358, 0.0347, 0.0281],
         [0.0266, 0.0356, 0.0319,  ..., 0.0324, 0.0280, 0.0313],
         [0.0310, 0.0300, 0.0294,  ..., 0.0347, 0.0330, 0.0297]]),
 tensor([[0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.0312],
         [0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.0312],
         [0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.

In [16]:
torch.sum(Fx), torch.sum(Fy), torch.sum(Fz)

(tensor(31.8582), tensor(31.8586), tensor(32.))

#### 5. 计算 F(A<sub>j</sub>) / ||F(A<sub>j</sub>)||<sub>2</sub>  和 gamma = ∑ | F(A) / ||F(A)||<sub>2</sub> - F(A<sub>j</sub>) / ||F(A<sub>j</sub>)||<sub>2</sub> |

In [18]:
Ax0, Az0 = Ax[0], Az[0]
print(Ax0, Az0)

FAx0, FAz0 = FsumAx - Ax0, FsumAz - Az0
print(FAx0, FAz0)

FAx0_s, FAz0_s = torch.linalg.norm(FAx0), torch.linalg.norm(FAz0)
print(FAx0_s, FAz0_s)

Fx0, Fz0 = FAx0 / FAx0_s, FAz0 / FAz0_s
print(Fx0, Fz0)

gammax0, gammaz0 = (Fx - Fx0).abs().sum(), (Fz - Fz0).abs().sum()
gammax0, gammaz0

tensor([[1.1258, 1.1524, 0.2506,  ..., 1.5863, 0.9463, 0.8437],
        [0.6136, 0.0316, 0.4927,  ..., 1.2341, 1.8197, 0.5515],
        [0.5692, 0.9200, 1.1108,  ..., 0.9565, 0.0335, 0.7101],
        ...,
        [1.0166, 1.2868, 2.0820,  ..., 0.8161, 0.5711, 0.1195],
        [0.4274, 0.8143, 1.4121,  ..., 0.1394, 0.3677, 0.4574],
        [1.2945, 0.7012, 1.9098,  ..., 0.5374, 1.0826, 1.7105]]) tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])
tensor([[47.0719, 50.4951, 46.1725,  ..., 60.0732, 49.9688, 42.8380],
        [41.5044, 54.0379, 47.8335,  ..., 45.7593, 49.0112, 47.3984],
        [56.1055, 44.5954, 44.8669,  ..., 57.8014, 48.5177, 41.3586],
        ...,
        [49.3204, 46.6189, 44.2372,  ..., 52.2488, 45.1867, 56.2959],
        [54.2687, 46.9045, 48.8813,  ..., 43.3701, 51.1

(tensor(0.3112), tensor(0.))

In [20]:
gammax_list, gammay_list, gammaz_list = torch.zeros(c), torch.zeros(c), torch.zeros(c)

for j in range(c):
    Axj, Ayj, Azj = Ax[j], Ay[j], Az[j]
    FAxj, FAyj, FAzj = FsumAx - Axj, FsumAy - Ayj, FsumAz - Azj
    FAxj_s, FAyj_s, FAzj_s = torch.linalg.norm(FAxj), torch.linalg.norm(FAyj), torch.linalg.norm(FAzj)
    Fxj, Fyj, Fzj = FAxj / FAxj_s, FAyj / FAyj_s, FAzj / FAzj_s
    gammmax, gammay, gammaz = (Fx - Fxj).abs().sum(), (Fy - Fyj).abs().sum(), (Fz - Fzj).abs().sum()
    gammax_list[j], gammay_list[j], gammaz_list[j] = gammmax, gammay, gammaz

gammax_list, gammay_list, gammaz_list

(tensor([0.3112, 0.3127, 0.2880, 0.2943, 0.3002, 0.3134, 0.2978, 0.3092, 0.3024,
         0.2959, 0.3060, 0.3157, 0.3089, 0.3090, 0.3002, 0.2985, 0.3007, 0.3087,
         0.3073, 0.2993, 0.3005, 0.2949, 0.2980, 0.3024, 0.2956, 0.3008, 0.3029,
         0.3097, 0.2956, 0.3042, 0.3034, 0.3053, 0.2959, 0.3067, 0.3014, 0.3037,
         0.3067, 0.2991, 0.3135, 0.3076, 0.2993, 0.2992, 0.2921, 0.3010, 0.2949,
         0.3049, 0.3029, 0.3034, 0.3047, 0.2921, 0.2968, 0.3031, 0.2911, 0.3072,
         0.2983, 0.2948, 0.2939, 0.3036, 0.2994, 0.3078, 0.3036, 0.2996, 0.3017,
         0.3022]),
 tensor([0.3039, 0.0479, 0.3118, 0.2895, 0.3143, 0.3079, 0.2990, 0.3046, 0.2968,
         0.3132, 0.2984, 0.3003, 0.2869, 0.3050, 0.3043, 0.3022, 0.3027, 0.2905,
         0.2951, 0.3082, 0.3039, 0.2981, 0.2921, 0.3175, 0.3005, 0.2916, 0.3070,
         0.3168, 0.2924, 0.3050, 0.2972, 0.3020, 0.3023, 0.3003, 0.3001, 0.2936,
         0.3072, 0.3047, 0.2877, 0.3013, 0.2996, 0.3141, 0.3095, 0.3184, 0.3076,
         

In [24]:
sortx, sortx_index = torch.sort(gammax_list, descending=True)
sorty, sorty_index = torch.sort(gammay_list, descending=True)
sortz, sortz_index = torch.sort(gammaz_list, descending=True)
maxx, maxx_idx = sortx[0], sortx_index[0]
maxy, maxy_idx = sorty[0], sorty_index[0]
maxz, maxz_idx = sortz[0], sortz_index[0]
print(maxx, maxx_idx, Ax[maxx_idx])
print(maxy, maxy_idx, Ax[maxy_idx])
print(maxz, maxz_idx, Ax[maxz_idx])

tensor(0.3157) tensor(11) tensor([[1.2072, 1.3946, 0.2554,  ..., 0.7349, 0.1060, 0.0482],
        [1.8015, 0.7515, 0.8094,  ..., 0.7470, 2.3901, 0.4567],
        [1.1487, 0.3122, 0.9012,  ..., 0.5979, 1.0858, 0.2447],
        ...,
        [0.5948, 0.4020, 0.0962,  ..., 1.7930, 0.8511, 1.4528],
        [0.1486, 0.4138, 0.5233,  ..., 0.6924, 0.4695, 0.6201],
        [0.9624, 0.3194, 0.4420,  ..., 1.2246, 0.9603, 0.6734]])
tensor(0.3184) tensor(43) tensor([[0.5971, 0.9966, 1.5495,  ..., 0.8722, 1.1792, 1.2553],
        [1.3510, 0.3032, 0.8867,  ..., 0.8558, 0.8534, 0.7020],
        [0.3258, 1.1347, 0.5089,  ..., 0.8760, 0.6129, 0.3393],
        ...,
        [0.0091, 0.4686, 0.8840,  ..., 2.1571, 1.1152, 0.8608],
        [0.4225, 0.0963, 1.4006,  ..., 0.5018, 0.5187, 0.6009],
        [1.6907, 1.5409, 0.3857,  ..., 0.6257, 0.9216, 0.2888]])
tensor(0.) tensor(0) tensor([[1.1258, 1.1524, 0.2506,  ..., 1.5863, 0.9463, 0.8437],
        [0.6136, 0.0316, 0.4927,  ..., 1.2341, 1.8197, 0.5515],
   

In [29]:
pruning_rate = 0.7
num_total = len(gammax_list)
thre_idx = int(num_total * pruning_rate)
thresholdx, thresholdy, thresholdz = sortx[thre_idx], sorty[thre_idx], sortz[thre_idx]
maskx, masky, maskz = gammax_list.lt(thresholdx).float(), gammay_list.lt(thresholdy).float(), gammaz_list.lt(thresholdz).float()

num_remainx, num_remainy, num_remainz = torch.sum(maskx), torch.sum(masky), torch.sum(maskz)

num_remainx / num_total, num_remainy / num_total, num_remainz /num_total

(tensor(0.2969), tensor(0.2969), tensor(0.))