기본정의 & 하이퍼 파라미터 값 정의

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

batch_size = 32
test_batch_size = 1000

epochs = 50
lr = 0.01           # learning rate
momentum = 0.5      # optimizer parameter
seed = 1

log_interval = 200

GPU 사용 여부

In [5]:
no_cuda = False
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

torch.manual_seed(seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}



data 로드 & 전처리 훈련 data 학습용(80%), 검증용(20%)로 나누기

In [14]:
def load_data(train_file_path, test_file_path):
    transform = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST(train_file_path, train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(test_file_path, train=False, download=True, transform=transform)

    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    return train_loader, val_loader, test_loader

신경망 모델 정의

In [15]:
class NeuralNetwork(nn.Module):
    def __init__(self, m1, m2, m3):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(784, m1)
        self.fc2 = nn.Linear(m1, m2)
        self.fc3 = nn.Linear(m2, m3)

    def forward(self, x):
        x = x.float()
        h1 = torch.relu(self.fc1(x.view(-1, 784)))
        h2 = torch.relu(self.fc2(h1))
        h3 = self.fc3(h2)
        return torch.log_softmax(h3, dim=1)

train함수&validate함수&test함수

In [16]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

In [19]:
def validate(model, device, val_loader):
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_loss /= len(val_loader.dataset)
    val_accuracy = 100. * correct / len(val_loader.dataset)
    return val_loss, val_accuracy

def test(model, device, test_loader):

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy


data 로드, 모델,optimizer설정

In [20]:
train_file_path = 'mnist_train.txt'
test_file_path = 'mnist_test.txt'
train_loader, val_loader, test_loader = load_data(train_file_path, test_file_path)

model2 = NeuralNetwork(m1=256, m2=128, m3=10).to(device)
optimizer2 = optim.SGD(model2.parameters(), lr=0.01, momentum=0.5)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 127745242.84it/s]

Extracting mnist_train.txt/MNIST/raw/train-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 8988994.79it/s]


Extracting mnist_train.txt/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 30073277.14it/s]

Extracting mnist_train.txt/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2468004.76it/s]


Extracting mnist_train.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 234449940.19it/s]


Extracting mnist_test.txt/MNIST/raw/train-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 12867611.41it/s]


Extracting mnist_test.txt/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 128024646.36it/s]

Extracting mnist_test.txt/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2208756.96it/s]

Extracting mnist_test.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw






훈련&성능 도출&test

In [21]:
for epoch in range(1, epochs + 1):
    train(model2, device, train_loader, optimizer2, epoch)
    val_loss2, val_accuracy2 = validate(model2, device, val_loader)
    print(f'Validation set: Average loss: {val_loss2:.4f}, Accuracy: {val_accuracy2:.2f}%')
test_loss2, test_accuracy2 = test(model2, device, test_loader)
print(f'Model 2 Test set: Average loss: {test_loss2:.4f}, Accuracy: {test_accuracy2:.2f}%')


Validation set: Average loss: 0.3505, Accuracy: 90.03%
Validation set: Average loss: 0.2819, Accuracy: 91.70%
Validation set: Average loss: 0.2393, Accuracy: 92.92%
Validation set: Average loss: 0.1957, Accuracy: 94.29%
Validation set: Average loss: 0.1719, Accuracy: 95.01%
Validation set: Average loss: 0.1473, Accuracy: 95.64%
Validation set: Average loss: 0.1362, Accuracy: 96.07%
Validation set: Average loss: 0.1257, Accuracy: 96.29%
Validation set: Average loss: 0.1180, Accuracy: 96.39%
Validation set: Average loss: 0.1092, Accuracy: 96.82%
Validation set: Average loss: 0.1009, Accuracy: 97.00%
Validation set: Average loss: 0.0995, Accuracy: 97.02%
Validation set: Average loss: 0.0978, Accuracy: 97.05%
Validation set: Average loss: 0.0913, Accuracy: 97.22%
Validation set: Average loss: 0.0912, Accuracy: 97.37%
Validation set: Average loss: 0.0883, Accuracy: 97.35%
Validation set: Average loss: 0.0817, Accuracy: 97.49%
Validation set: Average loss: 0.0849, Accuracy: 97.50%
Validation