## 计算局部相关性

对于给定的数据，其尺寸为`N,C,H,W`，现在想要计算其局部的相关性，也就是说特定尺寸范围内，例如2x2大小的区域内任意两点之间的点积。

试写出相关的代码。

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [56]:
a = torch.rand(1, 2, 3, 4)
b = torch.rand(1, 2, 3, 4)
print("a=>\n", a)
print("b=>\n", b)

a=>
 tensor([[[[0.4818, 0.9888, 0.8039, 0.7089],
          [0.7667, 0.2273, 0.9956, 0.4739],
          [0.9515, 0.1896, 0.7928, 0.0173]],

         [[0.1723, 0.8767, 0.4832, 0.6515],
          [0.9487, 0.6301, 0.5711, 0.7781],
          [0.2017, 0.9220, 0.2793, 0.2675]]]])
b=>
 tensor([[[[0.1417, 0.3510, 0.1170, 0.1698],
          [0.4311, 0.1535, 0.6087, 0.6646],
          [0.1880, 0.4103, 0.0289, 0.1094]],

         [[0.3398, 0.8751, 0.8299, 0.3514],
          [0.0333, 0.2831, 0.8086, 0.0514],
          [0.3168, 0.2895, 0.5107, 0.4949]]]])


In [57]:
unfold_func = nn.Unfold(2, 1, 0, 1)

unfold_a = unfold_func(a)
print("unfold_a=>\n", unfold_a)

unfold_b = unfold_func(b)
print("unfold_b=>\n", unfold_b)

unfold_a=>
 tensor([[[0.4818, 0.9888, 0.8039, 0.7667, 0.2273, 0.9956],
         [0.9888, 0.8039, 0.7089, 0.2273, 0.9956, 0.4739],
         [0.7667, 0.2273, 0.9956, 0.9515, 0.1896, 0.7928],
         [0.2273, 0.9956, 0.4739, 0.1896, 0.7928, 0.0173],
         [0.1723, 0.8767, 0.4832, 0.9487, 0.6301, 0.5711],
         [0.8767, 0.4832, 0.6515, 0.6301, 0.5711, 0.7781],
         [0.9487, 0.6301, 0.5711, 0.2017, 0.9220, 0.2793],
         [0.6301, 0.5711, 0.7781, 0.9220, 0.2793, 0.2675]]])
unfold_b=>
 tensor([[[0.1417, 0.3510, 0.1170, 0.4311, 0.1535, 0.6087],
         [0.3510, 0.1170, 0.1698, 0.1535, 0.6087, 0.6646],
         [0.4311, 0.1535, 0.6087, 0.1880, 0.4103, 0.0289],
         [0.1535, 0.6087, 0.6646, 0.4103, 0.0289, 0.1094],
         [0.3398, 0.8751, 0.8299, 0.0333, 0.2831, 0.8086],
         [0.8751, 0.8299, 0.3514, 0.2831, 0.8086, 0.0514],
         [0.0333, 0.2831, 0.8086, 0.3168, 0.2895, 0.5107],
         [0.2831, 0.8086, 0.0514, 0.2895, 0.5107, 0.4949]]])


这里使用fold和unfold操作之后可以看出来，外侧的括号从原来的四层变为了现在的三层，实际上表示的就是从原来的`N,C,H,W`变成了现在的`N,C*4,H/2*W/2`的样子。

而对于`H/2*W/2`的维度上，在滑窗处理时，也是基于行主序调整成一行的。

In [58]:
unfold_a_reshape = unfold_a.transpose(1, 2).view(1, (3-1)*(4-1), 2, 4)  # N,H'W',C,2*2
print("unfold_a_reshape=>\n", unfold_a_reshape)

unfold_b_reshape = unfold_b.transpose(1, 2).view(1, (3-1)*(4-1), 2, 4)
print("unfold_b_reshape=>\n", unfold_b_reshape)

unfold_a_reshape=>
 tensor([[[[0.4818, 0.9888, 0.7667, 0.2273],
          [0.1723, 0.8767, 0.9487, 0.6301]],

         [[0.9888, 0.8039, 0.2273, 0.9956],
          [0.8767, 0.4832, 0.6301, 0.5711]],

         [[0.8039, 0.7089, 0.9956, 0.4739],
          [0.4832, 0.6515, 0.5711, 0.7781]],

         [[0.7667, 0.2273, 0.9515, 0.1896],
          [0.9487, 0.6301, 0.2017, 0.9220]],

         [[0.2273, 0.9956, 0.1896, 0.7928],
          [0.6301, 0.5711, 0.9220, 0.2793]],

         [[0.9956, 0.4739, 0.7928, 0.0173],
          [0.5711, 0.7781, 0.2793, 0.2675]]]])
