In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [None]:
def get_data_loader(training=True):
  custom_transform= transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
  if training:
    dataset = datasets.MNIST('./data', train=training, transform=custom_transform, download=True)  
    return torch.utils.data.DataLoader(dataset, batch_size = 50, shuffle = True)
  else:
    dataset = datasets.MNIST('./data', train=training, transform=custom_transform, download=True)  
    return torch.utils.data.DataLoader(dataset, batch_size = 50)
  

train_loader = get_data_loader()
print(type(train_loader))
print(train_loader.dataset)
test_loader = get_data_loader(False)
print(test_loader.dataset)

In [None]:
def build_model():
  mnistnn = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 128),
    # second arg is the number of nodes ... because each node will have only one output be default 
    # only when u add another layer, the outputs will increase
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
    )
  return mnistnn

# print(model)

In [None]:
model = build_model()

criterion = nn.CrossEntropyLoss()

def train_model(model, train_loader, criterion, T):
  
  model.train()
  opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

  for epoch in range(T): 

      running_loss = 0.0
      correct = 0
      total = 0 

      for i, data in enumerate(train_loader, 0):
          inputs, labels = data
          opt.zero_grad()
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()
          loss.backward() 
          opt.step()
          running_loss += loss.item()*train_loader.batch_size

      print("​Train Epoch: "+str(epoch)+ "   Accuracy: " + str(correct) + "/" + str(total) + "(" + "{:.2f}%".format((100 * correct / total)) + ")"+ " Loss: "+ "{:.3f}".format(running_loss/total))
        
  model.train(mode=False)

train_model(model, train_loader, criterion, T = 5)

​Train Epoch: 0   Accuracy: 48244/60000(80.41%) Loss: 0.759
​Train Epoch: 1   Accuracy: 54843/60000(91.41%) Loss: 0.295
​Train Epoch: 2   Accuracy: 55887/60000(93.14%) Loss: 0.236
​Train Epoch: 3   Accuracy: 56587/60000(94.31%) Loss: 0.197
​Train Epoch: 4   Accuracy: 57081/60000(95.14%) Loss: 0.168


In [None]:
def evaluate_model(model, test_loader, criterion, show_loss = True):
  
  model.eval()

  correct = 0
  total = 0
  running_loss = 0

  with torch.no_grad():
      for data in test_loader:
          inputs, labels = data
          outputs = model(inputs)

          if show_loss:
            loss = criterion(outputs, labels)
            running_loss += loss.item()*test_loader.batch_size

          _, predicted = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted == labels).sum().item()

  if show_loss:
    print('Average loss: '+"{:.4f}".format((running_loss / total)))
    
  print("Accuracy: " + "{:.2f}%".format((100 * correct/total)))

evaluate_model(model, test_loader, criterion)


Average loss: 0.1550
Accuracy: 95.19%


In [None]:
def predict_label(model, test_images, index):

  class_names = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']

  logit = model(test_images[index])
  sm = F.softmax(logit, dim=1)
  values,indices = sm.topk(3)

  for idx,val in zip(indices[0], values[0]):
    percentage = "{:.2%}".format(val.item())
    print(class_names[idx]+":", percentage)

for data in test_loader:
  inputs, labels = data
  pred_set = inputs
  print(labels)
  break

# print(pred_set)
# print(torch.__version__)
predict_label(model, pred_set, 3)

tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,
        4, 4])
zero: 99.97%
two: 0.01%
nine: 0.00%
