In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets

In [4]:
USE_CUDA=torch.cuda.is_available()
DEVICE=torch.device("cuda" if USE_CUDA else "cpu")

In [5]:
EPOCHS=40
BATCH_SIZE=64

In [20]:
train_loader=torch.utils.data.DataLoader(
    datasets.MNIST("D:\Datasets\MNIST_torch",
                  train=True,
                   download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()])
                  ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [21]:
test_loader=torch.utils.data.DataLoader(
    datasets.MNIST("D:\Datasets\MNIST_torch",
                  train=False,
                   download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()])
                  ),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [22]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1=nn.Conv2d(1, 10, kernel_size=5)
        self.conv2=nn.Conv2d(10,20, kernel_size=5)
        self.drop=nn.Dropout2d()
        self.fc1=nn.Linear(320,50)
        self.fc2=nn.Linear(50,10)
    
    def forward(self,x):
        x=F.relu(F.max_pool2d(self.conv1(x),2))
        x=F.relu(F.max_pool2d(self.conv2(x),2))
        x=x.view(-1,320)
        x=F.relu(self.fc1(x))
        x=self.drop(x)
        x=self.fc2(x)
        return F.log_softmax(x, dim=1)

In [23]:
model=CNN().to(DEVICE)
optimizer=optim.SGD(model.parameters(),lr=0.01, momentum=0.5)

