# Convolution

Conv 는 기본적으로 matrix multiplication 연산이 아니라서 구현이 다소 tricky 한 면이 있음.

1. Navie method: `for i, j in each spatial positions`
2. img2col (unfold)
3. weight expanding

Conv 를 보기 전에 구현을 위한 조각인 unfold 를 먼저 보고 가자.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# comparing methods
def diff(a, b):
    return (a-b).abs().max()

def compare(a, b):
    dif = diff(a, b)
    allclose = torch.allclose(a, b)
    print("Allclose = {}, max diff = {}".format(allclose, dif))
    
    if not allclose:
        print("!!! Not close !!!")

# Unfold

img2col 이라고 더 많이 불리는데, conv 연산을 matrix multiplication 으로 계산할 수 있도록 input x 를 잘 펼쳐준다.
conv 연산 자체가 input x 에 대해 겹치는 연산이 있으므로 (e.g. stride=1, K=3 경우) unfold 할 때도 중복 데이터가 존재하게 unfolding 이 된다.

## Methods

img2col 연산이 여러가지가 존재. `Tensor.unfold` 와 `F.unfold` 를 헷갈리지 말자.

- `Tensor.unfold`: 1d unfold.
- `F.unfold`: 2d unfold. img2col 자체가 원래 2d 에 대한 얘기니까 이게 img2col 임.
- `Tensor.as_strided`: 1d, 2d unfold 를 둘다 요걸로 구현할 수 있음

### 1d unfold

특이하게 `torch.unfold` 는 없음.

In [3]:
B = 1
K = 2
C = 3
T = 5
stride = 1
padding = 0

In [4]:
x = torch.rand(B, C, T)
x

tensor([[[0.1221, 0.1485, 0.8846, 0.2949, 0.6597],
         [0.8651, 0.3928, 0.7882, 0.6540, 0.7690],
         [0.1089, 0.0725, 0.2225, 0.2966, 0.4136]]])

In [5]:
# args: unfold last dim, kernel size 2, stride 1. no padding.
unf = x.unfold(-1, K, stride)
print(unf.shape)  # [B, C, T', K]; K=kernel, T'=T-(K-1). padding=0, stride=1 이므로 T 가 K-1 만큼 줄어듬.

# for visualize, swap channel dim (1) and time dim (2).
unf.transpose(1, 2)  # [B, T', C, K]

torch.Size([1, 3, 4, 2])


tensor([[[[0.1221, 0.1485],
          [0.8651, 0.3928],
          [0.1089, 0.0725]],

         [[0.1485, 0.8846],
          [0.3928, 0.7882],
          [0.0725, 0.2225]],

         [[0.8846, 0.2949],
          [0.7882, 0.6540],
          [0.2225, 0.2966]],

         [[0.2949, 0.6597],
          [0.6540, 0.7690],
          [0.2966, 0.4136]]]])

### 1d unfold by as_strided

어떤 view 를 리턴하는 대부분의 pytorch 함수는 내부적으로 이걸로 구현되어 있음. 그만큼 강력하고 대신 약간 쓰기 까다로운 함수.

```
torch.as_strided(input, size, stride, storage_offset=0) → Tensor
```

- torch.as_strided(input, ...) 는 input.as_strided(...) 로 쓸 수 있다.
- size 는 output size.
- stride 는 the stride of the output tensor 라고 되어 있는데, 이 부분이 약간 트리키함.
  - output tensor 에서, 해당 디멘션의 index 가 바뀔 때 input tensor index 의 stride.
  - 즉, 아래 예제에서 `stride = (C*T, T, 1, 1)` 이므로, output tensor 의 첫번째 index (batch index) 가 바뀌면 input tensor 에서 `C*T` 만큼 움직인다. input tensor 가 `[B, C, T]` 이므로 자연스럽게 다음 배치로 넘어간다.
  - input tensor 가 flatten 되어 있다고 생각하는 것이 편하다.
- storage_offset 은 시작 offset.

In [6]:
Tp = T-(K-1)
unf_by_as = x.as_strided((B, C, Tp, K), (C*T, T, 1, 1))
unf_by_as

