In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 현재 사용중인 기기가 CUDA를 사용할 수 있는 경우, 그래픽카드를 지칭하도록 설정
# 그렇지 않은 경우 CPU를 지칭하도록 설정

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        # 입력 채널 1 -> 흑백 이미지
        # 출력 채널 16 -> 설계자의 임의 결정
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        # 커널 사이즈 2 -> 보편적인 사이즈 2x2

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        # 입력 채널 16 -> 상기 합성층의 출력 채널 수
        # 출력 채널 32 -> 설계자의 임의 결정
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.fc = nn.Linear(32 * 7 * 7, 10)
        # 32 -> 상기 합성층의 출력 채널 수
        # 7 * 7 -> 상기 풀링을 거친 특징값 행렬의 크기

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
if __name__ == "__main__":
    train_dataset = MNIST('./train', train=True, transform=transforms.ToTensor(), download=True)
    test_dataset = MNIST('./train', train=False, transform=transforms.ToTensor())
    # train_dataset이 만들어지면서 download=True 옵션 때문에 지정한 ./train 경로에 다운로드됨
    # test_dataset은 상기 경로에 다운로드된 데이터 중 t10k로 시작하는 데이터를 train=False 옵션으로 사용하게 됨
    # transform은 현재 텐서화 이외에는 적용하지 않음

    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    # train 데이터로더는 shuffle 옵션을 True로 지정
    # test 데이터로더는 shuffle 옵션이 필요 없음

    model = CNN().to(device)
    # 모델을 적절한 장치에 넘김

    criterion = nn.CrossEntropyLoss()
    # 다중 클래스 분류 문제이므로, Cross Entropy 손실 함수 사용
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    # lr 초기값은 설계자의 임의 지정

    num_epochs = 10

    fig, axs = plt.subplots(2, 2, figsize=(10, 8))
    fig.tight_layout(pad=4.0)
    axs = axs.flatten()
    # 시각화를 위한 그리드 설정

    epoch_losses = []
    # 오차 강하 그래프를 그리기 위한 오차를 담을 list

    for epoch in range(num_epochs):
        model.train()
        # 모델을 학습 모드로 설정
        running_loss = 0.0
        # 트레인 로더를 한번 순회한 오차값

        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            # loader에서 받아온 image와 label을 적절한 장치에 넘김

            optimizer.zero_grad() # 그래디언트 초기화

            outputs = model(images)
            # 모델에 받아온 images를 입력 후 예측값인 outputs를 받아옴

            loss = criterion(outputs, labels)
            # 실제 정답인 labels와 상기 예측값인 outputs를 손실함수로 대조하여 오차 출력

            loss.backward()
            optimizer.step()
            # 역전파 이후 optimizer가 가중치 업데이트

            running_loss += loss.item() * images.size(0)
        
        epoch_loss = running_loss / len(train_dataset)
        epoch_losses.append(epoch_loss)
        # epoch 순회마다 발생한 loss를 리스트에 추가

        print(f"Epoch {epoch+1} / {num_epochs}, Loss : {epoch_loss:.4f}")
        # Epoch 순회마다 발생한 loss 출력

        if epoch == 0:
            weights = model.conv1.weight.detach().cpu().numpy()
            axs[0].imshow(weights[0, 0], cmap='coolwarm')
            axs[0].set_title("Conv1 Weights")
            divider = make_axes_locatable(axs[0])
            cax = divider.append_axes('right', size='5%', pad=0.05)
            plt.colorbar(axs[0].imshow(weights[0, 0], cmap='coolwarm'), cax=cax)
            # 첫번째 합성곱 레이어의 가중치 시각화

            weights = model.conv2.weight.detach().cpu().numpy()
            axs[1].imshow(weights[0, 0], cmap='coolwarm')
            axs[1].set_title("Conv2 Weights")
            divider = make_axes_locatable(axs[1])
            cax = divider.append_axes('right', size='5%', pad=0.05)
            plt.colorbar(axs[1].imshow(weights[0, 0], cmap='coolwarm'), cax=cax)
            # 두번째 합성곱 레이어의 가중치 시각화

        axs[3].plot(range(epoch+1), epoch_losses)
        axs[3].set_title('training_loss')
        axs[3].set_xlabel('Epoch')
        axs[3].set_ylabel('Loss')
        # epoch loss 그래프

    plt.show()


    # 이하 모델 평가 코드
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            # outputs가 10개의 확률함수로 나타나게 될 것
            # -> outputs에 나타난 10개 확률 중 가장 높은 값을 취함 -> 모델이 예측한 숫자
            temp, predicted = torch.max(outputs.data, 1)
            # 출력 튜플의 두번째 값 -> 모델이 예측한 라벨
            print(f"temp is following value : {temp}")

            total += labels.size(0)
            # 이번 test loader에 올라온 label의 숫자를 총 숫자에 합산함
            correct += (predicted == labels).sum().item()
            # predicted와 labels 둘 다 일정 갯수의 행렬 형태로 나올 것
            # (predicted == labels) -> 같은 값은 행렬 내에서 1, 아니면 0 -> ex) [1. , 0. , 1. , 1. , ...]
            # 위에서 얻은 1, 0, 으로 이루어진 행렬의 합을 구함 -> sum()

    accuracy = (correct / total) * 100
    print(f"Test Accuracy: {accuracy:.2f} %")

