In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from torchvision import datasets, transforms
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import tqdm

In [3]:
# Model C [table 4] trained on MNIST dataset
class ModelC(nn.Module):

    # Specify layers
    def __init__(self):
        super(ModelC, self).__init__()
        self.conv1 = nn.Conv2d(1, 128, 3) # mnist has single channel input
        self.conv2 = nn.Conv2d(128, 64, 3)
        self.drop1 = nn.Dropout(0.25)
        self.fc1   = nn.Linear(64*24*24, 128)
        self.drop2 = nn.Dropout(0.50)
        self.fc2   = nn.Linear(128, 10)

    def forward(self, x):
        x = f.relu(self.conv1(x)) # Conv(128, 3, 3) + Relu
        x = f.relu(self.conv2(x)) # Conv(64, 3, 3) + Relu
        x = self.drop1(x)         # Dropout(0.25)
        x = torch.flatten(x, 1)   # Flatten
        x = f.relu(self.fc1(x))   # FC(128) + Relu
        x = self.drop2(x)         # Dropout(0.50)
        x = f.relu(self.fc2(x))   # FC(10)
        return f.log_softmax(x, dim=1) # + Softmax

In [4]:
# Loading dataset
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # mean, std of mnist dataset
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                    transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                    transform=transform)

training_data = torch.utils.data.DataLoader(dataset1, batch_size=64, shuffle=True)
test_data = torch.utils.data.DataLoader(dataset2, batch_size=1000)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 91419794.37it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 27828094.15it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 29859685.75it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3153017.01it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






In [5]:
model = ModelC()
print(model)

lr = 0.01 # Learning rate
N = 10 # epochs
crit = nn.CrossEntropyLoss() # loss criterion
optimizer = optim.SGD(model.parameters(), lr=lr)
scheduler_lr = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# For speed
if torch.cuda.is_available():
      device = torch.device('cuda:0')
else:
      device = torch.device('cpu')

print("Training...")

# Main loop
for epoch in tqdm.trange(N):
  total = 0
  correct = 0
  model.train() # training mode
  model.to(device)

  for i, (data, target) in enumerate(training_data):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad() # Reset gradient
    out = model(data) # Make prediction on data
    loss = crit(out, target) # Calculate loss
    loss.backward() # Backprop
    optimizer.step() # Optimize

    _, pred = torch.max(out.data, 1)
    total += target.size(0)
    correct += (pred == target).sum().item()

    if i % 100 == 0:
      accuracy = (correct / total) * 100
      print(f'[Epoch: {epoch}] {len(data) * i} / {len(training_data.dataset)}, loss: {loss.item()}, acc: {accuracy}')