tensor([[[[0.1221, 0.1485],
          [0.1485, 0.8846],
          [0.8846, 0.2949],
          [0.2949, 0.6597]],

         [[0.8651, 0.3928],
          [0.3928, 0.7882],
          [0.7882, 0.6540],
          [0.6540, 0.7690]],

         [[0.1089, 0.0725],
          [0.0725, 0.2225],
          [0.2225, 0.2966],
          [0.2966, 0.4136]]]])

In [7]:
compare(unf, unf_by_as)

Allclose = True, max diff = 0.0


### 2d unfold

`[B, C, H, W]` tensor -> `[B, C*K*K, L]` tensor; `L = H_out * W_out`.

- matmul 로 conv 연산이 가능하도록 `C*K*K` 로 묶어준다.

In [8]:
B = 1
K = 2
C = 4
size = 3  # H, W
stride = 1
padding = 0

In [9]:
x = torch.rand(B, C, size, size)
x

tensor([[[[0.3915, 0.0511, 0.7724],
          [0.0454, 0.9542, 0.4523],
          [0.7951, 0.5156, 0.0931]],

         [[0.3165, 0.5458, 0.0030],
          [0.9388, 0.8726, 0.5483],
          [0.0543, 0.2777, 0.2357]],

         [[0.5484, 0.5781, 0.1654],
          [0.7392, 0.6313, 0.4918],
          [0.9587, 0.9745, 0.0871]],

         [[0.8928, 0.4683, 0.2547],
          [0.9354, 0.9179, 0.2860],
          [0.4460, 0.9646, 0.0382]]]])

In [10]:
unf = F.unfold(x, K, padding=padding, stride=stride)
print(unf.shape)  # [B, C*K*K, L]; L = H_out * W_out
# stride=1, padding=0, H=W=size 세팅이므로 L = (size - K + 1) ** 2
unf

torch.Size([1, 16, 4])


tensor([[[0.3915, 0.0511, 0.0454, 0.9542],
         [0.0511, 0.7724, 0.9542, 0.4523],
         [0.0454, 0.9542, 0.7951, 0.5156],
         [0.9542, 0.4523, 0.5156, 0.0931],
         [0.3165, 0.5458, 0.9388, 0.8726],
         [0.5458, 0.0030, 0.8726, 0.5483],
         [0.9388, 0.8726, 0.0543, 0.2777],
         [0.8726, 0.5483, 0.2777, 0.2357],
         [0.5484, 0.5781, 0.7392, 0.6313],
         [0.5781, 0.1654, 0.6313, 0.4918],
         [0.7392, 0.6313, 0.9587, 0.9745],
         [0.6313, 0.4918, 0.9745, 0.0871],
         [0.8928, 0.4683, 0.9354, 0.9179],
         [0.4683, 0.2547, 0.9179, 0.2860],
         [0.9354, 0.9179, 0.4460, 0.9646],
         [0.9179, 0.2860, 0.9646, 0.0382]]])

### 2d unfold by as_strided

* 1d 와 달리 한방에 unfold 와 똑같이 갈 수는 없음.
* `[B, C, H, W]` -> `[B, C, H', W', K, K]` 로 바꾼 다음 `C*K*K` 와 `H'*W'` 을 모으자.
* 사실 conv 를 구현한다고 하면 꼭 unfold shape 으로 맞춰줄 필욘 없을 듯.

In [11]:
# n_flat = C*K*K
# L = (size - K + 1) ** 2
sizep = size - K + 1

strided = x.as_strided((B, C, sizep, sizep, K, K), (C*size*size, size*size, size, 1, size, 1))

# match shape
strided = strided.permute(0, 1, 4, 5, 2, 3)
strided = strided.reshape(B, C*K*K, sizep*sizep)
print(strided.shape)
strided

torch.Size([1, 16, 4])


