In [1]:
import torch
import torch.nn as nn

In [2]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(BasicBlock, self).__init__()

        # 합성곱 층 정의
        self.c1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=kernel_size, padding=1
        )
        self.c2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=kernel_size, padding=1
        )

        self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        # 배치 정규화 층 정의
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)

        self.relu = nn.ReLU()

    def forward(self, x):
        # 스킵 커넥션을 위해 초기 입력 저장
        x_ = x

        # ResNet 기본 블록에서 F(x) 부분
        x = self.c1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.c2(x)
        x = self.bn2(x)

        # 합성곱 결과와 입력 채널 수 맞춤
        x_ = self.downsample(x_)

        # 스킵 커넥션
        x += x_
        x = self.relu(x)

        return x

In [3]:
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()

        self.b1 = BasicBlock(in_channels=3, out_channels=64)
        self.b2 = BasicBlock(in_channels=64, out_channels=128)
        self.b3 = BasicBlock(in_channels=128, out_channels=256)

        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(in_features=4096, out_features=2048)
        self.fc2 = nn.Linear(in_features=2048, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=num_classes)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.b1(x)
        x = self.pool(x)
        x = self.b2(x)
        x = self.pool(x)
        x = self.b3(x)
        x = self.pool(x)

        x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)

        return x

In [4]:
import tqdm
from torch.optim.adam import Adam
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import (
    Compose,
    Normalize,
    RandomCrop,
    RandomHorizontalFlip,
    ToTensor,
)

In [5]:
transforms = Compose(
    [
        RandomCrop((32, 32), padding=4),
        RandomHorizontalFlip(p=0.5),
        ToTensor(),
        Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)),
    ]
)

In [6]:
training_data = CIFAR10(root="./", train=True, download=True, transform=transforms)
test_data = CIFAR10(root="./", train=False, download=True, transform=transforms)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|███████████████████████████████████████████████████████████████████████████████| 170M/170M [00:11<00:00, 14.6MB/s]


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


In [7]:
train_loader = DataLoader(training_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [9]:
model = ResNet(num_classes=10)
model.to(device)

ResNet(
  (b1): BasicBlock(
    (c1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b2): BasicBlock(
    (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b3): BasicBlock(
    (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [10]:
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)

In [11]:
for epoch in range(30):
    iterator = tqdm.tqdm(train_loader)
    for data, label in iterator:
        optim.zero_grad()

        preds = model(data.to(device))

        loss = nn.CrossEntropyLoss()(preds, label.to(device))
        loss.backward()
        optim.step()

        iterator.set_description(f"epoch: {epoch + 1} loss: {loss.item()}")

epoch: 1 loss: 0.9666403532028198: 100%|███████████████████████████████████████████| 1563/1563 [00:33<00:00, 47.12it/s]
epoch: 2 loss: 1.1442512273788452: 100%|███████████████████████████████████████████| 1563/1563 [00:31<00:00, 48.95it/s]
epoch: 3 loss: 0.6006902456283569: 100%|███████████████████████████████████████████| 1563/1563 [00:32<00:00, 48.79it/s]
epoch: 4 loss: 0.8722031116485596: 100%|███████████████████████████████████████████| 1563/1563 [00:31<00:00, 48.89it/s]
epoch: 5 loss: 0.43194618821144104: 100%|██████████████████████████████████████████| 1563/1563 [00:32<00:00, 48.74it/s]
epoch: 6 loss: 0.3194368779659271: 100%|███████████████████████████████████████████| 1563/1563 [00:31<00:00, 48.97it/s]
epoch: 7 loss: 0.5391959547996521: 100%|███████████████████████████████████████████| 1563/1563 [00:31<00:00, 48.89it/s]
epoch: 8 loss: 0.3684583008289337: 100%|███████████████████████████████████████████| 1563/1563 [00:31<00:00, 49.16it/s]
epoch: 9 loss: 0.5306670665740967: 100%|

In [12]:
torch.save(model.state_dict(), "ResNet.pth")

In [13]:
model.load_state_dict(torch.load("ResNet.pth", weights_only=True, map_location=device))

<All keys matched successfully>

In [14]:
num_corr = 0

with torch.no_grad():
    for data, label in test_loader:
        output = model(data.to(device))
        preds = output.data.max(1)[1]
        corr = preds.eq(label.to(device).data).sum().item()
        num_corr += corr

    print(f"Accuracy: {num_corr / len(test_data)}")

Accuracy: 0.8833