In [24]:
def train(model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target=data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output=model(data)
        loss=F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx%200==0:
            print("TRAIN EPOCH: {} [{}/{} ({:.0f}%)]\tLOSS: {:.6f}".format(epoch, batch_idx*len(data), len(train_loader.dataset),100.*batch_idx/len(train_loader), loss.item()))

In [25]:
def evaluate(model, test_loader):
    model.eval()
    test_loss=0
    correct=0
    with torch.no_grad():
        for data, target in test_loader:
            data, target=data.to(DEVICE), target.to(DEVICE)
            output=model(data)
           
            test_loss+=F.cross_entropy(output, target, reduction='sum').item()
            pred=output.max(1,keepdim=True)[1]
            correct+=pred.eq(target.view_as(pred)).sum().item()
    test_loss/=len(test_loader.dataset)
    test_acc=100*correct/len(test_loader.dataset)
    return test_loss, test_acc

In [26]:
for epoch in range(1, EPOCHS+1):
    train(model, train_loader, optimizer, epoch)
    test_loss, test_acc=evaluate(model, test_loader)
    print("{} TEST LOSS: {:.4f}, ACCURACY: {:.2f}%".format(epoch, test_loss, test_acc))

1 TEST LOSS: 0.3166, ACCURACY: 91.18%
2 TEST LOSS: 0.1665, ACCURACY: 95.12%
3 TEST LOSS: 0.1063, ACCURACY: 96.77%
4 TEST LOSS: 0.0871, ACCURACY: 97.34%
5 TEST LOSS: 0.0773, ACCURACY: 97.60%
6 TEST LOSS: 0.0700, ACCURACY: 97.90%
7 TEST LOSS: 0.0631, ACCURACY: 97.97%
8 TEST LOSS: 0.0585, ACCURACY: 98.08%
9 TEST LOSS: 0.0606, ACCURACY: 98.04%
10 TEST LOSS: 0.0519, ACCURACY: 98.38%
11 TEST LOSS: 0.0498, ACCURACY: 98.50%
12 TEST LOSS: 0.0483, ACCURACY: 98.41%
13 TEST LOSS: 0.0476, ACCURACY: 98.52%
14 TEST LOSS: 0.0460, ACCURACY: 98.64%
15 TEST LOSS: 0.0463, ACCURACY: 98.57%
16 TEST LOSS: 0.0451, ACCURACY: 98.55%
17 TEST LOSS: 0.0419, ACCURACY: 98.70%
18 TEST LOSS: 0.0418, ACCURACY: 98.63%
19 TEST LOSS: 0.0466, ACCURACY: 98.47%
20 TEST LOSS: 0.0398, ACCURACY: 98.74%
21 TEST LOSS: 0.0429, ACCURACY: 98.62%
22 TEST LOSS: 0.0411, ACCURACY: 98.72%
23 TEST LOSS: 0.0358, ACCURACY: 98.92%
24 TEST LOSS: 0.0366, ACCURACY: 98.88%
25 TEST LOSS: 0.0351, ACCURACY: 98.89%
26 TEST LOSS: 0.0368, ACCURACY: 98

29 TEST LOSS: 0.0358, ACCURACY: 98.91%
30 TEST LOSS: 0.0360, ACCURACY: 98.83%
31 TEST LOSS: 0.0353, ACCURACY: 98.95%
32 TEST LOSS: 0.0376, ACCURACY: 98.86%
33 TEST LOSS: 0.0385, ACCURACY: 98.85%
34 TEST LOSS: 0.0367, ACCURACY: 98.92%
35 TEST LOSS: 0.0383, ACCURACY: 98.90%
36 TEST LOSS: 0.0379, ACCURACY: 98.82%
37 TEST LOSS: 0.0347, ACCURACY: 98.94%
38 TEST LOSS: 0.0369, ACCURACY: 98.89%
39 TEST LOSS: 0.0348, ACCURACY: 98.93%
40 TEST LOSS: 0.0342, ACCURACY: 98.99%


In [28]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.1028,  0.0519,  0.0935,  0.1571,  0.2466],
                        [-0.1034,  0.0506, -0.2231, -0.1117,  0.3850],
                        [ 0.1197, -0.0148, -0.2670,  0.0861,  0.2948],
                        [ 0.0109, -0.2474, -0.0864, -0.0895,  0.1981],
                        [ 0.0648, -0.1938, -0.1037, -0.1835,  0.1341]]],
              
              
                      [[[-0.1980, -0.3880, -0.4325, -0.3353, -0.1884],
                        [-0.3058, -0.3909, -0.1737, -0.0572,  0.1178],
                        [-0.4570, -0.3614,  0.1462,  0.3964,  0.3433],
                        [-0.3664,  0.2653,  0.4066,  0.2591,  0.1555],
                        [-0.0875,  0.1143,  0.2701,  0.0564, -0.0328]]],
              
              
                      [[[ 0.3599,  0.0803,  0.2776,  0.2668, -0.0025],
                        [ 0.2425,  0.8594,  0.9143,  0.7620,  0.4868],
                        [ 0.4446,  0.5722,  0.4896,  0

In [29]:
torch.save(model.state_dict(), './MNIST_CNN_model.pt')
print('state_dict format of the model: {}'.format(model.state_dict()))

state_dict format of the model: OrderedDict([('conv1.weight', tensor([[[[-0.1028,  0.0519,  0.0935,  0.1571,  0.2466],
          [-0.1034,  0.0506, -0.2231, -0.1117,  0.3850],
          [ 0.1197, -0.0148, -0.2670,  0.0861,  0.2948],
          [ 0.0109, -0.2474, -0.0864, -0.0895,  0.1981],
          [ 0.0648, -0.1938, -0.1037, -0.1835,  0.1341]]],


        [[[-0.1980, -0.3880, -0.4325, -0.3353, -0.1884],
          [-0.3058, -0.3909, -0.1737, -0.0572,  0.1178],
          [-0.4570, -0.3614,  0.1462,  0.3964,  0.3433],
          [-0.3664,  0.2653,  0.4066,  0.2591,  0.1555],
          [-0.0875,  0.1143,  0.2701,  0.0564, -0.0328]]],


        [[[ 0.3599,  0.0803,  0.2776,  0.2668, -0.0025],
          [ 0.2425,  0.8594,  0.9143,  0.7620,  0.4868],
          [ 0.4446,  0.5722,  0.4896,  0.5399,  0.5699],
          [-0.2119, -0.5641, -0.4217, -0.1991,  0.1462],
          [-0.7691, -0.9072, -0.5793, -0.2567,  0.0275]]],


        [[[ 0.3823,  0.2902,  0.0255,  0.2393,  0.3351],
          [ 0.

         0.5387, -0.3218]))])


In [30]:
load_model=CNN()
load_model.load_state_dict(torch.load('./MNIST_CNN_model.pt'))
load_model.eval()
test_loss, test_acc=evaluate(model, test_loader)
print("{} TEST LOSS: {:.4f}, ACCURACY: {:.2f}%".format(epoch, test_loss, test_acc))

40 TEST LOSS: 0.0342, ACCURACY: 98.99%