tensor([[[0.3915, 0.0511, 0.0454, 0.9542],
         [0.0511, 0.7724, 0.9542, 0.4523],
         [0.0454, 0.9542, 0.7951, 0.5156],
         [0.9542, 0.4523, 0.5156, 0.0931],
         [0.3165, 0.5458, 0.9388, 0.8726],
         [0.5458, 0.0030, 0.8726, 0.5483],
         [0.9388, 0.8726, 0.0543, 0.2777],
         [0.8726, 0.5483, 0.2777, 0.2357],
         [0.5484, 0.5781, 0.7392, 0.6313],
         [0.5781, 0.1654, 0.6313, 0.4918],
         [0.7392, 0.6313, 0.9587, 0.9745],
         [0.6313, 0.4918, 0.9745, 0.0871],
         [0.8928, 0.4683, 0.9354, 0.9179],
         [0.4683, 0.2547, 0.9179, 0.2860],
         [0.9354, 0.9179, 0.4460, 0.9646],
         [0.9179, 0.2860, 0.9646, 0.0382]]])

In [12]:
compare(unf, strided)

Allclose = True, max diff = 0.0


# Back to Convolution

조각들을 살펴보았으니 이제 Conv 를 구현해보자.

In [13]:
B = 1
C_in = 4
C_out = 8
K = 3  # should be 3
size = 8
stride = 1
padding = 0

In [14]:
x = torch.rand(B, C_in, size, size)
W = torch.rand(C_out, C_in, K, K)
b = torch.rand(C_out)

In [15]:
def conv(x):
    return F.conv2d(x, W, b)

## Naive conv

In [16]:
def conv_naive(x):
    assert stride == 1 and padding == 0 and K == 3
    Wt = W.permute(1, 2, 3, 0)  # [C_in, K, K, C_out]
    Wt_flat = Wt.flatten(0, 2) # [C_in*K*K, C_out]
    out = torch.empty(B, C_out, size-(K-1), size-(K-1))
    s = K // 2
    for i in range(s, size-s):
        for j in range(s, size-s):
            si = i-s
            sj = j-s
            px = x[:, :, si:si+K, sj:sj+K] # [B, C_in, K, K]
            r = (px.flatten(1) @ Wt_flat) + b
            out[:, :, si, sj] = r
            
    return out

In [17]:
compare(conv(x), conv_naive(x))

Allclose = True, max diff = 3.814697265625e-06


## Unfold convs

In [18]:
x.shape, W.shape

(torch.Size([1, 4, 8, 8]), torch.Size([8, 4, 3, 3]))

In [19]:
def conv_unfold(x):
    unf = F.unfold(x, K)  # [B, C_in*K*K, L]; L = H_out * W_out
    weight = W.flatten(1)  # [C_out, C_in*K*K]
    
    out =  weight @ unf + b[None, ..., None]  # [B, C_out, L]
    sizep = size - K + 1
    return out.view(B, C_out, sizep, sizep)  # [B, C_out, H', W']

In [20]:
compare(conv(x), conv_unfold(x))

Allclose = True, max diff = 2.86102294921875e-06


### Unfold by as_strided

- as_strided 로 reshape 안하고.

In [21]:
x.shape, size, K, C_in, W.shape

(torch.Size([1, 4, 8, 8]), 8, 3, 4, torch.Size([8, 4, 3, 3]))

In [22]:
def conv_unfold_as(x):
    sizep = size - K + 1
    strided = x.as_strided((B, C_in, sizep, sizep, K, K), (C_in*size*size, size*size, size, 1, size, 1))
    strided = strided.permute(0, 2, 3, 1, 4, 5)  # [B, H', W', C, K, K]
    
    # https://discuss.pytorch.org/t/does-as-strided-copy-data/24503/5 에 따르면,
    # einsum 보다 gemm 을 쓰라는데. gemm 을 쓰는게 훨씬 복잡해지지 않나?
    out = torch.einsum('bhwckl,ockl->bohw', strided, W)
    out = out + b[None, ..., None, None]
    
    return out

In [23]:
compare(conv(x), conv_unfold_as(x))

Allclose = True, max diff = 2.86102294921875e-06


In [24]:
compare(conv_unfold(x), conv_unfold_as(x))

Allclose = True, max diff = 0.0


## Expand conv

Ref: https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/modules/dynamic_convolution.py#L173

skip now ...