In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import time

BATCH_SIZE = 128
NUM_EPOCHS = 10

In [2]:
# preprocessing
normalize = transforms.Normalize(mean=[.5], std=[.5])
transform = transforms.Compose([transforms.ToTensor(), normalize])

# download and load the data
train_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# encapsulate them into dataloader form
train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [10]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__() 
        self.linear = nn.Linear(784,10)  
 
    def forward(self, x):
        outputs = self.linear(x)
        return outputs
 
 
model = SimpleNet()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [24]:
for epoch in range(NUM_EPOCHS):
    for images, labels in tqdm(train_loader):
        
        images = images.reshape(-1, 28 * 28)

        
        outputs = model(images)
        loss = criterion(outputs, labels)

        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    





correct = 0
total = 0
for images, labels in train_loader:
    images = images.reshape(-1, 28 * 28)
    outputs = model(images)
    _,predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
    
print('Training accuracy: %0.2f%%' % (100*correct/total))

correct = 0
total = 0
for images, labels in test_loader:
    images = images.reshape(-1, 28 * 28)
    outputs = model(images)
    _,predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
    
print('Testing accuracy: %0.2f%%' % (100*correct/total))




    


  0%|          | 0/468 [00:00<?, ?it/s]
  1%|          | 3/468 [00:00<00:20, 22.28it/s]
  1%|          | 5/468 [00:00<00:21, 21.49it/s]
  2%|▏         | 8/468 [00:00<00:20, 22.01it/s]
  2%|▏         | 11/468 [00:00<00:20, 22.39it/s]
  3%|▎         | 14/468 [00:00<00:20, 22.01it/s]
  3%|▎         | 16/468 [00:00<00:21, 20.92it/s]
  4%|▍         | 18/468 [00:00<00:21, 20.52it/s]
  4%|▍         | 21/468 [00:00<00:21, 20.85it/s]
  5%|▍         | 23/468 [00:01<00:22, 20.17it/s]
  5%|▌         | 25/468 [00:01<00:22, 19.43it/s]
  6%|▌         | 27/468 [00:01<00:23, 19.00it/s]
  6%|▋         | 30/468 [00:01<00:22, 19.88it/s]
  7%|▋         | 33/468 [00:01<00:21, 20.54it/s]
  8%|▊         | 36/468 [00:01<00:20, 20.99it/s]
  8%|▊         | 39/468 [00:01<00:19, 21.64it/s]
  9%|▉         | 42/468 [00:01<00:19, 21.50it/s]
 10%|▉         | 45/468 [00:02<00:20, 20.87it/s]
 10%|█         | 48/468 [00:02<00:20, 20.04it/s]
 11%|█         | 51/468 [00:02<00:20, 20.58it/s]
 12%|█▏        | 54/468 [00:02<

  1%|▏         | 6/468 [00:00<00:24, 18.68it/s]
  2%|▏         | 9/468 [00:00<00:23, 19.18it/s]
  3%|▎         | 12/468 [00:00<00:22, 19.86it/s]
  3%|▎         | 15/468 [00:00<00:22, 20.53it/s]
  4%|▍         | 18/468 [00:00<00:21, 20.85it/s]
  4%|▍         | 21/468 [00:01<00:21, 21.12it/s]
  5%|▌         | 24/468 [00:01<00:20, 21.55it/s]
  6%|▌         | 27/468 [00:01<00:20, 21.67it/s]
  6%|▋         | 30/468 [00:01<00:19, 22.09it/s]
  7%|▋         | 33/468 [00:01<00:19, 22.70it/s]
  8%|▊         | 36/468 [00:01<00:19, 22.57it/s]
  8%|▊         | 39/468 [00:01<00:19, 22.48it/s]
  9%|▉         | 42/468 [00:01<00:19, 22.13it/s]
 10%|▉         | 45/468 [00:02<00:18, 22.52it/s]
 10%|█         | 48/468 [00:02<00:18, 22.81it/s]
 11%|█         | 51/468 [00:02<00:18, 22.65it/s]
 12%|█▏        | 54/468 [00:02<00:18, 22.74it/s]
 12%|█▏        | 57/468 [00:02<00:18, 22.45it/s]
 13%|█▎        | 60/468 [00:02<00:17, 22.75it/s]
 13%|█▎        | 63/468 [00:02<00:18, 22.46it/s]
 14%|█▍        | 66/46

  4%|▍         | 20/468 [00:00<00:21, 21.23it/s]
  5%|▍         | 23/468 [00:01<00:20, 21.63it/s]
  6%|▌         | 26/468 [00:01<00:20, 21.73it/s]
  6%|▌         | 29/468 [00:01<00:20, 21.65it/s]
  7%|▋         | 32/468 [00:01<00:19, 22.58it/s]
  7%|▋         | 35/468 [00:01<00:19, 22.34it/s]
  8%|▊         | 38/468 [00:01<00:19, 22.42it/s]
  9%|▉         | 41/468 [00:01<00:19, 22.13it/s]
  9%|▉         | 44/468 [00:01<00:18, 22.58it/s]
 10%|█         | 47/468 [00:02<00:18, 22.95it/s]
 11%|█         | 50/468 [00:02<00:18, 22.90it/s]
 11%|█▏        | 53/468 [00:02<00:18, 22.71it/s]
 12%|█▏        | 56/468 [00:02<00:18, 22.84it/s]
 13%|█▎        | 59/468 [00:02<00:17, 22.98it/s]
 13%|█▎        | 62/468 [00:02<00:17, 23.24it/s]
 14%|█▍        | 65/468 [00:02<00:17, 22.79it/s]
 15%|█▍        | 68/468 [00:03<00:17, 22.38it/s]
 15%|█▌        | 71/468 [00:03<00:17, 22.71it/s]
 16%|█▌        | 74/468 [00:03<00:17, 22.28it/s]
 16%|█▋        | 77/468 [00:03<00:17, 22.08it/s]
 17%|█▋        | 80/

  5%|▍         | 23/468 [00:01<00:20, 21.81it/s]
  6%|▌         | 26/468 [00:01<00:20, 21.81it/s]
  6%|▌         | 29/468 [00:01<00:20, 21.90it/s]
  7%|▋         | 32/468 [00:01<00:19, 22.31it/s]
  7%|▋         | 35/468 [00:01<00:19, 22.45it/s]
  8%|▊         | 38/468 [00:01<00:18, 22.75it/s]
  9%|▉         | 41/468 [00:01<00:18, 22.71it/s]
  9%|▉         | 44/468 [00:01<00:18, 23.00it/s]
 10%|█         | 47/468 [00:02<00:18, 22.32it/s]
 11%|█         | 50/468 [00:02<00:18, 22.66it/s]
 11%|█▏        | 53/468 [00:02<00:18, 22.45it/s]
 12%|█▏        | 56/468 [00:02<00:18, 22.65it/s]
 13%|█▎        | 59/468 [00:02<00:18, 22.29it/s]
 13%|█▎        | 62/468 [00:02<00:18, 21.99it/s]
 14%|█▍        | 65/468 [00:02<00:18, 21.84it/s]
 15%|█▍        | 68/468 [00:03<00:18, 22.21it/s]
 15%|█▌        | 71/468 [00:03<00:17, 22.53it/s]
 16%|█▌        | 74/468 [00:03<00:16, 23.40it/s]
 16%|█▋        | 77/468 [00:03<00:17, 22.64it/s]
 17%|█▋        | 80/468 [00:03<00:17, 22.68it/s]
 18%|█▊        | 83/

  7%|▋         | 35/468 [00:01<00:20, 21.23it/s]
  8%|▊         | 38/468 [00:01<00:20, 21.49it/s]
  9%|▉         | 41/468 [00:01<00:19, 21.68it/s]
  9%|▉         | 44/468 [00:02<00:19, 21.71it/s]
 10%|█         | 47/468 [00:02<00:19, 21.79it/s]
 11%|█         | 50/468 [00:02<00:18, 22.08it/s]
 11%|█▏        | 53/468 [00:02<00:18, 23.00it/s]
 12%|█▏        | 56/468 [00:02<00:17, 23.26it/s]
 13%|█▎        | 59/468 [00:02<00:17, 22.90it/s]
 13%|█▎        | 62/468 [00:02<00:17, 22.76it/s]
 14%|█▍        | 65/468 [00:03<00:18, 21.50it/s]
 15%|█▍        | 68/468 [00:03<00:18, 21.50it/s]
 15%|█▌        | 71/468 [00:03<00:18, 21.68it/s]
 16%|█▌        | 74/468 [00:03<00:18, 21.76it/s]
 16%|█▋        | 77/468 [00:03<00:17, 21.87it/s]
 17%|█▋        | 80/468 [00:03<00:17, 21.75it/s]
 18%|█▊        | 83/468 [00:03<00:16, 22.70it/s]
 18%|█▊        | 86/468 [00:04<00:16, 22.68it/s]
 19%|█▉        | 89/468 [00:04<00:16, 22.71it/s]
 20%|█▉        | 92/468 [00:04<00:16, 23.10it/s]
 20%|██        | 95/

 12%|█▏        | 54/468 [00:02<00:18, 22.20it/s]
 12%|█▏        | 57/468 [00:02<00:18, 22.73it/s]
 13%|█▎        | 60/468 [00:02<00:18, 22.00it/s]
 13%|█▎        | 63/468 [00:02<00:18, 22.08it/s]
 14%|█▍        | 66/468 [00:02<00:18, 21.95it/s]
 15%|█▍        | 69/468 [00:03<00:18, 21.62it/s]
 15%|█▌        | 72/468 [00:03<00:18, 21.67it/s]
 16%|█▌        | 75/468 [00:03<00:17, 22.49it/s]
 17%|█▋        | 78/468 [00:03<00:17, 22.78it/s]
 17%|█▋        | 81/468 [00:03<00:16, 22.78it/s]
 18%|█▊        | 84/468 [00:03<00:16, 23.15it/s]
 19%|█▊        | 87/468 [00:03<00:16, 23.58it/s]
 19%|█▉        | 90/468 [00:03<00:16, 23.61it/s]
 20%|█▉        | 93/468 [00:04<00:16, 22.99it/s]
 21%|██        | 96/468 [00:04<00:16, 22.72it/s]
 21%|██        | 99/468 [00:04<00:16, 22.59it/s]
 22%|██▏       | 102/468 [00:04<00:16, 22.70it/s]
 22%|██▏       | 105/468 [00:04<00:15, 22.78it/s]
 23%|██▎       | 108/468 [00:04<00:15, 23.09it/s]
 24%|██▎       | 111/468 [00:04<00:15, 22.95it/s]
 24%|██▍       |

 15%|█▍        | 69/468 [00:03<00:17, 22.51it/s]
 15%|█▌        | 72/468 [00:03<00:17, 22.49it/s]
 16%|█▌        | 75/468 [00:03<00:17, 22.08it/s]
 17%|█▋        | 78/468 [00:03<00:17, 22.69it/s]
 17%|█▋        | 81/468 [00:03<00:17, 22.47it/s]
 18%|█▊        | 84/468 [00:03<00:17, 21.97it/s]
 19%|█▊        | 87/468 [00:03<00:17, 21.77it/s]
 19%|█▉        | 90/468 [00:04<00:17, 21.69it/s]
 20%|█▉        | 93/468 [00:04<00:17, 21.86it/s]
 21%|██        | 96/468 [00:04<00:16, 22.73it/s]
 21%|██        | 99/468 [00:04<00:16, 22.65it/s]
 22%|██▏       | 102/468 [00:04<00:16, 22.84it/s]
 22%|██▏       | 105/468 [00:04<00:16, 22.62it/s]
 23%|██▎       | 108/468 [00:04<00:15, 22.88it/s]
 24%|██▎       | 111/468 [00:04<00:15, 22.75it/s]
 24%|██▍       | 114/468 [00:05<00:15, 23.07it/s]
 25%|██▌       | 117/468 [00:05<00:15, 22.52it/s]
 26%|██▌       | 120/468 [00:05<00:16, 21.21it/s]
 26%|██▋       | 123/468 [00:05<00:16, 20.72it/s]
 27%|██▋       | 126/468 [00:05<00:16, 20.48it/s]
 28%|██▊   

 16%|█▌        | 76/468 [00:03<00:17, 22.68it/s]
 17%|█▋        | 79/468 [00:03<00:17, 22.56it/s]
 18%|█▊        | 82/468 [00:03<00:17, 22.38it/s]
 18%|█▊        | 85/468 [00:03<00:16, 22.55it/s]
 19%|█▉        | 88/468 [00:04<00:17, 22.12it/s]
 19%|█▉        | 91/468 [00:04<00:17, 22.02it/s]
 20%|██        | 94/468 [00:04<00:16, 22.05it/s]
 21%|██        | 97/468 [00:04<00:16, 22.62it/s]
 21%|██▏       | 100/468 [00:04<00:16, 22.22it/s]
 22%|██▏       | 103/468 [00:04<00:16, 22.24it/s]
 23%|██▎       | 106/468 [00:04<00:16, 22.45it/s]
 23%|██▎       | 109/468 [00:05<00:16, 22.25it/s]
 24%|██▍       | 112/468 [00:05<00:16, 21.63it/s]
 25%|██▍       | 115/468 [00:05<00:16, 21.68it/s]
 25%|██▌       | 118/468 [00:05<00:15, 22.20it/s]
 26%|██▌       | 121/468 [00:05<00:15, 22.03it/s]
 26%|██▋       | 124/468 [00:05<00:15, 22.50it/s]
 27%|██▋       | 127/468 [00:05<00:15, 21.99it/s]
 28%|██▊       | 130/468 [00:06<00:15, 22.18it/s]
 28%|██▊       | 133/468 [00:06<00:15, 22.26it/s]
 29%|██▉

 20%|█▉        | 93/468 [00:04<00:17, 21.95it/s]
 21%|██        | 96/468 [00:04<00:16, 22.09it/s]
 21%|██        | 99/468 [00:04<00:16, 21.96it/s]
 22%|██▏       | 102/468 [00:04<00:16, 22.55it/s]
 22%|██▏       | 105/468 [00:04<00:15, 22.72it/s]
 23%|██▎       | 108/468 [00:04<00:16, 22.34it/s]
 24%|██▎       | 111/468 [00:05<00:15, 22.52it/s]
 24%|██▍       | 114/468 [00:05<00:15, 22.81it/s]
 25%|██▌       | 117/468 [00:05<00:15, 22.49it/s]
 26%|██▌       | 120/468 [00:05<00:15, 22.08it/s]
 26%|██▋       | 123/468 [00:05<00:15, 21.71it/s]
 27%|██▋       | 126/468 [00:05<00:15, 21.83it/s]
 28%|██▊       | 129/468 [00:05<00:15, 21.87it/s]
 28%|██▊       | 132/468 [00:05<00:15, 21.75it/s]
 29%|██▉       | 135/468 [00:06<00:15, 22.10it/s]
 29%|██▉       | 138/468 [00:06<00:14, 22.30it/s]
 30%|███       | 141/468 [00:06<00:14, 22.30it/s]
 31%|███       | 144/468 [00:06<00:14, 22.49it/s]
 31%|███▏      | 147/468 [00:06<00:14, 22.13it/s]
 32%|███▏      | 150/468 [00:06<00:14, 21.98it/s]
 33

 23%|██▎       | 108/468 [00:04<00:15, 22.63it/s]
 24%|██▎       | 111/468 [00:04<00:16, 22.18it/s]
 24%|██▍       | 114/468 [00:05<00:16, 22.06it/s]
 25%|██▌       | 117/468 [00:05<00:15, 22.42it/s]
 26%|██▌       | 120/468 [00:05<00:15, 22.58it/s]
 26%|██▋       | 123/468 [00:05<00:14, 23.06it/s]
 27%|██▋       | 126/468 [00:05<00:14, 23.35it/s]
 28%|██▊       | 129/468 [00:05<00:14, 23.90it/s]
 28%|██▊       | 132/468 [00:05<00:15, 22.00it/s]
 29%|██▉       | 135/468 [00:06<00:15, 21.24it/s]
 29%|██▉       | 138/468 [00:06<00:15, 20.87it/s]
 30%|███       | 141/468 [00:06<00:16, 20.41it/s]
 31%|███       | 144/468 [00:06<00:16, 20.14it/s]
 31%|███▏      | 147/468 [00:06<00:16, 19.91it/s]
 32%|███▏      | 150/468 [00:06<00:15, 19.88it/s]
 32%|███▏      | 152/468 [00:06<00:16, 19.64it/s]
 33%|███▎      | 155/468 [00:07<00:15, 20.20it/s]
 34%|███▍      | 158/468 [00:07<00:15, 20.32it/s]
 34%|███▍      | 161/468 [00:07<00:14, 20.83it/s]
 35%|███▌      | 164/468 [00:07<00:14, 21.47it/s]


Training accuracy: 92.00%
Testing accuracy: 92.00%
