In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
# Data Loading
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))])
train_set = datasets.MNIST(root='./MNIST_data', train=True, transform=trans, download=True)
test_set = datasets.MNIST(root='./MNIST_data', train=False, transform=trans, download=True)

batch_size = 32
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_set,batch_size=batch_size, shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST_data/MNIST/raw

==>>> total trainning batch number: 1875
==>>> total testing batch number: 313


In [4]:
# Network
class MLPNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 500)
        self.fc2 = nn.Linear(500, 256)
        self.fc3 = nn.Linear(256, 10)


    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = MLPNet().to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()

In [5]:
# Training
for epoch in range(5):

    # training
    for batch_idx, (x, y) in enumerate(train_loader):

        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)        
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('epoch: ', epoch, '  batch_idx: ', batch_idx, loss.item())

    # testing
    total_cnt = 0
    correct_cnt = 0
    for batch_idx, (x, y) in enumerate(test_loader):
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = loss_fn(out, y)
        _, pred_label = torch.max(out, 1)
        total_cnt += len(x)
        correct_cnt += (pred_label == y).sum().item()

    print('=======> epoch: ', epoch,  '  accuracy: ', correct_cnt*1.0/total_cnt)


epoch:  0   batch_idx:  0 2.306971311569214
epoch:  0   batch_idx:  100 0.9001249074935913
epoch:  0   batch_idx:  200 0.3615915775299072
epoch:  0   batch_idx:  300 0.6210492253303528
epoch:  0   batch_idx:  400 0.3176538348197937
epoch:  0   batch_idx:  500 0.1373610943555832
epoch:  0   batch_idx:  600 0.2564818561077118
epoch:  0   batch_idx:  700 0.5965662598609924
epoch:  0   batch_idx:  800 0.25786879658699036
epoch:  0   batch_idx:  900 0.10482250154018402
epoch:  0   batch_idx:  1000 0.39101529121398926
epoch:  0   batch_idx:  1100 0.29534420371055603
epoch:  0   batch_idx:  1200 0.12342219799757004
epoch:  0   batch_idx:  1300 0.1853213608264923
epoch:  0   batch_idx:  1400 0.09140060842037201
epoch:  0   batch_idx:  1500 0.3391956090927124
epoch:  0   batch_idx:  1600 0.19609342515468597
epoch:  0   batch_idx:  1700 0.07297945767641068
epoch:  0   batch_idx:  1800 0.180165097117424
epoch:  1   batch_idx:  0 0.022355755791068077
epoch:  1   batch_idx:  100 0.09146755933761597