In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Hyper-parameter
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs =10
batch_size=256
learning_rate =0.001

In [13]:
# MINST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, transform=transforms.ToTensor(), download=False)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


In [5]:
# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        assert (x.shape[1]==input_size)
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [6]:
# Model & Loss Func & Optimizer
model = NeuralNet(input_size=input_size, hidden_size=hidden_size, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer =torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # get Input &GT (send to Device)
        images = images.reshape(-1, 28*28)
        labels = labels.to(device)
        
        # get OUTPUT  through Net and Loss
        output = model(images)
        loss = criterion(output, labels)
        
        # Zero & Backward & Step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Logging
        if (i+1) % 100 ==0:
            print ('Epoch [{}/{}], Step[{}/{}], Loss:{:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

Epoch [0/10], Step[100/235], Loss:0.2451
Epoch [0/10], Step[200/235], Loss:0.3065
Epoch [1/10], Step[100/235], Loss:0.1103
Epoch [1/10], Step[200/235], Loss:0.1594
Epoch [2/10], Step[100/235], Loss:0.1583
Epoch [2/10], Step[200/235], Loss:0.1112
Epoch [3/10], Step[100/235], Loss:0.1131
Epoch [3/10], Step[200/235], Loss:0.0639
Epoch [4/10], Step[100/235], Loss:0.0593
Epoch [4/10], Step[200/235], Loss:0.1033
Epoch [5/10], Step[100/235], Loss:0.0797
Epoch [5/10], Step[200/235], Loss:0.0655
Epoch [6/10], Step[100/235], Loss:0.0611
Epoch [6/10], Step[200/235], Loss:0.0325
Epoch [7/10], Step[100/235], Loss:0.0201
Epoch [7/10], Step[200/235], Loss:0.0417
Epoch [8/10], Step[100/235], Loss:0.0348
Epoch [8/10], Step[200/235], Loss:0.0324
Epoch [9/10], Step[100/235], Loss:0.0222
Epoch [9/10], Step[200/235], Loss:0.0214


In [15]:
# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(test_loader):
        # get Input &GT (send to Device)
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        # Output: forward pass throuhg model 
        outputs = model(images)
        # extract Prediction
        _, predicted = torch.max(outputs.data, 1)
        # Precision from Prediction vs GT
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print ("[{}]Accuracy:{:5f}%".format(i+1, 100*correct/total))

torch.save(model.state_dict(),'model.ckpt')

[1]Accuracy:99.609375%
[2]Accuracy:98.828125%
[3]Accuracy:98.437500%
[4]Accuracy:98.339844%
[5]Accuracy:97.812500%
[6]Accuracy:97.656250%
[7]Accuracy:97.600446%
[8]Accuracy:97.558594%
[9]Accuracy:97.309028%
[10]Accuracy:97.343750%
[11]Accuracy:97.336648%
[12]Accuracy:97.428385%
[13]Accuracy:97.566106%
[14]Accuracy:97.572545%
[15]Accuracy:97.473958%
[16]Accuracy:97.363281%
[17]Accuracy:97.357537%
[18]Accuracy:97.395833%
[19]Accuracy:97.388980%
[20]Accuracy:97.460938%
[21]Accuracy:97.563244%
[22]Accuracy:97.638494%
[23]Accuracy:97.656250%
[24]Accuracy:97.574870%
[25]Accuracy:97.656250%
[26]Accuracy:97.596154%
[27]Accuracy:97.656250%
[28]Accuracy:97.739955%
[29]Accuracy:97.804418%
[30]Accuracy:97.877604%
[31]Accuracy:97.946069%
[32]Accuracy:97.961426%
[33]Accuracy:98.011364%
[34]Accuracy:98.046875%
[35]Accuracy:98.102679%
[36]Accuracy:98.111979%
[37]Accuracy:98.152449%
[38]Accuracy:98.139391%
[39]Accuracy:98.086939%
[40]Accuracy:98.090000%