unfold_b_reshape=>
 tensor([[[[0.1417, 0.3510, 0.4311, 0.1535],
          [0.3398, 0.8751, 0.0333, 0.2831]],

         [[0.3510, 0.1170, 0.1535, 0.6087],
          [0.8751, 0.8299, 0.2831, 0.8086]],

         [[0.1170, 0.1698, 0.6087, 0.6646],
          [0.8299, 0.3514, 0.8086, 0.0514]],

         [[0.4311, 0.1535, 0.1880, 0.4103],
          [0.0333, 0.2831, 0.3168, 0.2895]],

         [[0.1535, 0.6087, 0.4103, 0.0289],
          [0.28

In [59]:
mm_unfold_a = torch.matmul(unfold_a_reshape.transpose(2, 3), unfold_a_reshape)  # N,H'W',2*2,2*2
print("mm_unfold_a=>\n", mm_unfold_a)

mm_unfold_b = torch.matmul(unfold_b_reshape.transpose(2, 3), unfold_b_reshape)
print("mm_unfold_b=>\n", mm_unfold_b)

mm_unfold_a=>
 tensor([[[[0.2619, 0.6275, 0.5329, 0.2181],
          [0.6275, 1.7462, 1.5898, 0.7771],
          [0.5329, 1.5898, 1.4878, 0.7720],
          [0.2181, 0.7771, 0.7720, 0.4487]],

         [[1.7462, 1.2184, 0.7771, 1.4851],
          [1.2184, 0.8796, 0.4871, 1.0763],
          [0.7771, 0.4871, 0.4487, 0.5862],
          [1.4851, 1.0763, 0.5862, 1.3174]],

         [[0.8796, 0.8847, 1.0763, 0.7569],
          [0.8847, 0.9270, 1.0779, 0.8429],
          [1.0763, 1.0779, 1.3174, 0.9163],
          [0.7569, 0.8429, 0.9163, 0.8301]],

         [[1.4878, 0.7720, 0.9209, 1.0200],
          [0.7720, 0.4487, 0.3433, 0.6240],
          [0.9209, 0.3433, 0.9459, 0.3664],
          [1.0200, 0.6240, 0.3664, 0.8860]],

         [[0.4487, 0.5862, 0.6240, 0.3562],
          [0.5862, 1.3174, 0.7153, 0.9488],
          [0.6240, 0.7153, 0.8860, 0.4078],
          [0.3562, 0.9488, 0.4078, 0.7065]],

         [[1.3174, 0.9163, 0.9488, 0.1700],
          [0.9163, 0.8301, 0.5930, 0.2164],
       

In [60]:
a_ = a[0, :2, :2, :2]
b_ = b[0, :2, :2, :2]

print(a_.shape, b_.shape)

a_ = a_.reshape(1, 2, 2*2)  # N,C,2*2
b_ = b_.reshape(1, 2, 2*2)

print("torch.matmul(a_.t, a_)=>\n", torch.matmul(a_.transpose(1, 2), a_))
print("torch.matmul(b_.t, b_)=>\n", torch.matmul(b_.transpose(1, 2), b_))

print(torch.matmul(a_.transpose(1, 2), a_)[0] == mm_unfold_a[0, 0])
print(torch.matmul(b_.transpose(1, 2), b_)[0] == mm_unfold_b[0, 0])

torch.Size([2, 2, 2]) torch.Size([2, 2, 2])
torch.matmul(a_.t, a_)=>
 tensor([[[0.2619, 0.6275, 0.5329, 0.2181],
         [0.6275, 1.7462, 1.5898, 0.7771],
         [0.5329, 1.5898, 1.4878, 0.7720],
         [0.2181, 0.7771, 0.7720, 0.4487]]])
torch.matmul(b_.t, b_)=>
 tensor([[[0.1355, 0.3471, 0.0724, 0.1180],
         [0.3471, 0.8891, 0.1805, 0.3017],
         [0.0724, 0.1805, 0.1869, 0.0756],
         [0.1180, 0.3017, 0.0756, 0.1037]]])
tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])
tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])


从这里可以看出来，通过fold、reshape（view）、matmul实现了对于`N,C,H,W`形状的数据的局部（这里对应为滑窗操作的`kernel_size`）关联矩阵的计算，而且速度又快（相较于最原始朴素的“滑窗式”计算方法）。

对于运算过程代码的书写，这里不多提，这里只是对于这些东西的一个验证，结果来看，是验证了运算的合理性的。

这里给了一个启示，简单的按照矩阵的维度匹配的原则，是可以直接写出来这个变化过程的：

```
N,C,H,W --(Ws*Ws)--> 
N,C*Ws*Ws,H/Ws*W/Ws --> 
N,H/Ws*W/Ws,C*Ws*Ws --> 
N,H/Ws*W/Ws,C*Ws*Ws --> 
N,H/Ws*W/Ws,C,Ws*Ws --> 
N,H/Ws*W/Ws,Ws*Ws,Ws*Ws
```

这里的`H/Ws*W/Ws`实际上反映出来的是分块的数量，这里直接使用除法对应的是滑窗大小正好可以被数据长宽整除，同时步长等于滑窗大小，没有padding的情况。

前面给出的代码中可以看出来，这里的值对于步长为1的时候，是需要进行调整的。

```python
unfold_func = nn.Unfold(2, 1, 0, 1)
...
unfold_a_reshape = unfold_a.transpose(1, 2).view(1, (3-1)*(4-1), 2, 4)
```