# 多输入多通道

输入 $X$ : $c_i\times n_h \times n_w$

核 $W$ : $c_o \times c_i \times k_h \times k_w$

偏差 $B$ : $c_o \times c_i$

输出 $Y$ : $c_o \times m_h \times m_w$

$$
Y = X \star W + B
$$

时间复杂度 : $\mathbf{O} = (c_i c_o\times k_h k_w\times m_h m_w)$

FLOPS 即 Floating-point operations per second

| 单位 | 运算次数 |
| --- | --- |
| MFLOPS | $10^6$ |
| GFLOPS | $10^9$ |
| TFLOPS | $10^{12}$ |
| PFLOPS | $10^{15}$ |
| EFLOPS | $10^{18}$ |

CPU : 0.15 TF

GPU : 没查到

In [10]:
import torch
from torch import nn
import d2l.torch as d2l

#### 实现多通道输入互相关运算

In [15]:
def corr2d(X, K):
    h, w = K.shape
    Y = torch.zeros((X.shape[0]-h+1), (X.shape[1]-w+1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i+h, j:j+w] * K).sum()
    return Y

def multi_corr2d_in(X, K): # 注意 X, K 都是三维的 tensor
    return sum(corr2d(x, k) for x, k in zip(X, K))

In [19]:
X = torch.tensor([[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
K = torch.tensor([[[0.0, 1.0], [2.0, 3.0]], [[1.0, 2.0], [3.0, 4.0]]])
multi_corr2d_in(X, K)

tensor([[ 56.,  72.],
        [104., 120.]])

#### 实现多通道输入, 多通道输出

$stack$ 可以在 $axis=0$ 上进行堆叠

In [20]:
def multi_corr2d_in_out(X, K): # 注意 X 是三维的 tensor, K 是四维的 tensor
    return torch.stack([multi_corr2d_in(X, k) for k in K], 0)

输入 $X$ : $c_i\times n_h \times n_w$

核 $W$ : $c_o \times c_i \times k_h \times k_w$

In [25]:
X = torch.normal(mean=0, std=1, size=(3, 3, 3))
K = torch.normal(mean=0, std=1, size=(2, 3, 2, 2))
multi_corr2d_in_out(X, K).shape

torch.Size([2, 2, 2])

#### $1 \times 1$ 的卷积层

可以用于提取相邻通道间的相关特征, $1 \times 1$ 的卷积层等价于一个全连接层, 下面验证下

In [26]:
def multi_corr2d_in_out_1x1(X, K):
    co, ci, h, w = K.shape[0], X.shape[0], X.shape[1], X.shape[2]
    X = X.reshape((ci, -1))
    K = K.reshape((co, -1))
    Y = torch.matmul(K, X)
    return Y.reshape((co, h, w))

验证结果

In [27]:
X = torch.normal(mean=0, std=1, size=(3, 3, 3))
K = torch.normal(mean=0, std=1, size=(2, 3, 1, 1))
Y1 = multi_corr2d_in_out(X, K)
Y2 = multi_corr2d_in_out_1x1(X, K)
print(torch.abs(Y1-Y2).sum())

tensor(3.3528e-07)