temp is following value : tensor([ 7.1473,  7.2212, 14.5870, 10.0435, 13.3586, 18.4097, 18.9483, 14.5658,
         6.2995, 13.5217,  6.1366, 13.9171, 10.7465, 12.7641, 12.8424, 12.5880,
         2.6458, 13.5403, 13.8265, 13.9745,  9.4281, 13.6141, 11.1431, 15.2899,
         8.0847, 10.8256,  5.6650, 12.9282, 12.4985, 16.6367, 11.6540,  3.0474,
         7.7717, 17.3721, 10.0915,  5.3383, 10.5582, 10.9050, 11.2753, 15.0335,
        13.0704, 12.1770, 12.5113, 10.1915,  6.7587, 17.2519, 13.2630,  8.4422,
        12.5989, 16.1288, 13.9102, 13.0300, 16.3928, 11.3244, 10.7902, 16.0803,
        11.1691, 10.9699,  5.1540, 11.5113, 15.6586, 11.0888, 21.9927,  3.7110])
temp is following value : tensor([10.9124,  9.8583,  8.3787, 13.2406, 21.8021, 15.8961, 11.9709, 10.8052,
        11.8703,  9.8594,  7.1487, 15.8990, 10.0275, 17.3461, 17.0896, 11.1170,
         8.5363, 16.8178, 13.8906,  6.5141,  9.7290,  9.8770,  9.5770, 11.4189,
        12.7338, 12.1432, 18.8563, 16.8300, 10.0373, 13.3367, 11.17

temp is following value : tensor([13.8405,  8.5998, 11.5294, 14.3270, 15.7950,  6.1962, 11.4055,  8.8370,
        14.0353, 10.2752, 14.5697, 13.8926,  9.5077, 21.4309, 15.5217,  6.6638,
         8.2892,  8.6022,  8.8269, 13.6680, 21.0087, 15.0522,  5.8432,  9.2404,
        11.8354,  6.9782, 18.9317,  7.9930, 11.3590, 15.0641,  3.2426, 10.2033,
        13.2011, 18.8738, 11.7663, 14.1893, 18.9075, 11.0872, 13.3235, 13.8920,
         9.7302,  5.9895, 10.3607, 12.8754,  5.5161, 12.8630, 11.3530, 10.2325,
         8.1487, 12.7693, 15.6866, 16.1802, 13.3776, 14.1348, 18.6291, 18.8567,
         1.2915,  8.5999,  9.7766, 15.1090, 11.6341, 13.2692, 10.4648, 15.9562])
temp is following value : tensor([10.9587, 10.0875, 15.0868, 15.8964, 14.4797, 18.4745, 18.5883, 12.6285,
         7.6100, 18.8755, 17.1055, 11.6675,  7.9833, 12.1891, 10.8988, 17.7163,
         7.3039, 12.8088,  3.7242, 11.2137,  8.5289,  7.4316, 22.4639, 10.1664,
        11.2301,  7.1249, 13.7855, 10.6004,  9.9728, 11.5870, 10.37

temp is following value : tensor([12.2894, 11.3146, 12.7459, 11.7410,  9.3969,  9.0568, 11.5623, 18.1151,
        19.1446, 21.4456,  9.3314, 24.2524, 18.6952, 13.1431,  3.8091, 13.2488,
         9.3877, 16.0763, 10.9408, 13.4304, 14.4527,  9.8519, 20.6301, 17.5967,
         6.9219, 10.1972, 11.2310, 16.3736, 16.7364, 15.3162, 15.6122, 10.8792,
        12.0625,  3.0123, 13.8972, 14.2281, 15.0454,  9.2697, 14.3269, 12.3239,
         9.0850, 11.2688, 12.2331, 17.8231, 14.7604, 14.6362, 10.1773, 14.0092,
        13.0431, 15.7231, 12.1121,  9.9249, 13.3302, 18.9823,  7.7106, 19.9709,
        13.9709,  9.4454, 12.6141, 13.8166, 12.7472, 12.5258,  8.9490, 13.8546])
temp is following value : tensor([18.3077, 13.3223, 18.1595, 19.1211,  8.3858, 18.0270,  7.5633, 14.8941,
        20.9490, 13.3588, 14.8058, 11.2377,  9.7413, 13.4485,  7.4301,  3.2923,
        10.1747,  7.7744, 12.4769, 12.4442, 10.4503, 11.7537, 14.8330, 15.8147,
        10.9138,  9.8531, 12.6137,  9.9701, 18.4666, 11.6271,  6.22

temp is following value : tensor([ 7.2399, 20.0521, 15.8640, 16.4327, 11.0043, 10.5741, 19.6980,  8.0867,
         9.8800, 15.4089,  5.6462, 16.2274,  8.2531, 10.4725,  9.3297,  4.5046,
         7.5158, 13.7729, 13.3899, 14.0336, 13.4070, 18.9826,  7.1503, 12.3599,
         4.9820, 10.8131,  9.6238, 17.3169, 23.4192, 12.4314, 14.9397, 12.7690,
        11.3443, 11.5204, 12.9241, 10.8093,  8.6131, 13.6080, 13.3389, 10.3383,
        21.0192, 15.0969, 12.5333, 10.2749, 17.3861, 10.8059, 15.2626,  5.1482,
        12.1136, 14.9015, 19.8355,  9.3632, 19.7038, 10.2094,  6.7960, 14.6180,
        17.9210,  7.2344,  9.1249, 17.2442,  7.1648, 14.1257, 16.0724, 11.8580])
temp is following value : tensor([18.4459,  4.6578, 11.3319, 10.4838, 12.0497, 17.3274, 10.0884, 15.8728,
        11.8575, 12.6106, 18.8568, 10.1871, 20.3165,  4.1829,  7.2260,  8.4882,
         9.7650, 12.5580, 12.2078, 13.1181, 10.4644, 12.4563,  5.2592,  8.7069,
        13.3398, 11.9999, 16.7439, 20.4979,  4.0197, 18.1166, 18.45

temp is following value : tensor([12.3979, 13.2025, 20.2004, 17.9601, 20.7669, 12.2059, 12.8371, 18.7981,
        13.9603, 14.6950, 13.1485, 20.1100, 12.6075,  1.6208, 14.5839, 10.1618,
        18.6173,  3.8279,  9.6132, 13.4136, 19.5773, 14.6175,  1.0224, 19.7467,
         5.8256, 15.9306, 13.1343, 14.7577,  3.6182,  0.9789, 14.0528, 11.7117,
        13.6275, 18.4162,  9.2639, 11.9018,  9.6691,  9.6249, 10.8105,  9.7713,
        17.5867, 16.9153, 20.8030, 19.4458, 10.8399, 11.0459, 13.9190, 13.9433,
        11.0743,  9.7009, 16.8295, 17.8274, 19.2161, 17.2996, 13.0838, 13.0946,
        12.7376, 17.6002, 11.6348, 13.0160, 17.0316, 15.9347, 22.7612, 16.3353])
temp is following value : tensor([12.5536, 12.8406, 13.5581, 18.0417, 11.6467, 16.4896, 13.3278, 18.2295,
        20.3558, 14.8407, 18.7563, 12.7352, 16.4932, 12.7018, 10.2135, 13.3777,
        11.5401,  9.7418, 20.1737, 15.1415, 14.2413, 15.9510, 13.5091, 10.8782,
        10.5277, 10.7855,  6.6828, 11.5231, 22.6749, 14.4108, 12.28

temp is following value : tensor([13.3792, 11.5523,  8.2248, 11.2061, 12.3380, 13.6657,  7.3528, 15.6601,
         8.1146,  8.3413,  4.4079, 23.3973, 18.7294, 15.0751, 14.9225, 14.7005,
        13.9682,  8.1062, 13.1014, 13.1910, 10.6233,  9.2012, 14.3761, 10.8588,
        17.1435, 12.9715, 12.4532,  6.6616, 15.3165, 12.5374, 12.3619, 12.3342,
         9.9196, 10.3404, 18.4214,  9.7321, 14.3888, 10.1412, 10.1367, 11.7769,
        12.2265, 11.3611, 12.0246, 16.7446, 12.1260,  8.1175, 19.0312,  7.9760,
         8.8992,  6.1868, 16.2499, 20.6274,  8.4454, 14.7013, 15.1127, 12.7943,
        14.2592,  9.2032,  6.7686, 12.8154, 18.4625, 16.3875,  7.1849,  9.7891])
temp is following value : tensor([12.3032,  9.5565, 12.2331,  6.6244,  9.9566, 12.6494, 14.0450,  9.9871,
        11.0003, 12.2454, 10.5883, 10.4297, 12.1289, 19.0217, 11.7097, 11.7167,
        12.2390,  9.6493, 16.4856, 10.5229, 10.5516, 15.6654, 12.6773,  9.9256,
        13.9333, 13.4854,  7.9151, 10.5794, 16.2364, 11.6028,  8.12

temp is following value : tensor([12.5707, 18.5100, 15.2671, 20.5116, 17.9167, 11.1949, 14.7255, 12.5442,
        12.4513, 17.3983,  9.7948, 18.4453, 17.5056, 11.0777, 14.4146, 15.7183,
        13.2589, 16.6964, 12.4263, 11.8749, 22.2517, 13.3800, 18.7188, 12.8941,
        13.0641, 17.1978, 11.6695, 15.3350, 21.7220, 16.6044, 13.9230, 17.0524,
        11.7557, 12.5977, 14.9361, 20.3847, 19.0472, 16.3872, 10.3526,  9.6702,
        14.8708, 12.4193, 11.6375, 14.2964, 14.2784, 16.9833, 14.8980, 15.3114,
        14.7686, 14.0599, 12.6337, 16.3258, 12.4106, 11.1620, 18.2361, 16.6438,
        21.3515, 20.5467, 12.2938, 10.8824, 11.7547, 19.6060, 17.3554, 15.4745])
temp is following value : tensor([17.6920, 13.2788, 20.0847, 12.2268, 12.0367, 19.7739, 18.1382, 17.3318,
        14.7255, 16.9668, 10.8121, 11.6527, 17.1250, 12.8963, 11.8044, 13.3066,
        12.8575, 13.6936, 10.5888, 11.8660, 12.6675, 14.9878, 15.8453, 13.3064,
        12.0958, 12.7952, 17.1403, 15.0200, 17.8784, 14.8275, 20.80

temp is following value : tensor([14.0819, 13.8169,  7.1608, 11.6278,  8.3836,  7.4075,  9.6742, 10.8189,
        15.9692, 14.4589, 16.9125, 12.0014, 12.0630, 12.7199, 14.1464, 11.5351,
        12.2168,  8.4616, 15.3925, 10.3887,  9.5609,  8.0598, 20.9750, 16.0463,
        13.8366, 12.0956, 13.0423, 10.0676, 13.4008, 14.4558, 13.4409, 11.3758,
        15.5095, 11.1244,  7.9282, 12.1279, 12.3610, 12.0027, 12.1446,  8.3278,
        16.1318, 17.9266, 15.0547, 10.8563, 10.8725, 15.2378, 15.3470, 13.1065,
        13.4867, 12.3277, 11.0896, 10.4159, 14.5344, 12.8790, 10.3770,  9.0766,
         8.1641, 11.1953, 11.4219, 12.4049, 10.0366,  6.1222,  3.8607, 14.1872])
temp is following value : tensor([11.8347, 11.5138, 14.3397, 15.7146, 22.9846, 16.5639, 15.5905,  8.8733,
        12.5382,  8.6070, 10.7995, 15.2773, 14.1682, 14.6041, 16.8745,  9.4691])
Test Accuracy: 98.87 %
