In [2]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# MNIST 데이터셋 로딩
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# DataLoader 사용
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False
)


In [3]:
import torch.nn as nn
import torch.optim as optim

# 신경망 모델 정의
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)   # 입력 28×28 이미지 -> 128 뉴런
        self.fc2 = nn.Linear(128, 10)        # 128 뉴런 -> 10개의 클래스

    def forward(self, x):
        x = x.view(-1, 28 * 28)              # 28×28 이미지를 1D 벡터로 펼침
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [4]:
# 모델, 손실 함수, 옵티마이저 설정
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 학습
for epoch in range(5):   # 5번 학습
    for images, labels in train_loader:
        optimizer.zero_grad()          # 기울기 초기화
        outputs = model(images)        # 예측
        loss = criterion(outputs, labels)  # 손실 계산
        loss.backward()                # 역전파
        optimizer.step()               # 파라미터 업데이트

    print(f"Epoch [{epoch+1}/5], Loss: {loss.item():.4f}")


Epoch [1/5], Loss: 0.3234
Epoch [2/5], Loss: 0.3479
Epoch [3/5], Loss: 0.1286
Epoch [4/5], Loss: 0.1361
Epoch [5/5], Loss: 0.3231


In [5]:
# 모델 평가
correct = 0
total = 0

with torch.no_grad():   # 평가 시에는 기울기 계산을 하지 않음
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)  # 예측값 추출
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy: {accuracy:.2f}%')


Accuracy: 96.72%
