기본정의 & 하이퍼 파라미터 값 정의

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

batch_size = 32
test_batch_size = 1000

epochs = 50
lr = 0.01           # learning rate
momentum = 0.5      # optimizer parameter
seed = 1

log_interval = 200

GPU 사용 여부

In [3]:
no_cuda = False
use_cuda = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

torch.manual_seed(seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}


data 로드 & 전처리 훈련 data 학습용(80%), 검증용(20%)로 나누기

In [4]:
def load_data(train_file_path, test_file_path):
    transform = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST(train_file_path, train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(test_file_path, train=False, download=True, transform=transform)

    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

    return train_loader, val_loader, test_loader

신경망 모델 정의

In [5]:
class NeuralNetwork(nn.Module):
    def __init__(self, m1, m2):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(784, m1)
        self.fc2 = nn.Linear(m1, m2)

    def forward(self, x):
        x = x.float()
        h1 = torch.relu(self.fc1(x.view(-1, 784)))
        h2 = self.fc2(h1)
        return torch.log_softmax(h2, dim=1)

train함수&validate함수&test함수

In [6]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def validate(model, device, val_loader):
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_loss /= len(val_loader.dataset)
    val_accuracy = 100. * correct / len(val_loader.dataset)
    return val_loss, val_accuracy

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy

data 로드, 모델,optimizer설정

In [7]:
train_file_path = 'mnist_train.txt'
test_file_path = 'mnist_test.txt'
train_loader, val_loader, test_loader = load_data(train_file_path, test_file_path)

model = NeuralNetwork(m1=16, m2=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 119411415.17it/s]

Extracting mnist_train.txt/MNIST/raw/train-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 91630630.73it/s]


Extracting mnist_train.txt/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 32998498.90it/s]

Extracting mnist_train.txt/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_train.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 14142931.53it/s]


Extracting mnist_train.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_train.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 147469774.96it/s]

Extracting mnist_test.txt/MNIST/raw/train-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 137186516.22it/s]


Extracting mnist_test.txt/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 126336111.15it/s]

Extracting mnist_test.txt/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_test.txt/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 25232488.43it/s]

Extracting mnist_test.txt/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_test.txt/MNIST/raw






훈련&성능 도출&test

In [8]:
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    val_loss, val_accuracy = validate(model, device, val_loader)
    print(f'Validation set: Average loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')

test_loss, test_accuracy = test(model, device, test_loader)
print(f'Model Test set: Average loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%')


Validation set: Average loss: 0.3940, Accuracy: 89.06%
Validation set: Average loss: 0.3241, Accuracy: 90.81%
Validation set: Average loss: 0.3015, Accuracy: 91.41%
Validation set: Average loss: 0.2841, Accuracy: 91.83%
Validation set: Average loss: 0.2714, Accuracy: 92.20%
Validation set: Average loss: 0.2607, Accuracy: 92.47%
Validation set: Average loss: 0.2460, Accuracy: 92.99%
Validation set: Average loss: 0.2422, Accuracy: 93.04%
Validation set: Average loss: 0.2337, Accuracy: 93.41%
Validation set: Average loss: 0.2292, Accuracy: 93.58%
Validation set: Average loss: 0.2233, Accuracy: 93.66%
Validation set: Average loss: 0.2218, Accuracy: 93.61%
Validation set: Average loss: 0.2164, Accuracy: 93.89%
Validation set: Average loss: 0.2128, Accuracy: 93.90%
Validation set: Average loss: 0.2066, Accuracy: 94.09%
Validation set: Average loss: 0.2055, Accuracy: 94.13%
Validation set: Average loss: 0.2026, Accuracy: 94.14%
Validation set: Average loss: 0.2011, Accuracy: 94.22%
Validation