ModelC(
  (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
  (drop1): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=36864, out_features=128, bias=True)
  (drop2): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
Training...


  0%|          | 0/10 [00:00<?, ?it/s]

[Epoch: 0] 0 / 60000, loss: 2.2821712493896484, acc: 14.0625
[Epoch: 0] 6400 / 60000, loss: 0.4234318435192108, acc: 67.07920792079209
[Epoch: 0] 12800 / 60000, loss: 0.4425220787525177, acc: 75.75404228855722
[Epoch: 0] 19200 / 60000, loss: 0.5253005027770996, acc: 79.671926910299
[Epoch: 0] 25600 / 60000, loss: 0.39111363887786865, acc: 81.87733790523691
[Epoch: 0] 32000 / 60000, loss: 0.34069114923477173, acc: 83.30214570858283
[Epoch: 0] 38400 / 60000, loss: 0.39430785179138184, acc: 84.32820299500831
[Epoch: 0] 44800 / 60000, loss: 0.5216675400733948, acc: 85.24875178316691
[Epoch: 0] 51200 / 60000, loss: 0.2616156339645386, acc: 85.91409176029963
[Epoch: 0] 57600 / 60000, loss: 0.4178941249847412, acc: 86.50804661487237


 10%|█         | 1/10 [00:23<03:31, 23.48s/it]

[Epoch: 1] 0 / 60000, loss: 0.2555570602416992, acc: 90.625
[Epoch: 1] 6400 / 60000, loss: 0.3053061366081238, acc: 91.92450495049505
[Epoch: 1] 12800 / 60000, loss: 0.22916056215763092, acc: 92.13308457711443
[Epoch: 1] 19200 / 60000, loss: 0.3429703712463379, acc: 92.37437707641196
[Epoch: 1] 25600 / 60000, loss: 0.1457960158586502, acc: 92.42518703241896
[Epoch: 1] 32000 / 60000, loss: 0.33888542652130127, acc: 92.48378243512974
[Epoch: 1] 38400 / 60000, loss: 0.09649951010942459, acc: 92.60607321131448
[Epoch: 1] 44800 / 60000, loss: 0.19585682451725006, acc: 92.75588445078459
[Epoch: 1] 51200 / 60000, loss: 0.23962800204753876, acc: 92.84878277153558
[Epoch: 1] 57600 / 60000, loss: 0.23286055028438568, acc: 92.92105993340732


 20%|██        | 2/10 [00:43<02:49, 21.20s/it]

[Epoch: 2] 0 / 60000, loss: 0.22826898097991943, acc: 95.3125
[Epoch: 2] 6400 / 60000, loss: 0.16089078783988953, acc: 94.69368811881188
[Epoch: 2] 12800 / 60000, loss: 0.2512710988521576, acc: 94.59732587064677
[Epoch: 2] 19200 / 60000, loss: 0.20249806344509125, acc: 94.45078903654485
[Epoch: 2] 25600 / 60000, loss: 0.16817258298397064, acc: 94.42019950124688
[Epoch: 2] 32000 / 60000, loss: 0.09645486623048782, acc: 94.48290918163673
[Epoch: 2] 38400 / 60000, loss: 0.2993870973587036, acc: 94.52995008319468
[Epoch: 2] 44800 / 60000, loss: 0.3620445728302002, acc: 94.57248573466477
[Epoch: 2] 51200 / 60000, loss: 0.2296186089515686, acc: 94.582943196005
[Epoch: 2] 57600 / 60000, loss: 0.0871594101190567, acc: 94.64657325194229


 30%|███       | 3/10 [01:03<02:25, 20.85s/it]

[Epoch: 3] 0 / 60000, loss: 0.1048174723982811, acc: 96.875
[Epoch: 3] 6400 / 60000, loss: 0.09532459080219269, acc: 95.60643564356435
[Epoch: 3] 12800 / 60000, loss: 0.16980619728565216, acc: 95.56125621890547
[Epoch: 3] 19200 / 60000, loss: 0.21912211179733276, acc: 95.62915282392026
[Epoch: 3] 25600 / 60000, loss: 0.15985719859600067, acc: 95.68266832917706
[Epoch: 3] 32000 / 60000, loss: 0.020421503111720085, acc: 95.73353293413174
[Epoch: 3] 38400 / 60000, loss: 0.1240759864449501, acc: 95.76227121464225
[Epoch: 3] 44800 / 60000, loss: 0.13596375286579132, acc: 95.77835235378032
[Epoch: 3] 51200 / 60000, loss: 0.11695452779531479, acc: 95.770911360799
[Epoch: 3] 57600 / 60000, loss: 0.2657391130924225, acc: 95.81194506104328


 40%|████      | 4/10 [01:23<02:03, 20.55s/it]

[Epoch: 4] 0 / 60000, loss: 0.10610972344875336, acc: 96.875
[Epoch: 4] 6400 / 60000, loss: 0.05447559803724289, acc: 96.03960396039604
[Epoch: 4] 12800 / 60000, loss: 0.20701764523983002, acc: 96.23756218905473
[Epoch: 4] 19200 / 60000, loss: 0.09626400470733643, acc: 96.31955980066445
[Epoch: 4] 25600 / 60000, loss: 0.1520192176103592, acc: 96.36455735660849
[Epoch: 4] 32000 / 60000, loss: 0.03420490026473999, acc: 96.42901696606786
[Epoch: 4] 38400 / 60000, loss: 0.1601506620645523, acc: 96.42782861896838
[Epoch: 4] 44800 / 60000, loss: 0.05367780476808548, acc: 96.47824536376605
[Epoch: 4] 51200 / 60000, loss: 0.0848296582698822, acc: 96.48096129837704
[Epoch: 4] 57600 / 60000, loss: 0.0718037486076355, acc: 96.58712541620422


 50%|█████     | 5/10 [01:43<01:41, 20.21s/it]

[Epoch: 5] 0 / 60000, loss: 0.05479362607002258, acc: 98.4375
[Epoch: 5] 6400 / 60000, loss: 0.20415152609348297, acc: 96.61200495049505
[Epoch: 5] 12800 / 60000, loss: 0.06474929302930832, acc: 96.88277363184079
[Epoch: 5] 19200 / 60000, loss: 0.05876540392637253, acc: 96.86461794019934
[Epoch: 5] 25600 / 60000, loss: 0.32995903491973877, acc: 96.8555174563591
[Epoch: 5] 32000 / 60000, loss: 0.08184937387704849, acc: 96.91554391217565
[Epoch: 5] 38400 / 60000, loss: 0.09227047860622406, acc: 96.97119384359401
[Epoch: 5] 44800 / 60000, loss: 0.07938012480735779, acc: 97.00205064194009
[Epoch: 5] 51200 / 60000, loss: 0.10478157550096512, acc: 97.02910424469412
[Epoch: 5] 57600 / 60000, loss: 0.06229580193758011, acc: 97.02587402885683


 60%|██████    | 6/10 [02:03<01:20, 20.17s/it]

[Epoch: 6] 0 / 60000, loss: 0.1470976322889328, acc: 93.75
[Epoch: 6] 6400 / 60000, loss: 0.02467627450823784, acc: 97.61757425742574
[Epoch: 6] 12800 / 60000, loss: 0.17705249786376953, acc: 97.55130597014924
[Epoch: 6] 19200 / 60000, loss: 0.06755498051643372, acc: 97.51349667774086
[Epoch: 6] 25600 / 60000, loss: 0.22666653990745544, acc: 97.41271820448878
[Epoch: 6] 32000 / 60000, loss: 0.041992172598838806, acc: 97.48627744510978
[Epoch: 6] 38400 / 60000, loss: 0.05864432826638222, acc: 97.47816139767055
[Epoch: 6] 44800 / 60000, loss: 0.1162153109908104, acc: 97.49019258202568
[Epoch: 6] 51200 / 60000, loss: 0.030410271137952805, acc: 97.48361423220973
[Epoch: 6] 57600 / 60000, loss: 0.02901810221374035, acc: 97.49236958934517


 70%|███████   | 7/10 [02:23<01:00, 20.15s/it]

[Epoch: 7] 0 / 60000, loss: 0.10343067348003387, acc: 96.875
[Epoch: 7] 6400 / 60000, loss: 0.059326473623514175, acc: 97.57116336633663
[Epoch: 7] 12800 / 60000, loss: 0.08785336464643478, acc: 97.43470149253731
[Epoch: 7] 19200 / 60000, loss: 0.053504884243011475, acc: 97.51349667774086
[Epoch: 7] 25600 / 60000, loss: 0.023455405607819557, acc: 97.5607855361596
[Epoch: 7] 32000 / 60000, loss: 0.10852235555648804, acc: 97.58607784431138
[Epoch: 7] 38400 / 60000, loss: 0.04188697412610054, acc: 97.61855241264558
[Epoch: 7] 44800 / 60000, loss: 0.05883811414241791, acc: 97.60164051355207
[Epoch: 7] 51200 / 60000, loss: 0.09639674425125122, acc: 97.64357053682896
[Epoch: 7] 57600 / 60000, loss: 0.033360984176397324, acc: 97.66752219755827


 80%|████████  | 8/10 [02:43<00:39, 19.97s/it]

[Epoch: 8] 0 / 60000, loss: 0.04080407693982124, acc: 98.4375
[Epoch: 8] 6400 / 60000, loss: 0.07970280945301056, acc: 97.49381188118812
[Epoch: 8] 12800 / 60000, loss: 0.03631274774670601, acc: 97.55907960199005
[Epoch: 8] 19200 / 60000, loss: 0.03714713454246521, acc: 97.74709302325581
[Epoch: 8] 25600 / 60000, loss: 0.13037285208702087, acc: 97.74781795511221
[Epoch: 8] 32000 / 60000, loss: 0.12248200178146362, acc: 97.83557884231537
[Epoch: 8] 38400 / 60000, loss: 0.028013115748763084, acc: 97.84993760399334
[Epoch: 8] 44800 / 60000, loss: 0.10951587557792664, acc: 97.93152639087019
[Epoch: 8] 51200 / 60000, loss: 0.04888442903757095, acc: 97.94787765293384
[Epoch: 8] 57600 / 60000, loss: 0.036802105605602264, acc: 97.92938401775805


 90%|█████████ | 9/10 [03:03<00:19, 19.98s/it]

[Epoch: 9] 0 / 60000, loss: 0.04221121966838837, acc: 98.4375
[Epoch: 9] 6400 / 60000, loss: 0.03043697401881218, acc: 98.26732673267327
[Epoch: 9] 12800 / 60000, loss: 0.09831144660711288, acc: 98.03327114427861
[Epoch: 9] 19200 / 60000, loss: 0.046963170170784, acc: 98.05855481727575
[Epoch: 9] 25600 / 60000, loss: 0.06391425430774689, acc: 97.99719451371571
[Epoch: 9] 32000 / 60000, loss: 0.04543492570519447, acc: 98.07572355289422
[Epoch: 9] 38400 / 60000, loss: 0.058270104229450226, acc: 98.0605241264559
[Epoch: 9] 44800 / 60000, loss: 0.11047022044658661, acc: 98.10538516405136
[Epoch: 9] 51200 / 60000, loss: 0.04282263666391373, acc: 98.08052434456928
[Epoch: 9] 57600 / 60000, loss: 0.1277025192975998, acc: 98.08199223085461


100%|██████████| 10/10 [03:23<00:00, 20.30s/it]


In [6]:
# Testing
model.eval()

print("Testing...")

total = 0
correct = 0
total_loss = 0
with torch.no_grad():
  for i, (data, target) in enumerate(test_data):
      data, target = data.to(device), target.to(device)
      optimizer.zero_grad() # Reset gradient
      out = model(data) # Make prediction on data
      loss = crit(out, target) # Calculate loss
      total_loss += loss

      _, pred = torch.max(out.data, 1)
      total += target.size(0)
      correct += (pred == target).sum().item()

  accuracy = (correct / total) * 100
  avg_loss = total_loss / len(test_data.dataset)
  print(f"Test set:\n\tAccuracy: {accuracy}%\n\tAvg. loss: {avg_loss}")


Testing...
Test set:
	Accuracy: 98.67%
	Avg. loss: 4.1530463931849226e-05


In [7]:
torch.save(model.state_dict(), "model_c.pth")