In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
from torch import optim
from pathlib import Path

%matplotlib inline

In [2]:
data_path = os.path.join(str(Path.home()), 'dev', 'pytorch-data', 'mnist')
train = tv.datasets.MNIST(data_path,
                          train=True,
                          download=True,
                          transform=transforms.ToTensor())
train_set = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True)

test = tv.datasets.MNIST(data_path,
                          train=False,
                          download=True,
                          transform=transforms.ToTensor())
test_set = torch.utils.data.DataLoader(test, batch_size=128, shuffle=True)

In [3]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(64 * 5 * 5, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)
params = list()
for n, p in net.named_parameters():
    print(n, p.size())

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=1600, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)
conv1.weight torch.Size([32, 1, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([64, 32, 3, 3])
conv2.bias torch.Size([64])
fc1.weight torch.Size([500, 1600])
fc1.bias torch.Size([500])
fc2.weight torch.Size([10, 500])
fc2.bias torch.Size([10])


In [4]:
device = torch.device(0)
print(torch.cuda.device_count())

1


In [7]:
def eval(net):
  correct, total = 0, 0
  with torch.no_grad():
      for inputs, labels in test_set:
          inputs = inputs.to(device)
          labels = labels.to(device)
          logits = net(inputs)
          _, pred = torch.max(logits, 1)
          total += len(labels)
          correct += torch.sum(labels == pred).item()
  print("Accuracy: %.2f%%" % ((correct / total) * 100))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-5, momentum=0.9)
net.to(device)
for epoch in range(300):  
    net.train()
    running_loss, num = 0.0, 0
    t1 = time.time()
    for i, data in enumerate(train_set, 0):
        
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()   
        
        optimizer.step()
        
        running_loss += loss.item()
        num += inputs.size(0)

    t2 = time.time()
    print('Epoch %d Loss: %f Time: %f' % (epoch + 1, running_loss / num, t2 - t1))
    net.eval()
    eval(net)
print('Finished Training')

Epoch 1 Loss: 0.000266 Time: 5.980066
Accuracy: 98.87%
Epoch 2 Loss: 0.000268 Time: 5.997501
Accuracy: 98.87%
Epoch 3 Loss: 0.000258 Time: 5.947529
Accuracy: 98.87%
Epoch 4 Loss: 0.000270 Time: 5.950980
Accuracy: 98.87%
Epoch 5 Loss: 0.000262 Time: 5.954792
Accuracy: 98.86%
Epoch 6 Loss: 0.000273 Time: 5.941078
Accuracy: 98.86%
Epoch 7 Loss: 0.000266 Time: 5.959959
Accuracy: 98.86%
Epoch 8 Loss: 0.000262 Time: 5.941016
Accuracy: 98.85%
Epoch 9 Loss: 0.000267 Time: 5.966096
Accuracy: 98.86%
Epoch 10 Loss: 0.000259 Time: 5.982173
Accuracy: 98.86%
Epoch 11 Loss: 0.000277 Time: 6.006250
Accuracy: 98.87%
Epoch 12 Loss: 0.000265 Time: 5.965910
Accuracy: 98.87%


KeyboardInterrupt: 