In [9]:
import torch
import torchvision as tv
import torchsummary

In [3]:
samples0, samples1 = 60000, 10000
classes = 10

source0 = tv.datasets.MNIST("../../MNIST", train = True, download = False)
source1 = tv.datasets.MNIST("../../MNIST", train = False, download = False)
DATA0 = source0.data.unsqueeze(1).float().cuda()
DATA1 = source1.data.unsqueeze(1).float().cuda()
TARGET0 = source0.targets.cuda()
TARGET1 = source1.targets.cuda()

In [33]:
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, 5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(8, 16, 5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Flatten(),
    torch.nn.Linear(16*4*4, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 10)).cuda()
variables = model.parameters()

torchsummary.summary(model, input_size=DATA0.shape[1:])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 24, 24]             208
              ReLU-2            [-1, 8, 24, 24]               0
         MaxPool2d-3            [-1, 8, 12, 12]               0
            Conv2d-4             [-1, 16, 8, 8]           3,216
              ReLU-5             [-1, 16, 8, 8]               0
         MaxPool2d-6             [-1, 16, 4, 4]               0
           Flatten-7                  [-1, 256]               0
            Linear-8                  [-1, 128]          32,896
              ReLU-9                  [-1, 128]               0
           Linear-10                   [-1, 32]           4,128
             ReLU-11                   [-1, 32]               0
           Linear-12                   [-1, 10]             330
Total params: 40,778
Trainable params: 40,778
Non-trainable params: 0
---------------------------------

In [34]:
batch = 1000
optimizer = torch.optim.Adam(variables)
for epoch in range(100):
    LOSS0 = torch.zeros((), device = "cuda")
    ACCURACY0 = torch.zeros((), device = "cuda")
    count0 = 0
    for index in range(0, samples0, batch):
        optimizer.zero_grad()
        DATA = DATA0[index : index + batch]
        TARGET = TARGET0[index : index + batch]
        count = TARGET.size(0)
        ACTIVATION = model(DATA)
        LOSS = torch.nn.functional.cross_entropy(ACTIVATION, TARGET)
        LOSS0 += LOSS * count
        VALUE = torch.argmax(ACTIVATION, 1)
        ACCURACY0 += torch.sum(VALUE == TARGET)
        count0 += count
        LOSS.backward()
        optimizer.step()
    LOSS0 /= count0
    ACCURACY0 /= count0
    with torch.no_grad():
        LOSS1 = torch.zeros((), device = "cuda")
        ACCURACY1 = torch.zeros((), device = "cuda")
        count1 = 0
        for index in range(0, samples1, batch):
            DATA = DATA1[index : index + batch]
            TARGET = TARGET1[index : index + batch]
            ACTIVATION = model(DATA)
            LOSS1 += torch.nn.functional.cross_entropy(ACTIVATION, TARGET, reduction = "sum")
            VALUE = torch.argmax(ACTIVATION, 1)
            ACCURACY1 += torch.sum(VALUE == TARGET)
            count1 += TARGET.size(0)
        LOSS1 /= count1
        ACCURACY1 /= count1
    print("%5d %12.3f %4.3f %12.3f %4.3f" % \
          (epoch, LOSS0, ACCURACY0, LOSS1, ACCURACY1), flush = True)

#parameters: 46 090 + 208 = 46 298
#accuracy: train: 1000 test: 981

#this net with one convolutional and one dense layer performs better than two-layer dense net but has over 17 times less parameters.

    0        0.721 0.788        0.151 0.953
    1        0.135 0.959        0.092 0.972
    2        0.088 0.973        0.071 0.978
    3        0.065 0.980        0.062 0.981
    4        0.052 0.984        0.061 0.981
    5        0.043 0.987        0.057 0.983
    6        0.036 0.989        0.053 0.984
    7        0.030 0.991        0.054 0.984
    8        0.026 0.992        0.059 0.983
    9        0.025 0.992        0.067 0.982
   10        0.026 0.991        0.066 0.982
   11        0.026 0.991        0.074 0.980
   12        0.023 0.993        0.062 0.982
   13        0.020 0.993        0.050 0.986
   14        0.018 0.994        0.052 0.986
   15        0.015 0.995        0.067 0.982
   16        0.016 0.994        0.065 0.984
   17        0.016 0.995        0.069 0.984
   18        0.012 0.996        0.057 0.987
   19        0.008 0.997        0.047 0.989
   20        0.008 0.997        0.053 0.989
   21        0.008 0.997        0.054 0.988
   22        0.008 0.997        

KeyboardInterrupt: 