<a href="https://colab.research.google.com/github/cbh4635/DL_studty/blob/main/multiclass_classification_CIFAR10_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 학습/검증 관련 함수 사전 정의

In [1]:
import os
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE) # GPU 설정 확인

def Train(model, train_DL, criterion, optimizer, EPOCH):

    loss_history =[]
    NoT = len(train_DL.dataset)

    model.train() # train mode로 전환
    for ep in range(EPOCH):
        rloss = 0 # running loss
        for x_batch, y_batch in train_DL:
            x_batch = x_batch.to(DEVICE)
            y_batch = y_batch.to(DEVICE)
            # inference
            y_hat = model(x_batch)
            # loss
            loss = criterion(y_hat, y_batch)
            # update
            optimizer.zero_grad() # gradient 누적을 막기 위한 초기화
            loss.backward() # backpropagation
            optimizer.step() # weight update
            # loss accumulation
            loss_b = loss.item() * x_batch.shape[0] # batch loss
            rloss += loss_b # running loss
        # print loss
        loss_e = rloss/NoT
        loss_history += [loss_e]
        print(f"Epoch: {ep+1}, train loss: {round(loss_e,3)}")
        print("-"*20)

    return loss_history

def Test(model, test_DL):
    model.eval() # eval mode로 전환
    with torch.no_grad():
        rcorrect = 0
        for x_batch, y_batch in test_DL:
            x_batch = x_batch.to(DEVICE)
            y_batch = y_batch.to(DEVICE)
            # inference
            y_hat = model(x_batch)
            # accuracy accumulation
            pred = y_hat.argmax(dim=1)
            corrects_b = torch.sum(pred == y_batch).item()
            rcorrect += corrects_b
        accuracy_e = rcorrect/len(test_DL.dataset)*100
    print(f"Test accuracy: {rcorrect}/{len(test_DL.dataset)} ({round(accuracy_e,1)} %)")

def Test_plot(model, test_DL):
    model.eval()
    with torch.no_grad():
        x_batch, y_batch = next(iter(test_DL))
        x_batch = x_batch.to(DEVICE)
        y_hat = model(x_batch)
        pred = y_hat.argmax(dim=1)

    x_batch = x_batch.to("cpu")

    plt.figure(figsize=(8,4))
    for idx in range(6):
        plt.subplot(2,3, idx+1, xticks=[], yticks=[])
        plt.imshow(x_batch[idx].permute(1,2,0).squeeze(), cmap="gray")
        pred_class = test_DL.dataset.classes[pred[idx]]
        true_class = test_DL.dataset.classes[y_batch[idx]]
        plt.title(f"{pred_class} ({true_class})", color = "g" if pred_class==true_class else "r")

def count_params(model):
    num = sum([p.numel() for p in model.parameters() if p.requires_grad])
    return num

cuda


## 학습 파라미터 정의

In [2]:
BATCH_SIZE = 32
LR = 1e-3
EPOCH = 5
criterion = nn.CrossEntropyLoss()

## 데이터셋 로드

In [None]:
transform = transforms.ToTensor()
train_DS = datasets.CIFAR10(root = '/content/ data', train=True, download=True, transform=transform)
test_DS = datasets.CIFAR10(root = '/content/ data', train=False, download=True, transform=transform)
train_DL = torch.utils.data.DataLoader(train_DS, batch_size=BATCH_SIZE, shuffle=True)
test_DL = torch.utils.data.DataLoader(test_DS, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
print(train_DS)
print(test_DS)
print(len(train_DS))
print(len(test_DS))

In [None]:
print(test_DS.classes)
print(test_DS.class_to_idx)

# DataLoader로 한국자 퍼내기
x_batch, y_batch = next(iter(test_DL))
print(x_batch.shape)

plt.imshow(x_batch[0].permute(1,2,0)) # (C,H,W)->(H,W,C)
print(test_DS.classes[y_batch[0]])

## MLP 모델 정의

In [6]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()

        # MLP의 은닉층 노드 수 - 하이퍼 파리미터
        hidden_dim = 100

        # MLP layer 정의
        self.linear = nn.Sequential(nn.Linear(3*32*32, hidden_dim), # 입력 노드 수: 3*32*32 (MNIST는 28*28)
                                    nn.ReLU(), # activation func
                                    nn.Linear(hidden_dim,10)) # 출력 노드 수: 10 (총 class 개수)

    def forward(self, x):
        # MLP는 CNN과 다르게 2d 이미지를 바로 처리 못하므로 flatten 매서드를 통해 평탄화함
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)
        return x

## CNN 모델 정의

In [7]:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3,8,3, padding=1), # (입력 채널, 출력 채널, 커널 사이즈)
                                   nn.BatchNorm2d(8), # (입력채널)
                                   nn.ReLU())
        self.Maxpool1 = nn.MaxPool2d(2) # (커널 사이즈) - 다운 스케일링
        self.conv2 = nn.Sequential(nn.Conv2d(8,16,3, padding=1),
                                   nn.BatchNorm2d(16),
                                   nn.ReLU())
        self.Maxpool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Sequential(nn.Conv2d(16,32,3, padding=1),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU())
        self.Maxpool3 = nn.MaxPool2d(2)
        self.fc = nn.Linear(32*4*4, 10) # (최종입력 채널 * Feature map Size, 클래스 수)

    def forward(self, x):
        x = self.conv1(x)
        x = self.Maxpool1(x)
        x = self.conv2(x)
        x = self.Maxpool2(x)
        x = self.conv3(x)
        x = self.Maxpool3(x)
        x = torch.flatten(x,start_dim=1)
        x = self.fc(x)
        return x

