In [89]:
import torch
from torch.nn import Module
from PIL import Image
from torch import nn
import numpy as np
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

In [90]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()
        
    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)

        return y

In [91]:
model = Model()

In [94]:
def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Running on {device}")
    batch_size = 256
    train_dataset = mnist.MNIST(root='data/train', train=True, transform=ToTensor(), download=True)
    test_dataset = mnist.MNIST(root='data/test', train=False, transform=ToTensor(), download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    LeNet = model.to(device)
    sgd = SGD(LeNet.parameters(), lr=1e-1)
    loss_fn = CrossEntropyLoss()
    epoch = 20
    prev_acc = 0
    
    for cur_epoch in range(epoch):
        LeNet.train()
        for idx, (train_x, train_label) in enumerate(train_loader):
            train_x = train_x.to(device)
            train_label = train_label.to(device)
            sgd.zero_grad()
            predict_y = LeNet(train_x.float())
            loss = loss_fn(predict_y, train_label.long())
            loss.backward()
            sgd.step()
        
        all_correct_num = 0
        all_sample_num = 0
        LeNet.eval()
        
        for idx, (test_x, test_label) in enumerate(test_loader):
            test_x = test_x.to(device)
            test_label = test_label.to(device)
            predict_y = LeNet(test_x.float()).detach()
            predict_y = torch.argmax(predict_y, dim=-1)
            current_correct_num = predict_y == test_label
            all_correct_num += np.sum(current_correct_num.to('cpu').numpy(), axis=-1)
            all_sample_num += current_correct_num.shape[0]
        acc = all_correct_num/all_sample_num
        print(f"Epoch: {cur_epoch}\tAccuracy: {acc}\tLoss: {loss}")
        
        if np.abs(acc - prev_acc)< 1e-4:
            break
        prev_acc = acc
        

In [95]:
train()

Running on cuda
Epoch: 0	Accuracy: 0.8591	Loss: 0.4663788378238678
Epoch: 1	Accuracy: 0.9661	Loss: 0.19467996060848236
Epoch: 2	Accuracy: 0.9742	Loss: 0.17733533680438995
Epoch: 3	Accuracy: 0.9776	Loss: 0.16273649036884308
Epoch: 4	Accuracy: 0.9789	Loss: 0.15471391379833221
Epoch: 5	Accuracy: 0.9802	Loss: 0.14941798150539398
Epoch: 6	Accuracy: 0.9774	Loss: 0.1411147564649582
Epoch: 7	Accuracy: 0.9789	Loss: 0.13949540257453918
Epoch: 8	Accuracy: 0.98	Loss: 0.1371331810951233
Epoch: 9	Accuracy: 0.9808	Loss: 0.13482840359210968
Epoch: 10	Accuracy: 0.9816	Loss: 0.13215021789073944
Epoch: 11	Accuracy: 0.9806	Loss: 0.130820631980896
Epoch: 12	Accuracy: 0.9818	Loss: 0.12964124977588654
Epoch: 13	Accuracy: 0.9818	Loss: 0.128016397356987


In [96]:
img = Image.open('data/img_3.jpg') 
img_tensor = ToTensor()(img).unsqueeze(0).to('cuda')

print(torch.argmax(model(img_tensor)))

tensor(9, device='cuda:0')
