## 实现 Attention-Based Conv2d Pruning

$$
\gamma= \| \frac{F(A)}{\|F(A)\|_2} - \frac{F(A_j)}{\|F(A_j)\|_2} \|_2
$$

### 1. 按步骤实现（可能存在问题）

设置随机数种子

In [2]:
import torch
import numpy as np
import random
import torch.nn.functional as fn

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

setup_seed(0)

定义随机数据

In [3]:
d1, d2, d3 = 64, 32, 32

x, y, z = torch.randn(d1, d2, d3), torch.randn(d1, d2, d3), torch.ones(d1, d2, d3)

y[1] = torch.ones(1, d2, d3)

x[0], y[:2], z[0]

(tensor([[-1.1258, -1.1524, -0.2506,  ...,  1.5863,  0.9463, -0.8437],
         [-0.6136,  0.0316, -0.4927,  ..., -1.2341,  1.8197, -0.5515],
         [-0.5692,  0.9200,  1.1108,  ..., -0.9565,  0.0335,  0.7101],
         ...,
         [ 1.0166,  1.2868,  2.0820,  ...,  0.8161, -0.5711, -0.1195],
         [-0.4274,  0.8143, -1.4121,  ..., -0.1394, -0.3677, -0.4574],
         [-1.2945,  0.7012, -1.9098,  ...,  0.5374,  1.0826, -1.7105]]),
 tensor([[[ 0.0193,  0.4089,  0.1344,  ..., -0.6761, -1.3392,  1.8296],
          [ 0.7607, -0.3626, -0.8501,  ..., -1.1197, -0.9004,  1.3018],
          [-1.2728,  0.3214,  0.0853,  ..., -1.2682, -1.2450, -1.5951],
          ...,
          [-0.3060,  0.4043,  2.3663,  ...,  0.7321, -0.9249, -2.1863],
          [-0.3336,  2.4964,  1.0345,  ..., -1.8370,  0.1747,  0.3298],
          [-0.6356,  1.6734, -0.0258,  ...,  0.3021,  0.7552,  0.1049]],
 
         [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000

#### 1. 将 Weight data 类型转换和求绝对值
- x [N<sub>in</sub>, N<sub>out</sub>, kernel_size[0], kernel_size[1]]
- A [N<sub>in</sub>, N<sub>out</sub>, kernel_size[0] * kernel_size[1]]
- C = N<sub>in</sub>, H =  N<sub>out</sub>, W = kernel_size[0] * kernel_size[1]
- A [C, H, W]

In [8]:
Ax = x.view(d1, d2, -1).abs()
Ay = y.view(d1, d2, -1).abs()
Az = z.view(d1, d2, -1).abs()
Ax.shape, Ay.shape, Az.shape

(torch.Size([64, 32, 32]), torch.Size([64, 32, 32]), torch.Size([64, 32, 32]))

#### 2. 计算 $F_{sum}(A)=\sum_{i=1}^C|A_i|$ 沿通道方向绝对值之和 


In [9]:
c, h, w = Ax.shape
FsumAx = torch.zeros(h, w)
FsumAy = torch.zeros(h, w)
FsumAz = torch.zeros(h, w)
for i in range(c):
    FsumAx.add_(torch.abs(Ax[i]))
    FsumAy.add_(torch.abs(Ay[i]))
    FsumAz.add_(torch.abs(Az[i]))
FsumAx, FsumAy, FsumAz

(tensor([[48.1977, 51.6474, 46.4231,  ..., 61.6596, 50.9151, 43.6817],
         [42.1180, 54.0695, 48.3262,  ..., 46.9935, 50.8309, 47.9499],
         [56.6747, 45.5154, 45.9777,  ..., 58.7579, 48.5513, 42.0687],
         ...,
         [50.3370, 47.9057, 46.3192,  ..., 53.0649, 45.7578, 56.4153],
         [54.6961, 47.7189, 50.2934,  ..., 43.5095, 51.5398, 44.3939],
         [51.5721, 50.8895, 54.3762,  ..., 55.9808, 56.1004, 62.8565]]),
 tensor([[48.9983, 56.7777, 56.5189,  ..., 41.2521, 58.0901, 53.5648],
         [50.3488, 61.2142, 57.5986,  ..., 56.4483, 50.8274, 48.7947],
         [48.6556, 56.6372, 53.2805,  ..., 49.5691, 60.1352, 54.4315],
         ...,
         [50.0536, 47.7087, 56.0015,  ..., 59.2379, 57.3356, 46.5085],
         [44.0457, 58.9825, 52.7712,  ..., 53.6677, 46.2976, 51.7568],
         [51.2917, 49.5835, 48.6539,  ..., 57.3652, 54.6319, 49.1777]]),
 tensor([[64., 64., 64.,  ..., 64., 64., 64.],
         [64., 64., 64.,  ..., 64., 64., 64.],
         [64., 64., 64

#### 3. 计算 $\|F(A)\|_2$ 二范式

In [11]:
FAx_s, FAy_s, FAz_s = torch.linalg.norm(FsumAx), torch.linalg.norm(FsumAy), torch.linalg.norm(FsumAz)
FAx_s, FAy_s, FAz_s

(tensor(1640.5884), tensor(1654.4967), tensor(2048.))

#### 4. 计算 $\frac{F(A)}{\|F(A)\|_2}$ 二范式规则化的矩阵 

In [13]:
Fx, Fy, Fz = FsumAx / FAx_s, FsumAy / FAy_s,  FsumAz / FAz_s
Fx, Fy, Fz

(tensor([[0.0294, 0.0315, 0.0283,  ..., 0.0376, 0.0310, 0.0266],
         [0.0257, 0.0330, 0.0295,  ..., 0.0286, 0.0310, 0.0292],
         [0.0345, 0.0277, 0.0280,  ..., 0.0358, 0.0296, 0.0256],
         ...,
         [0.0307, 0.0292, 0.0282,  ..., 0.0323, 0.0279, 0.0344],
         [0.0333, 0.0291, 0.0307,  ..., 0.0265, 0.0314, 0.0271],
         [0.0314, 0.0310, 0.0331,  ..., 0.0341, 0.0342, 0.0383]]),
 tensor([[0.0296, 0.0343, 0.0342,  ..., 0.0249, 0.0351, 0.0324],
         [0.0304, 0.0370, 0.0348,  ..., 0.0341, 0.0307, 0.0295],
         [0.0294, 0.0342, 0.0322,  ..., 0.0300, 0.0363, 0.0329],
         ...,
         [0.0303, 0.0288, 0.0338,  ..., 0.0358, 0.0347, 0.0281],
         [0.0266, 0.0356, 0.0319,  ..., 0.0324, 0.0280, 0.0313],
         [0.0310, 0.0300, 0.0294,  ..., 0.0347, 0.0330, 0.0297]]),
 tensor([[0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.0312],
         [0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.0312],
         [0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.

In [16]:
torch.sum(Fx), torch.sum(Fy), torch.sum(Fz)

(tensor(31.8582), tensor(31.8586), tensor(32.))

#### 5. 计算 $\frac{F(A)}{\|F(A)\|_2}$ 和 $\gamma= \| \frac{F(A)}{\|F(A)\|_2} - \frac{F(A_j)}{\|F(A_j)\|_2} \|_2$

In [18]:
Ax0, Az0 = Ax[0], Az[0]
print(Ax0, Az0)

FAx0, FAz0 = FsumAx - Ax0, FsumAz - Az0
print(FAx0, FAz0)

FAx0_s, FAz0_s = torch.linalg.norm(FAx0), torch.linalg.norm(FAz0)
print(FAx0_s, FAz0_s)

Fx0, Fz0 = FAx0 / FAx0_s, FAz0 / FAz0_s
print(Fx0, Fz0)

gammax0, gammaz0 = (Fx - Fx0).abs().sum(), (Fz - Fz0).abs().sum()
gammax0, gammaz0

tensor([[1.1258, 1.1524, 0.2506,  ..., 1.5863, 0.9463, 0.8437],
        [0.6136, 0.0316, 0.4927,  ..., 1.2341, 1.8197, 0.5515],
        [0.5692, 0.9200, 1.1108,  ..., 0.9565, 0.0335, 0.7101],
        ...,
        [1.0166, 1.2868, 2.0820,  ..., 0.8161, 0.5711, 0.1195],
        [0.4274, 0.8143, 1.4121,  ..., 0.1394, 0.3677, 0.4574],
        [1.2945, 0.7012, 1.9098,  ..., 0.5374, 1.0826, 1.7105]]) tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])
tensor([[47.0719, 50.4951, 46.1725,  ..., 60.0732, 49.9688, 42.8380],
        [41.5044, 54.0379, 47.8335,  ..., 45.7593, 49.0112, 47.3984],
        [56.1055, 44.5954, 44.8669,  ..., 57.8014, 48.5177, 41.3586],
        ...,
        [49.3204, 46.6189, 44.2372,  ..., 52.2488, 45.1867, 56.2959],
        [54.2687, 46.9045, 48.8813,  ..., 43.3701, 51.1

(tensor(0.3112), tensor(0.))

In [20]:
gammax_list, gammay_list, gammaz_list = torch.zeros(c), torch.zeros(c), torch.zeros(c)

for j in range(c):
    Axj, Ayj, Azj = Ax[j], Ay[j], Az[j]
    FAxj, FAyj, FAzj = FsumAx - Axj, FsumAy - Ayj, FsumAz - Azj
    FAxj_s, FAyj_s, FAzj_s = torch.linalg.norm(FAxj), torch.linalg.norm(FAyj), torch.linalg.norm(FAzj)
    Fxj, Fyj, Fzj = FAxj / FAxj_s, FAyj / FAyj_s, FAzj / FAzj_s
    gammmax, gammay, gammaz = (Fx - Fxj).abs().sum(), (Fy - Fyj).abs().sum(), (Fz - Fzj).abs().sum()
    gammax_list[j], gammay_list[j], gammaz_list[j] = gammmax, gammay, gammaz

gammax_list, gammay_list, gammaz_list

(tensor([0.3112, 0.3127, 0.2880, 0.2943, 0.3002, 0.3134, 0.2978, 0.3092, 0.3024,
         0.2959, 0.3060, 0.3157, 0.3089, 0.3090, 0.3002, 0.2985, 0.3007, 0.3087,
         0.3073, 0.2993, 0.3005, 0.2949, 0.2980, 0.3024, 0.2956, 0.3008, 0.3029,
         0.3097, 0.2956, 0.3042, 0.3034, 0.3053, 0.2959, 0.3067, 0.3014, 0.3037,
         0.3067, 0.2991, 0.3135, 0.3076, 0.2993, 0.2992, 0.2921, 0.3010, 0.2949,
         0.3049, 0.3029, 0.3034, 0.3047, 0.2921, 0.2968, 0.3031, 0.2911, 0.3072,
         0.2983, 0.2948, 0.2939, 0.3036, 0.2994, 0.3078, 0.3036, 0.2996, 0.3017,
         0.3022]),
 tensor([0.3039, 0.0479, 0.3118, 0.2895, 0.3143, 0.3079, 0.2990, 0.3046, 0.2968,
         0.3132, 0.2984, 0.3003, 0.2869, 0.3050, 0.3043, 0.3022, 0.3027, 0.2905,
         0.2951, 0.3082, 0.3039, 0.2981, 0.2921, 0.3175, 0.3005, 0.2916, 0.3070,
         0.3168, 0.2924, 0.3050, 0.2972, 0.3020, 0.3023, 0.3003, 0.3001, 0.2936,
         0.3072, 0.3047, 0.2877, 0.3013, 0.2996, 0.3141, 0.3095, 0.3184, 0.3076,
         

In [24]:
sortx, sortx_index = torch.sort(gammax_list, descending=True)
sorty, sorty_index = torch.sort(gammay_list, descending=True)
sortz, sortz_index = torch.sort(gammaz_list, descending=True)
maxx, maxx_idx = sortx[0], sortx_index[0]
maxy, maxy_idx = sorty[0], sorty_index[0]
maxz, maxz_idx = sortz[0], sortz_index[0]
print(maxx, maxx_idx, Ax[maxx_idx])
print(maxy, maxy_idx, Ax[maxy_idx])
print(maxz, maxz_idx, Ax[maxz_idx])

tensor(0.3157) tensor(11) tensor([[1.2072, 1.3946, 0.2554,  ..., 0.7349, 0.1060, 0.0482],
        [1.8015, 0.7515, 0.8094,  ..., 0.7470, 2.3901, 0.4567],
        [1.1487, 0.3122, 0.9012,  ..., 0.5979, 1.0858, 0.2447],
        ...,
        [0.5948, 0.4020, 0.0962,  ..., 1.7930, 0.8511, 1.4528],
        [0.1486, 0.4138, 0.5233,  ..., 0.6924, 0.4695, 0.6201],
        [0.9624, 0.3194, 0.4420,  ..., 1.2246, 0.9603, 0.6734]])
tensor(0.3184) tensor(43) tensor([[0.5971, 0.9966, 1.5495,  ..., 0.8722, 1.1792, 1.2553],
        [1.3510, 0.3032, 0.8867,  ..., 0.8558, 0.8534, 0.7020],
        [0.3258, 1.1347, 0.5089,  ..., 0.8760, 0.6129, 0.3393],
        ...,
        [0.0091, 0.4686, 0.8840,  ..., 2.1571, 1.1152, 0.8608],
        [0.4225, 0.0963, 1.4006,  ..., 0.5018, 0.5187, 0.6009],
        [1.6907, 1.5409, 0.3857,  ..., 0.6257, 0.9216, 0.2888]])
tensor(0.) tensor(0) tensor([[1.1258, 1.1524, 0.2506,  ..., 1.5863, 0.9463, 0.8437],
        [0.6136, 0.0316, 0.4927,  ..., 1.2341, 1.8197, 0.5515],
   

In [29]:
pruning_rate = 0.7
num_total = len(gammax_list)
thre_idx = int(num_total * pruning_rate)
thresholdx, thresholdy, thresholdz = sortx[thre_idx], sorty[thre_idx], sortz[thre_idx]
maskx, masky, maskz = gammax_list.lt(thresholdx).float(), gammay_list.lt(thresholdy).float(), gammaz_list.lt(thresholdz).float()

num_remainx, num_remainy, num_remainz = torch.sum(maskx), torch.sum(masky), torch.sum(maskz)

num_remainx / num_total, num_remainy / num_total, num_remainz /num_total

(tensor(0.2969), tensor(0.2969), tensor(0.))

### 2. 按维度实现 

#### 2.1 数据准备

In [1]:
import torch
import numpy as np
import torch.nn.functional as F

B, C, H, W = 4, 8, 4, 4
data = torch.randn(B, C, H, W)
datan = data.numpy()

data.shape, datan.shape

(torch.Size([4, 8, 4, 4]), (4, 8, 4, 4))

#### 2.2 pytorch实现 

1. （B,C,H,W）→（B,H,W）沿着通道方向计算平方平均数（平方和的均值）
2. （B,H,W）→（B,H×W）将数据以Batch作为单独的维度，其他维度合并成一个维度
3. 正态分布化数据

In [68]:
x = data.pow(2).mean(1)
x = x.view(x.size(0), -1)
x.shape, x

(torch.Size([4, 16]),
 tensor([[1.5062, 0.6776, 0.4537, 1.1660, 0.2690, 1.0880, 1.1333, 1.5674, 1.0934,
          1.1475, 1.0724, 1.5206, 0.8944, 0.6874, 1.8965, 1.1948],
         [1.2713, 0.9307, 0.4234, 0.6191, 2.1409, 1.1193, 1.0584, 1.4638, 0.5582,
          1.1009, 1.5972, 1.5151, 1.3469, 1.4395, 0.8958, 0.8490],
         [1.9093, 1.2639, 1.0181, 3.0649, 0.3986, 1.0464, 1.2088, 1.0529, 1.6034,
          0.4639, 0.9645, 1.1715, 0.6276, 0.5704, 1.0258, 0.2900],
         [1.6222, 0.6057, 1.3385, 1.2220, 1.1952, 1.4389, 0.3287, 1.4243, 0.8753,
          0.9462, 0.8821, 0.9698, 1.7016, 1.2486, 0.4495, 0.6128]]))

##### 2.2.1 手动实现正态分布化数据 

In [79]:
x1 = x.clone()
x2 = []
for data in x1:
    x2.append(data / torch.sum(data ** 2) ** 0.5)
x2

[tensor([0.3244, 0.1459, 0.0977, 0.2511, 0.0579, 0.2343, 0.2441, 0.3376, 0.2355,
         0.2471, 0.2309, 0.3275, 0.1926, 0.1480, 0.4084, 0.2573]),
 tensor([0.2599, 0.1903, 0.0866, 0.1266, 0.4377, 0.2288, 0.2164, 0.2993, 0.1141,
         0.2251, 0.3265, 0.3098, 0.2754, 0.2943, 0.1831, 0.1736]),
 tensor([0.3713, 0.2458, 0.1980, 0.5960, 0.0775, 0.2035, 0.2351, 0.2048, 0.3118,
         0.0902, 0.1876, 0.2278, 0.1220, 0.1109, 0.1995, 0.0564]),
 tensor([0.3598, 0.1343, 0.2968, 0.2710, 0.2651, 0.3191, 0.0729, 0.3159, 0.1941,
         0.2098, 0.1956, 0.2151, 0.3774, 0.2769, 0.0997, 0.1359])]

##### 2.2.2 pytorch函数调用 

In [55]:
x = F.normalize(x)
x

tensor([[0.3244, 0.1459, 0.0977, 0.2511, 0.0579, 0.2343, 0.2441, 0.3376, 0.2355,
         0.2471, 0.2309, 0.3275, 0.1926, 0.1480, 0.4084, 0.2573],
        [0.2599, 0.1903, 0.0866, 0.1266, 0.4377, 0.2288, 0.2164, 0.2993, 0.1141,
         0.2251, 0.3265, 0.3098, 0.2754, 0.2943, 0.1831, 0.1736],
        [0.3713, 0.2458, 0.1980, 0.5960, 0.0775, 0.2035, 0.2351, 0.2048, 0.3118,
         0.0902, 0.1876, 0.2278, 0.1220, 0.1109, 0.1995, 0.0564],
        [0.3598, 0.1343, 0.2968, 0.2710, 0.2651, 0.3191, 0.0729, 0.3159, 0.1941,
         0.2098, 0.1956, 0.2151, 0.3774, 0.2769, 0.0997, 0.1359]])

#### 2.3 numpy实现

In [47]:
xn = np.square(datan)
xn = np.mean(xn, axis=1)
xn = xn.reshape(xn.shape[0], -1)
xn.shape, xn

((4, 16),
 array([[1.5062251 , 0.67760205, 0.45365313, 1.1659613 , 0.26901874,
         1.0879692 , 1.133285  , 1.5674258 , 1.093431  , 1.1474781 ,
         1.0723621 , 1.520561  , 0.8944081 , 0.68743837, 1.8964972 ,
         1.1947949 ],
        [1.2713171 , 0.93072975, 0.42342928, 0.61907315, 2.140855  ,
         1.1192968 , 1.058415  , 1.4638273 , 0.5582265 , 1.1008934 ,
         1.5971556 , 1.5150522 , 1.3469063 , 1.4394847 , 0.89578104,
         0.84897494],
        [1.9092603 , 1.2638607 , 1.0180924 , 3.0648754 , 0.39857435,
         1.0464232 , 1.2087557 , 1.0529248 , 1.6034381 , 0.46393275,
         0.96448696, 1.1714634 , 0.627574  , 0.5703842 , 1.0258193 ,
         0.2900106 ],
        [1.6222157 , 0.6056761 , 1.3384904 , 1.2220182 , 1.1952196 ,
         1.4388841 , 0.32865882, 1.4243088 , 0.8753281 , 0.9461654 ,
         0.88214266, 0.9697903 , 1.7016097 , 1.2485752 , 0.44951284,
         0.61275357]], dtype=float32))

##### 2.3.1 np.linalg.norm计算平方平均数 

In [56]:
v = np.linalg.norm(xn, axis=1, keepdims=True)
xn = xn / v
v, xn

(array([[0.99999994],
        [1.        ],
        [1.        ],
        [0.99999994]], dtype=float32),
 array([[0.3243775 , 0.14592697, 0.09769779, 0.251099  , 0.05793532,
         0.23430277, 0.2440619 , 0.33755755, 0.23547903, 0.2471185 ,
         0.23094165, 0.32746485, 0.19261786, 0.1480453 , 0.4084257 ,
         0.25730854],
        [0.25992113, 0.19028796, 0.08657023, 0.12656967, 0.4376984 ,
         0.22884053, 0.21639325, 0.29927987, 0.11412957, 0.22507796,
         0.3265389 , 0.30975282, 0.27537537, 0.29430306, 0.18314268,
         0.17357317],
        [0.37130386, 0.24578962, 0.19799377, 0.5960424 , 0.07751285,
         0.20350341, 0.23507307, 0.20476781, 0.31182903, 0.09022344,
         0.18756884, 0.22782063, 0.12204761, 0.11092561, 0.19949646,
         0.05639989],
        [0.3597739 , 0.13432644, 0.2968495 , 0.27101836, 0.265075  ,
         0.31911477, 0.07288973, 0.31588224, 0.19412968, 0.20983993,
         0.195641  , 0.21507943, 0.37738186, 0.27690816, 0.09969265,
 

#### 2.4 函数封装 

In [32]:
ones = torch.ones(B, H * W)
onesn = ones.numpy()

ones.shape, onesn.shape

(torch.Size([4, 16]), (4, 16))

In [3]:
def at(x):
    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))


def at_loss(x, y):
    return (at(x) - at(y))



value = at_loss(datax, datay)
value.shape

torch.Size([64, 16])

In [16]:
F_A = datax.pow(2).mean(1)
F_A.shape

torch.Size([64, 32, 32])

In [22]:
F_A.view(F_A.size(0), -1).shape, F_A

(torch.Size([64, 1024]),
 tensor([[[0.4533, 0.5819, 3.4659,  ..., 1.3390, 0.6274, 0.2621],
          [0.2669, 0.5295, 0.4900,  ..., 0.4292, 0.2936, 0.7929],
          [0.1183, 1.6187, 0.4253,  ..., 0.2029, 0.4095, 0.7340],
          ...,
          [1.7430, 0.5819, 0.4700,  ..., 0.1823, 0.0718, 0.6682],
          [1.8804, 1.2755, 1.8785,  ..., 0.5381, 0.6046, 1.6938],
          [3.1485, 0.1752, 1.3064,  ..., 4.5727, 0.5579, 2.7471]],
 
         [[0.3761, 2.2716, 1.5009,  ..., 0.8825, 0.5006, 3.1626],
          [0.5042, 0.0787, 1.7971,  ..., 1.4207, 1.0284, 0.0240],
          [0.3022, 0.2683, 2.1142,  ..., 1.4540, 0.6503, 1.0044],
          ...,
          [0.3743, 0.1511, 0.5097,  ..., 0.9435, 0.5137, 1.2636],
          [1.4879, 2.0318, 0.5538,  ..., 0.6840, 0.3458, 0.4118],
          [1.9066, 0.7246, 1.7175,  ..., 0.3644, 0.2178, 1.1404]],
 
         [[2.5538, 1.1105, 0.2389,  ..., 0.7934, 0.5914, 1.0298],
          [0.3597, 1.8543, 0.6631,  ..., 0.1782, 0.2804, 0.0903],
          [0.66

##### 2.1 数据(1, 2)正态分布后

In [40]:
datat = torch.tensor([[1., 2.], [1., 2.]])
datat_norm = F.normalize(datat)

datat, datat_norm, datat_norm[0].sum()

(tensor([[1., 2.],
         [1., 2.]]),
 tensor([[0.4472, 0.8944],
         [0.4472, 0.8944]]),
 tensor(1.3416))

##### 2.2 数据(1, 1)和(2, 2)正态分布后

In [41]:
datat = torch.tensor([[1., 1.], [2., 2.]])
datat_norm = F.normalize(datat)

datat, datat_norm, datat_norm[0].sum()

(tensor([[1., 1.],
         [2., 2.]]),
 tensor([[0.7071, 0.7071],
         [0.7071, 0.7071]]),
 tensor(1.4142))

In [17]:
datat2 = torch.randn(2, 4, 1, 1)
b, c, h, w = datat2.shape

datat2

tensor([[[[-1.2522]],

         [[ 0.1777]],

         [[ 0.3770]],

         [[-0.6339]]],


        [[[ 1.0001]],

         [[ 0.1608]],

         [[ 1.0589]],

         [[-0.0019]]]])

In [20]:
for i in range(c):
    itemj = datat2[:, i, :, :].unsqueeze(1)
#     itemj = torch.unsqueeze(itemj, 1)
#     print(itemj)
    dataj = datat2 - itemj
    print(dataj)
    
#     dataj = datat2[:, torch.arange(c) != i, :, :]
#     print(dataj)

# datat2.shape, dataj.shape, datat2, dataj

tensor([[[[ 0.0000]],

         [[ 1.4299]],

         [[ 1.6292]],

         [[ 0.6183]]],


        [[[ 0.0000]],

         [[-0.8393]],

         [[ 0.0587]],

         [[-1.0021]]]])
tensor([[[[-1.4299]],

         [[ 0.0000]],

         [[ 0.1993]],

         [[-0.8116]]],


        [[[ 0.8393]],

         [[ 0.0000]],

         [[ 0.8981]],

         [[-0.1627]]]])
tensor([[[[-1.6292]],

         [[-0.1993]],

         [[ 0.0000]],

         [[-1.0109]]],


        [[[-0.0587]],

         [[-0.8981]],

         [[ 0.0000]],

         [[-1.0608]]]])
tensor([[[[-0.6183]],

         [[ 0.8116]],

         [[ 1.0109]],

         [[ 0.0000]]],


        [[[ 1.0021]],

         [[ 0.1627]],

         [[ 1.0608]],

         [[ 0.0000]]]])


In [7]:
itemj = datat2[-1, 0, -1, -1]
itemj

tensor(1.0083)