In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.init

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'available device: {device}')
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

available device: cuda


# Prepare data & data loader

In [2]:
mnist_train = torchvision.datasets.MNIST(root='MNIST_data/',
                                         train=True,
                                         transform=transforms.ToTensor(),
                                         download=True)
mnist_test = torchvision.datasets.MNIST(root='MNIST_data/',
                                        train=False,
                                        transform=transforms.ToTensor(),
                                        download=True)

In [3]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                          batch_size=batch_size,
                                          shuffle=True)

# Define my Neural Network

In [4]:
class myNN(torch.nn.Module):
    def __init__(self):
        super(myNN, self).__init__()

        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = torch.nn.ReLU()
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = torch.nn.ReLU()
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc = torch.nn.Linear(7 * 7 * 64, 10, bias=True)
        self.fc_bn = torch.nn.BatchNorm1d(10)
        torch.nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x):

        out = self.conv1(x)
        out = self.relu1(out)
        out = self.pool1(out)

        out = self.conv2(out)
        out = self.relu2(out)
        out = self.pool2(out)

        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.fc_bn(out)
        return out

# Set training protocols

In [5]:
model = myNN().to(device=device)
learning_rate = 0.001
training_epochs = 5

criterion = torch.nn.CrossEntropyLoss().to(device=device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Training myNN

In [6]:
total_batch = len(train_loader)
print('Learning Started!')

for epoch in range(training_epochs):
    running_loss = 0.0
    for data, labels in train_loader:
        data = data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() / total_batch
    print('[Epoch: {:>4}] loss = {:>.9}'.format(epoch+1, running_loss))

print('Learning Finished!')

Learning Started!
[Epoch:    1] loss = 0.922220454
[Epoch:    2] loss = 0.6599126
[Epoch:    3] loss = 0.573877717
[Epoch:    4] loss = 0.51543976
[Epoch:    5] loss = 0.472939699
Learning Finished!


# Validate myNN

In [7]:
correct = 0
total = 0

model.eval()

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of myNN on the test set: %.4f %%' % (100 * correct/total))

Accuracy of myNN on the test set: 95.7900 %


# Save the model

In [8]:
torch.save(model.state_dict(), 'model.ckpt')

In [9]:
!ls -lh

total 264K
-rw-r--r-- 1 etriai02 etriai02 7.6K 10월 13 16:03 HelloMnist.ipynb
-rw-r--r-- 1 etriai02 etriai02  33K 10월 13 14:27 HelloTensor.ipynb
-rw-r--r-- 1 etriai02 etriai02 9.7K 10월 13 15:18 HelloTorch.ipynb
-rw-rw-r-- 1 etriai02 etriai02  535 10월 13 10:26 main.py
drwxr-xr-x 3 etriai02 etriai02 4.0K 10월 13 15:24 MNIST_data
-rw-r--r-- 1 etriai02 etriai02 200K 10월 13 16:04 model.ckpt
