<a href="https://colab.research.google.com/github/rickiepark/MLQandAI/blob/main/supplementary/q12-fc-cnn-equivalence/q12-fc-cnn-equivalence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 12장의 예제 코드

완전 연결 층과 합성곱 층이 동등한 상황은?

In [None]:
import torch
print(f"파이토치 버전: {torch.__version__}")

파이토치 버전: 2.5.1+cu121


## 1) 완전 연결 층

<img src="https://github.com/rickiepark/MLQandAI/blob/main/supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-1.png?raw=1" width="400px">

In [None]:
torch.manual_seed(123)

fc = torch.nn.Linear(4, 2)

inputs = torch.tensor([[1., 2., 3., 4.]])

with torch.no_grad():
    out1 = fc(inputs)

print(out1)

tensor([[-0.4775, -2.1469]])


In [None]:
fc.weight

Parameter containing:
tensor([[-0.2039,  0.0166, -0.2483,  0.1886],
        [-0.4260,  0.3665, -0.3634, -0.3975]], requires_grad=True)

## 2) 시나리오 1: 커널 크기가 입력 크기와 같을 때

<img src="https://github.com/rickiepark/MLQandAI/blob/main/supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-2.png?raw=1" width="500px">

파이토치 합성곱 층은 기본적으로 NCHW 포맷을 기대합니다.

- N = 배치 크기
- C = 채널
- H = 높이
- W = 너비

In [None]:
reshaped = inputs.reshape(-1, 1, 2, 2)
reshaped

tensor([[[[1., 2.],
          [3., 4.]]]])

In [None]:
conv = torch.nn.Conv2d(
    in_channels=1,
    out_channels=2,
    kernel_size=2
)

conv.weight.shape

torch.Size([2, 1, 2, 2])

Conv2d의 가중치가 랜덤하게 초기화되기 때문에 정확히 동일한 결과를 얻으려면 완전 연결 층의 가중치로 합성곱 층의 랜덤한 가중치를 덮어 써야 합니다.

In [None]:
with torch.no_grad():
    conv.weight[0][0] = fc.weight[0].reshape(1, 2, 2)
    conv.weight[1][0] = fc.weight[1].reshape(1, 2, 2)
    conv.bias[0] = fc.bias[0]
    conv.bias[1] = fc.bias[1]

    out2 = conv(reshaped)

print(out2)

tensor([[[[-0.4775]],

         [[-2.1469]]]])


In [None]:
out1.flatten() == out2.flatten()

tensor([True, True])

## 3) 시나리오 2: 커널 크기가 1일 때

<img src="https://github.com/rickiepark/MLQandAI/blob/main/supplementary/q12-fc-cnn-equivalence/img/fc-cnn-equivalent-3.png?raw=1" width="500px">

In [None]:
reshaped2 = inputs.reshape(-1, 4, 1, 1)
reshaped2

tensor([[[[1.]],

         [[2.]],

         [[3.]],

         [[4.]]]])

In [None]:
conv = torch.nn.Conv2d(
    in_channels=4,
    out_channels=2,
    kernel_size=1
)

conv.weight.shape

torch.Size([2, 4, 1, 1])

In [None]:
with torch.no_grad():
    conv.weight[0] = fc.weight[0].reshape(4, 1, 1)
    conv.weight[1] = fc.weight[1].reshape(4, 1, 1)
    conv.bias[0] = fc.bias[0]
    conv.bias[1] = fc.bias[1]

    out3 = conv(reshaped2)

print(out3)

tensor([[[[-0.4775]],

         [[-2.1469]]]])


In [None]:
out1.flatten() == out3.flatten()

tensor([True, True])