In [1]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [17]:
!nvidia-smi

Wed Jul  3 04:37:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P0              30W /  70W |    199MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
# import

import numpy as np
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import torchinfo

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
from pytz import timezone

In [4]:
# 하이퍼파라미터 설정
RANDOM_SEED = 4242
LEARNING_RATE = 0.01
BATCH_SIZE = 32
EPOCHS = 90
IMG_SIZE = 227
NUM_CLASSES = 1000

In [5]:
# 모델의 정확도를 계산하는 함수
def get_accuracy(model, data_loader, device):
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        model.eval()
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            probabilities = F.softmax(model(images), dim=-1)
            _, predicted_labels = torch.max(probabilities, 1)

            total_predictions += labels.size(0)
            correct_predictions += (predicted_labels == labels).sum()
    return correct_predictions.float() / total_predictions

In [6]:
# 학습 손실과 검증 손실을 시각화
def plot_loss(train_loss, val_loss):
    plt.style.use("grayscale")
    train_loss = np.array(train_loss)
    val_loss = np.array(val_loss)
    fig, ax = plt.subplots(1, 1, figsize=(8, 4.5))
    ax.plot(train_loss, color="green", label="Training Loss")
    ax.plot(val_loss, color="red", label="Validation Loss")
    ax.set(title="Loss Over Epochs", xlabel="EPOCH", ylabel="LOSS")
    ax.legend()
    fig.show()
    plt.style.use("default")

In [7]:
# 모델 학습 함수
def train(train_loader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)
        total_loss += loss.item() * images.size(0)
        loss.backward()
        optimizer.step()
    epoch_loss = total_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss

In [8]:
# 검증 데이터셋을 사용하여 모델의 성능을 평가
def validate(valid_loader, model, criterion, device):
    model.eval()
    total_loss = 0

    for images, labels in valid_loader:
        images = images.to(device)
        labels = labels.to(device)

        # 순전파와 손실 기록하기
        logits = model(images)
        loss = criterion(logits, labels)
        total_loss += loss.item() * images.size(0)

    epoch_loss = total_loss / len(valid_loader.dataset)
    return model, epoch_loss

In [9]:
# 전체 학습 루프
def training_loop(
    model,
    criterion,
    optimizer,
    train_loader,
    valid_loader,
    epochs,
    device,
    print_every=1,
):
    train_losses = []
    valid_losses = []

    for epoch in range(epochs):
        # training
        model, optimizer, train_loss = train(
            train_loader, model, criterion, optimizer, device
        )
        train_losses.append(train_loss)

        # validation
        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion, device)
            valid_losses.append(valid_loss)

        if epoch % print_every == (print_every - 1):

            train_acc = get_accuracy(model, train_loader, device=device)
            valid_acc = get_accuracy(model, valid_loader, device=device)

            print(
                datetime.now(timezone("Asia/Seoul")).time().replace(microsecond=0),
                "--- ",
                f"Epoch: {epoch}\t"
                f"Train loss: {train_loss:.4f}\t"
                f"Valid loss: {valid_loss:.4f}\t"
                f"Train accuracy: {100 * train_acc:.2f}\t"
                f"Valid accuracy: {100 * valid_acc:.2f}",
            )

    plot_loss(train_losses, valid_losses)

    return model, optimizer, (train_losses, valid_losses)

In [18]:
# AlexNet 논문과 최대한 유사하게 구현
class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.conv1_u = nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=0)
        self.conv1_d = nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.conv2_u = nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2)
        self.conv2_d = nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.conv3_u = nn.Conv2d(128 * 2, 192, kernel_size=3, stride=1, padding=1)
        self.conv3_d = nn.Conv2d(128 * 2, 192, kernel_size=3, stride=1, padding=1)

        self.conv4_u = nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1)
        self.conv4_d = nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1)

        self.conv5_u = nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1)
        self.conv5_d = nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2)

        self.fc6_u = nn.Linear(128 * 6 * 6 * 2, 2048)
        self.fc6_d = nn.Linear(128 * 6 * 6 * 2, 2048)

        self.fc7_u = nn.Linear(2048 * 2, 2048)
        self.fc7_d = nn.Linear(2048 * 2, 2048)

        self.fc8 = nn.Linear(2048 * 2, NUM_CLASSES)

    def forward(self, x):
        # Conv 1
        x_u = F.relu(self.conv1_u(x))
        x_d = F.relu(self.conv1_d(x))
        x_u = self.pool1(x_u)
        x_d = self.pool1(x_d)

        # Conv 2
        x_u = F.relu(self.conv2_u(x_u))
        x_d = F.relu(self.conv2_d(x_d))
        x_u = self.pool2(x_u)
        x_d = self.pool2(x_d)

        # Conv 3, GPU 데이터 합치고 각각 연산 수행
        x = torch.cat((x_u, x_d), dim=1)
        x_u = F.relu(self.conv3_u(x))
        x_d = F.relu(self.conv3_d(x))

        # Conv 4
        x_u = F.relu(self.conv4_u(x_u))
        x_d = F.relu(self.conv4_d(x_d))

        # Conv 5
        x_u = F.relu(self.conv5_u(x_u))
        x_d = F.relu(self.conv5_d(x_d))
        x_u = self.pool5(x_u)
        x_d = self.pool5(x_d)

        # FC 6, GPU 데이터 합치고 각각 연산 수행
        x = torch.cat((x_u, x_d), dim=1)
        x = x.view(x.size(0), -1)
        x_u = F.relu(self.fc6_u(x))
        x_d = F.relu(self.fc6_d(x))

        # FC 7, GPU 데이터 합치고 각각 연산 수행
        x = torch.cat((x_u, x_d), dim=1)
        x_u = F.relu(self.fc7_u(x))
        x_d = F.relu(self.fc7_d(x))

        # FC 8, GPU 데이터 합쳐서 최종 연산
        x = torch.cat((x_u, x_d), dim=1)
        logits = self.fc8(x)

        return logits


torchinfo.summary(
    AlexNet(NUM_CLASSES),
    input_size=(1, 3, IMG_SIZE, IMG_SIZE),
    col_names=["input_size", "output_size", "num_params", "kernel_size"],
    row_settings=["depth", "var_names"],
)

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
AlexNet (AlexNet)                        [1, 3, 227, 227]          [1, 1000]                 --                        --
├─Conv2d (conv1_u): 1-1                  [1, 3, 227, 227]          [1, 48, 55, 55]           17,472                    [11, 11]
├─Conv2d (conv1_d): 1-2                  [1, 3, 227, 227]          [1, 48, 55, 55]           17,472                    [11, 11]
├─MaxPool2d (pool1): 1-3                 [1, 48, 55, 55]           [1, 48, 27, 27]           --                        3
├─MaxPool2d (pool1): 1-4                 [1, 48, 55, 55]           [1, 48, 27, 27]           --                        3
├─Conv2d (conv2_u): 1-5                  [1, 48, 27, 27]           [1, 128, 27, 27]          153,728                   [5, 5]
├─Conv2d (conv2_d): 1-6                  [1, 48, 27, 27]           [1, 128, 27, 27]          153,728                   [5,