class CNN_deep(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_block1 = nn.Sequential(nn.Conv2d(3,32,3,padding=1),
                                         nn.BatchNorm2d(32),
                                         nn.ReLU(),
                                         nn.Conv2d(32,32,3,padding=1),
                                         nn.BatchNorm2d(32),
                                         nn.ReLU())
        self.Maxpool1 = nn.MaxPool2d(2)

        self.conv_block2 = nn.Sequential(nn.Conv2d(32,64,3,padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU(),
                                         nn.Conv2d(64,64,3,padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU(),
                                         nn.Conv2d(64,64,3,padding=1),
                                         nn.BatchNorm2d(64),
                                         nn.ReLU())
        self.Maxpool2 = nn.MaxPool2d(2)

        self.conv_block3 = nn.Sequential(nn.Conv2d(64,128,3,padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU(),
                                         nn.Conv2d(128,128,3,padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU(),
                                         nn.Conv2d(128,128,3,padding=1),
                                         nn.BatchNorm2d(128),
                                         nn.ReLU())
        self.Maxpool3 = nn.MaxPool2d(2)

        self.classifier = nn.Sequential(nn.Linear(128*4*4,512),  # (최종입력 채널 * Feature map Size, 출력 채널)
                                        nn.ReLU(), # activation func
                                        nn.Linear(512,10)) # 출력 노드 수: 10 (총 class 개수)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.Maxpool1(x)
        x = self.conv_block2(x)
        x = self.Maxpool2(x)
        x = self.conv_block3(x)
        x = self.Maxpool3(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

## 모델 설정

In [None]:
# 모델 저장 경로 설정
os.makedirs("/content/test_result", exist_ok=True)

new_model_train = True
model_type = "MLP"
dataset = "CIFAR10"
save_model_path = f"/content/test_result/{model_type}_{dataset}.pt"

exec(f"model = {model_type}().to(DEVICE)") # 모델 GPU에 올리기
print(model) # 모델 구성 확인

x_batch, _ = next(iter(train_DL))
print('모델 최종출력 Tensor: ',model(x_batch.to(DEVICE)).shape)

## 모델 훈련

In [34]:
if new_model_train:
    optimizer = optim.Adam(model.parameters(), lr=LR) # Adam 사용
    loss_history = Train(model, train_DL, criterion, optimizer, EPOCH)

    torch.save(model, save_model_path)

    plt.plot(range(1,EPOCH+1),loss_history)
    plt.xlabel('Epoch')
    plt.ylabel('loss')
    plt.title("Train Loss")
    plt.grid()

    new_model_train = False
    load_model = model
else:
    print('Model load')
    load_model = torch.load(save_model_path, map_location=DEVICE, weights_only=False)

Load Model


## 모델 결과 확인

In [None]:
Test(load_model, test_DL)
print('모델 파라미터 수: ', count_params(load_model))

In [None]:
Test_plot(load_model, test_DL)

## CNN 모델의 Feature map 시각화

In [None]:
load_model=torch.load("/content/test_result/CNN_deep_CIFAR10.pt", map_location=DEVICE, weights_only=False)
tmp_DL = torch.utils.data.DataLoader(test_DS, batch_size=1, shuffle=True)

load_model.eval()
with torch.no_grad():
    x_batch, y_batch = next(iter(tmp_DL))
    x_batch = x_batch.to(DEVICE)
    y_batch = y_batch.to(DEVICE)
    y_hat = load_model(x_batch)
    pred = y_hat.argmax(dim=1)

    feature_map1 = load_model.conv_block1(x_batch)
    feature_map2 = load_model.conv_block2(load_model.Maxpool1(feature_map1))
    feature_map3 = load_model.conv_block3(load_model.Maxpool2(feature_map2))

x_batch = x_batch.cpu()
feature_map1 = feature_map1.cpu()
feature_map2 = feature_map2.cpu()
feature_map3 = feature_map3.cpu()

print(test_DS.classes[y_batch])
plt.figure(figsize=(8,8))
plt.xticks([]); plt.yticks([])
plt.imshow(x_batch[0,...].permute(1,2,0))

print(feature_map1.shape)
plt.figure(figsize=(32,16))
for idx in range(32):
    plt.subplot(4,8,idx+1, xticks=[], yticks=[])
    plt.imshow(feature_map1[0,idx,...], cmap="gray")

print(feature_map2.shape)
plt.figure(figsize=(16,16))
for idx in range(64):
    plt.subplot(8,8,idx+1, xticks=[], yticks=[])
    plt.imshow(feature_map2[0,idx,...], cmap="gray")

print(feature_map3.shape)
plt.figure(figsize=(16,8))
for idx in range(128):
    plt.subplot(8,16,idx+1, xticks=[], yticks=[])
    plt.imshow(feature_map3[0,idx,...], cmap="gray")

In [None]:
# 가중치 결과: 보라색(최소값) -> 남색 -> 초록색 -> 노란색(최대값)
summed_map = feature_map3.abs().sum(dim=1)
plt.figure(figsize=(8,8))
plt.xticks([]); plt.yticks([])
plt.imshow(summed_map[0,...], cmap='viridis')

plt.figure(figsize=(8,8))
plt.xticks([]); plt.yticks([])
plt.imshow(x_batch[0,...].permute(1,2,0))
plt.imshow(summed_map[0,...], extent=[0,32,32,0], alpha=0.4, cmap='viridis')
pred_class = test_DS.classes[pred]
true_class = test_DS.classes[y_batch]
plt.title(f"{pred_class} ({true_class})", color="g" if pred_class==true_class else "r")