In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_EPOCHS = 100

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

# MNIST 데이터셋 불러오기
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
valid_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# DataLoader 설정
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1000, shuffle=False)

# 학습모델 정의
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(20 * 12 * 12, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 20 * 12 * 12)
        x = F.relu(self.fc1(x))

        return x
    
model = SimpleNN()
if torch.cuda.is_available():
    model = model.to('cuda')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)


for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0

    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        # cuda 세팅 안돼있어서 일단 주석
        # if torch.cuda.is_available():
        #     inputs, labels = inputs.to('cuda'), labels.to('cuda')

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

    # 평가
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in valid_loader:
            images, labels = data

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Accuracy of the network on the validation images: {100 * correct / total}%")

# 모델의 state_dict 저장
torch.save(model.state_dict(), 'model_weights.pth')

# 전체 모델 저장
torch.save(model, 'complete_model.pth')


Epoch 1, Loss: 1.3705984130342885
Accuracy of the network on the validation images: 58.26%
Epoch 2, Loss: 1.113471129714553
Accuracy of the network on the validation images: 59.1%
Epoch 3, Loss: 1.0593496648741683
Accuracy of the network on the validation images: 59.37%
Epoch 4, Loss: 1.0289149052425743
Accuracy of the network on the validation images: 59.6%
Epoch 5, Loss: 1.009620405654155
Accuracy of the network on the validation images: 59.81%
Epoch 6, Loss: 0.9959461152680648
Accuracy of the network on the validation images: 59.91%
Epoch 7, Loss: 0.9863413908461264
Accuracy of the network on the validation images: 60.04%
Epoch 8, Loss: 0.9786674229702207
Accuracy of the network on the validation images: 59.94%
Epoch 9, Loss: 0.9731361311254725
Accuracy of the network on the validation images: 60.08%
Epoch 10, Loss: 0.9688391972071072
Accuracy of the network on the validation images: 60.18%
