In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
input_size = 28 * 28
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [3]:
train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw



In [4]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
model = nn.Linear(input_size, num_classes)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [7]:
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    images = images.reshape(-1, input_size)

    outputs = model(images)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0:
      print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
      


Epoch [1/5], Step [100/600], Loss: 2.1898
Epoch [1/5], Step [200/600], Loss: 2.1073
Epoch [1/5], Step [300/600], Loss: 1.9894
Epoch [1/5], Step [400/600], Loss: 1.9192
Epoch [1/5], Step [500/600], Loss: 1.8701
Epoch [1/5], Step [600/600], Loss: 1.8120
Epoch [2/5], Step [100/600], Loss: 1.7258
Epoch [2/5], Step [200/600], Loss: 1.6982
Epoch [2/5], Step [300/600], Loss: 1.5565
Epoch [2/5], Step [400/600], Loss: 1.5366
Epoch [2/5], Step [500/600], Loss: 1.4595
Epoch [2/5], Step [600/600], Loss: 1.4131
Epoch [3/5], Step [100/600], Loss: 1.4087
Epoch [3/5], Step [200/600], Loss: 1.3567
Epoch [3/5], Step [300/600], Loss: 1.3421
Epoch [3/5], Step [400/600], Loss: 1.3295
Epoch [3/5], Step [500/600], Loss: 1.3077
Epoch [3/5], Step [600/600], Loss: 1.2919
Epoch [4/5], Step [100/600], Loss: 1.2867
Epoch [4/5], Step [200/600], Loss: 1.2249
Epoch [4/5], Step [300/600], Loss: 1.1578
Epoch [4/5], Step [400/600], Loss: 1.1015
Epoch [4/5], Step [500/600], Loss: 1.1178
Epoch [4/5], Step [600/600], Loss:

In [8]:
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    images = images.reshape(-1, input_size)
    outputs = model(images)

    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

  print('Accuracy of the model on the 10000 test images : {} %'.format(100 * correct / total))

torch.save(model.state_dict(), 'model.ckpt')

Accuracy of the model on the 10000 test images : 82.13999938964844 %
