# MNIST CNN 딥러닝 (Pytorch)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

In [None]:
# sklearn digits 데이터셋 로드
digits = datasets.load_digits()
X = torch.tensor(digits.images.reshape((len(digits.images), -1)), dtype=torch.float32)
y = torch.tensor(digits.target, dtype=torch.long)

# 학습/테스트 분할
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# CNN 입력을 위해 데이터 재구성 (채널 차원 추가)
X_train = X_train.reshape(-1, 1, 8, 8)  # 8x8 이미지를 (배치크기, 채널, 높이, 너비) 형태로 변환
X_test = X_test.reshape(-1, 1, 8, 8)

# DataLoader 생성
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [8]:
# CNN 모델 정의
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 첫 번째 컨볼루션 블록
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 8x8 -> 8x8
        self.pool1 = nn.MaxPool2d(2, 2)  # 8x8 -> 4x4
        
        # 두 번째 컨볼루션 블록
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 4x4 -> 4x4
        self.pool2 = nn.MaxPool2d(2, 2)  # 4x4 -> 2x2
        
        # 완전 연결 레이어
        self.fc1 = nn.Linear(64 * 2 * 2, 64)
        self.fc2 = nn.Linear(64, 10)  # 10개 클래스 (0-9)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = self.pool1(x)
        x = nn.functional.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 64 * 2 * 2)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [10]:
# 모델, 손실 함수, 옵티마이저 인스턴스 생성
model = CNN()
criterion = nn.CrossEntropyLoss()  # PyTorch는 원-핫 인코딩 없이 바로 사용 가능
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [12]:
# 학습 루프
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        # 그래디언트 초기화
        optimizer.zero_grad()
        
        # 순전파
        outputs = model(images)
        
        # 손실 계산
        loss = criterion(outputs, labels)
        
        # 역전파
        loss.backward()
        
        # 가중치 업데이트
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}')

Epoch 1/5, Loss: 1.5650
Epoch 2/5, Loss: 0.4136
Epoch 3/5, Loss: 0.1835
Epoch 4/5, Loss: 0.1039
Epoch 5/5, Loss: 0.0903


In [28]:
# 평가
model.eval()
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # 예측값과 실제 라벨 저장
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = correct / total
print(f'Accuracy of the CNN model on the test set: {accuracy:.4f}')

Accuracy of the CNN model on the test set: 0.9750


In [30]:
# 성능 지표 계산
print("\n=== CNN 성능 지표 ===")
print(f"F1 Score: {f1_score(all_labels, all_preds, average='macro'):.4f}")
print(f"Precision: {precision_score(all_labels, all_preds, average='macro'):.4f}")
print(f"Recall: {recall_score(all_labels, all_preds, average='macro'):.4f}")
print("\n=== Confusion Matrix ===")
print(confusion_matrix(all_labels, all_preds))


=== CNN 성능 지표 ===
F1 Score: 0.9742
Precision: 0.9757
Recall: 0.9734

=== Confusion Matrix ===
[[33  0  0  0  0  0  0  0  0  0]
 [ 0 27  1  0  0  0  0  0  0  0]
 [ 0  0 33  0  0  0  0  0  0  0]
 [ 0  0  0 33  0  1  0  0  0  0]
 [ 0  0  0  0 46  0  0  0  0  0]
 [ 0  0  0  0  0 46  1  0  0  0]
 [ 0  0  0  0  0  0 35  0  0  0]
 [ 0  0  0  0  0  0  0 33  0  1]
 [ 0  1  1  0  0  1  0  0 27  0]
 [ 0  0  0  0  0  1  0  0  1 38]]
