<a href="https://colab.research.google.com/github/dileeppj/TSAI_EVA/blob/master/Phase_2/Session5/EVAP2S5_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Libraries

In [0]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

## Data Transformations

We first start with defining our data transformations. We need to think what our data is and how can we augment it to correct represent images which it might not see otherwise. 


In [0]:
# Train Phase transformations
train_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                       transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.RandomRotation(degrees=(20, -20), fill=(0,)),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,)) # The mean and std have to be sequences (e.g., tuples), therefore you should add a comma after the values. 
                                       # Note the difference between (0.1307) and (0.1307,)
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])


# Dataset and Creating Train/Test Split

In [3]:
train = datasets.MNIST('./data', train=True, download=True, transform=train_transforms)
test = datasets.MNIST('./data', train=False, download=True, transform=test_transforms)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


# Dataloader Arguments & Test/Train Dataloaders


In [4]:
SEED = 1

# CUDA?
cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

# For reproducibility
torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)

# dataloader arguments - something you'll fetch these from cmdprmt
dataloader_args = dict(shuffle=True, batch_size=128, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

# train dataloader
train_loader = torch.utils.data.DataLoader(train, **dataloader_args)

# test dataloader
test_loader = torch.utils.data.DataLoader(test, **dataloader_args)

CUDA Available? True


# Data Statistics

It is important to know your data very well. Let's check some of the statistics around our data and how it actually looks like

In [5]:
# We'd need to convert it into Numpy! Remember above we have converted it into tensors already
train_data = train.train_data
train_data = train.transform(train_data.numpy())

print('[Train]')
print(' - Numpy Shape:', train.train_data.cpu().numpy().shape)
print(' - Tensor Shape:', train.train_data.size())
print(' - min:', torch.min(train_data))
print(' - max:', torch.max(train_data))
print(' - mean:', torch.mean(train_data))
print(' - std:', torch.std(train_data))
print(' - var:', torch.var(train_data))

dataiter = iter(train_loader)
images, labels = dataiter.next()

print(images.shape)
print(labels.shape)

# Let's visualize some of the images
%matplotlib inline
import matplotlib.pyplot as plt

plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')




TypeError: ignored

## MORE

It is important that we view as many images as possible. This is required to get some idea on image augmentation later on

In [6]:
figure = plt.figure()
num_of_images = 60
for index in range(1, num_of_images + 1):
    plt.subplot(6, 10, index)
    plt.axis('off')
    plt.imshow(images[index].numpy().squeeze(), cmap='gray_r')

NameError: ignored

# The model
Let's start with the model we first saw

In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.convblock1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3, 3), padding=0),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True)
        )

        self.convblock2 = nn.Sequential(
            nn.BatchNorm2d(8),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=0),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1)
        )

        self.convblock3 = nn.Sequential(
            nn.BatchNorm2d(16),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=0),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1)
        )

        self.convblock4 = nn.Sequential(
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=8, kernel_size=(1, 1), padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(stride=2,kernel_size=2)
        )

        self.convblock5 = nn.Sequential(
            nn.BatchNorm2d(8),
            nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(3, 3), padding=0),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1)
        )
        
        self.convblock6 = nn.Sequential(
            nn.BatchNorm2d(8),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=0),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1)
        )

        self.convblock7 = nn.Sequential(
            nn.BatchNorm2d(16),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=0),
            nn.ReLU(inplace=True)
        )

        self.convblock8 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=10, kernel_size=(1, 1), padding=0),
            # nn.ReLU() NEVER!
        ) # output_size = 1

        self.gap = nn.AdaptiveAvgPool2d((1,1))


    def forward(self, x):
        x = self.convblock1(x)
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.convblock4(x)
        x = self.convblock5(x)
        x = self.convblock6(x)
        x = self.convblock7(x)
        x = self.convblock8(x)
        x = self.gap(x)
        x = x.view(-1, 10)
        return F.log_softmax(x, dim=-1)

# Model Params
Can't emphasize on how important viewing Model Summary is. 
Unfortunately, there is no in-built model visualizer, so we have to take external help

In [12]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
model = Net().to(device)
summary(model, input_size=(1, 28, 28))

cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              80
       BatchNorm2d-2            [-1, 8, 26, 26]              16
              ReLU-3            [-1, 8, 26, 26]               0
       BatchNorm2d-4            [-1, 8, 26, 26]              16
            Conv2d-5           [-1, 16, 24, 24]           1,168
              ReLU-6           [-1, 16, 24, 24]               0
           Dropout-7           [-1, 16, 24, 24]               0
       BatchNorm2d-8           [-1, 16, 24, 24]              32
            Conv2d-9           [-1, 32, 22, 22]           4,640
             ReLU-10           [-1, 32, 22, 22]               0
          Dropout-11           [-1, 32, 22, 22]               0
      BatchNorm2d-12           [-1, 32, 22, 22]              64
           Conv2d-13            [-1, 8, 22, 22]             264
             ReLU-14            [-

# Training and Testing

Looking at logs can be boring, so we'll introduce **tqdm** progressbar to get cooler logs. 

Let's write train and test functions

In [0]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []

def train(model, device, train_loader, optimizer, epoch):
  model.train()
  pbar = tqdm(train_loader)
  correct = 0
  processed = 0
  for batch_idx, (data, target) in enumerate(pbar):
    # get samples
    data, target = data.to(device), target.to(device)

    # Init
    optimizer.zero_grad()
    # In PyTorch, we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes. 
    # Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly.

    # Predict
    y_pred = model(data)

    # Calculate loss
    loss = F.nll_loss(y_pred, target)
    train_losses.append(loss)

    # Backpropagation
    loss.backward()
    optimizer.step()

    # Update pbar-tqdm
    
    pred = y_pred.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()
    processed += len(data)

    pbar.set_description(desc= f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
    train_acc.append(100*correct/processed)

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    test_acc.append(100. * correct / len(test_loader.dataset))

# Let's Train and test our model

In [14]:
model =  Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
EPOCHS = 20
for epoch in range(EPOCHS):
  print("EPOCH:", epoch)
  train(model, device, train_loader, optimizer, epoch)
  test(model, device, test_loader)


  0%|          | 0/469 [00:00<?, ?it/s][A

EPOCH: 0



Loss=2.287252187728882 Batch_id=0 Accuracy=15.62:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=2.287252187728882 Batch_id=0 Accuracy=15.62:   0%|          | 1/469 [00:00<02:41,  2.90it/s][A
Loss=2.298182725906372 Batch_id=1 Accuracy=14.06:   0%|          | 1/469 [00:00<02:41,  2.90it/s][A
Loss=2.314302444458008 Batch_id=2 Accuracy=13.54:   0%|          | 1/469 [00:00<02:41,  2.90it/s][A
Loss=2.3071861267089844 Batch_id=3 Accuracy=12.89:   0%|          | 1/469 [00:00<02:41,  2.90it/s][A
Loss=2.3028767108917236 Batch_id=4 Accuracy=13.12:   0%|          | 1/469 [00:00<02:41,  2.90it/s][A
Loss=2.3028767108917236 Batch_id=4 Accuracy=13.12:   1%|          | 5/469 [00:00<01:55,  4.00it/s][A
Loss=2.300924777984619 Batch_id=5 Accuracy=13.15:   1%|          | 5/469 [00:00<01:55,  4.00it/s] [A
Loss=2.281588077545166 Batch_id=6 Accuracy=13.62:   1%|          | 5/469 [00:00<01:55,  4.00it/s][A
Loss=2.299072027206421 Batch_id=7 Accuracy=14.16:   1%|          | 5/469 [00:00<01:55,  4.00it


Test set: Average loss: 0.1394, Accuracy: 9597/10000 (95.97%)

EPOCH: 1



Loss=0.15889354050159454 Batch_id=0 Accuracy=95.31:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.15889354050159454 Batch_id=0 Accuracy=95.31:   0%|          | 1/469 [00:00<02:19,  3.35it/s][A
Loss=0.15242630243301392 Batch_id=1 Accuracy=95.31:   0%|          | 1/469 [00:00<02:19,  3.35it/s][A
Loss=0.15288187563419342 Batch_id=2 Accuracy=95.31:   0%|          | 1/469 [00:00<02:19,  3.35it/s][A
Loss=0.15340106189250946 Batch_id=3 Accuracy=94.92:   0%|          | 1/469 [00:00<02:19,  3.35it/s][A
Loss=0.2772819399833679 Batch_id=4 Accuracy=94.53:   0%|          | 1/469 [00:00<02:19,  3.35it/s] [A
Loss=0.2772819399833679 Batch_id=4 Accuracy=94.53:   1%|          | 5/469 [00:00<01:41,  4.59it/s][A
Loss=0.24549588561058044 Batch_id=5 Accuracy=94.27:   1%|          | 5/469 [00:00<01:41,  4.59it/s][A
Loss=0.1785418540239334 Batch_id=6 Accuracy=94.42:   1%|          | 5/469 [00:00<01:41,  4.59it/s] [A
Loss=0.1785418540239334 Batch_id=6 Accuracy=94.42:   1%|▏         | 7/469 [00:00<


Test set: Average loss: 0.0674, Accuracy: 9801/10000 (98.01%)

EPOCH: 2



Loss=0.12479695677757263 Batch_id=0 Accuracy=96.09:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.12479695677757263 Batch_id=0 Accuracy=96.09:   0%|          | 1/469 [00:00<02:39,  2.94it/s][A
Loss=0.05897929519414902 Batch_id=1 Accuracy=97.66:   0%|          | 1/469 [00:00<02:39,  2.94it/s][A
Loss=0.10987848788499832 Batch_id=2 Accuracy=97.40:   0%|          | 1/469 [00:00<02:39,  2.94it/s][A
Loss=0.0717482939362526 Batch_id=3 Accuracy=97.46:   0%|          | 1/469 [00:00<02:39,  2.94it/s] [A
Loss=0.0717482939362526 Batch_id=3 Accuracy=97.46:   1%|          | 4/469 [00:00<01:55,  4.03it/s][A
Loss=0.09457722306251526 Batch_id=4 Accuracy=97.50:   1%|          | 4/469 [00:00<01:55,  4.03it/s][A
Loss=0.10062611103057861 Batch_id=5 Accuracy=97.40:   1%|          | 4/469 [00:00<01:55,  4.03it/s][A
Loss=0.12953603267669678 Batch_id=6 Accuracy=97.32:   1%|          | 4/469 [00:00<01:55,  4.03it/s][A
Loss=0.12953603267669678 Batch_id=6 Accuracy=97.32:   1%|▏         | 7/469 [00:00


Test set: Average loss: 0.0636, Accuracy: 9797/10000 (97.97%)

EPOCH: 3



Loss=0.11735517531633377 Batch_id=0 Accuracy=96.09:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.11735517531633377 Batch_id=0 Accuracy=96.09:   0%|          | 1/469 [00:00<02:22,  3.28it/s][A
Loss=0.07956746965646744 Batch_id=1 Accuracy=96.88:   0%|          | 1/469 [00:00<02:22,  3.28it/s][A
Loss=0.18109315633773804 Batch_id=2 Accuracy=96.09:   0%|          | 1/469 [00:00<02:22,  3.28it/s][A
Loss=0.1576264500617981 Batch_id=3 Accuracy=95.90:   0%|          | 1/469 [00:00<02:22,  3.28it/s] [A
Loss=0.1576264500617981 Batch_id=3 Accuracy=95.90:   1%|          | 4/469 [00:00<01:44,  4.43it/s][A
Loss=0.08069388568401337 Batch_id=4 Accuracy=96.41:   1%|          | 4/469 [00:00<01:44,  4.43it/s][A
Loss=0.11827383190393448 Batch_id=5 Accuracy=96.22:   1%|          | 4/469 [00:00<01:44,  4.43it/s][A
Loss=0.11827383190393448 Batch_id=5 Accuracy=96.22:   1%|▏         | 6/469 [00:00<01:20,  5.76it/s][A
Loss=0.08732067793607712 Batch_id=6 Accuracy=96.43:   1%|▏         | 6/469 [00:00


Test set: Average loss: 0.0506, Accuracy: 9841/10000 (98.41%)

EPOCH: 4



Loss=0.10087860375642776 Batch_id=0 Accuracy=95.31:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.10087860375642776 Batch_id=0 Accuracy=95.31:   0%|          | 1/469 [00:00<02:26,  3.18it/s][A
Loss=0.08342514932155609 Batch_id=1 Accuracy=96.48:   0%|          | 1/469 [00:00<02:26,  3.18it/s][A
Loss=0.08548850566148758 Batch_id=2 Accuracy=96.35:   0%|          | 1/469 [00:00<02:26,  3.18it/s][A
Loss=0.12843972444534302 Batch_id=3 Accuracy=96.48:   0%|          | 1/469 [00:00<02:26,  3.18it/s][A
Loss=0.11998878419399261 Batch_id=4 Accuracy=96.41:   0%|          | 1/469 [00:00<02:26,  3.18it/s][A
Loss=0.11998878419399261 Batch_id=4 Accuracy=96.41:   1%|          | 5/469 [00:00<01:47,  4.31it/s][A
Loss=0.03994347155094147 Batch_id=5 Accuracy=96.88:   1%|          | 5/469 [00:00<01:47,  4.31it/s][A
Loss=0.049187421798706055 Batch_id=6 Accuracy=97.10:   1%|          | 5/469 [00:00<01:47,  4.31it/s][A
Loss=0.07996048778295517 Batch_id=7 Accuracy=97.17:   1%|          | 5/469 [00:


Test set: Average loss: 0.0398, Accuracy: 9867/10000 (98.67%)

EPOCH: 5



Loss=0.040519099682569504 Batch_id=0 Accuracy=98.44:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.040519099682569504 Batch_id=0 Accuracy=98.44:   0%|          | 1/469 [00:00<01:35,  4.90it/s][A
Loss=0.04473671689629555 Batch_id=1 Accuracy=98.44:   0%|          | 1/469 [00:00<01:35,  4.90it/s] [A
Loss=0.04473671689629555 Batch_id=1 Accuracy=98.44:   0%|          | 2/469 [00:00<01:30,  5.14it/s][A
Loss=0.0902915671467781 Batch_id=2 Accuracy=98.18:   0%|          | 2/469 [00:00<01:30,  5.14it/s] [A
Loss=0.0634794756770134 Batch_id=3 Accuracy=97.66:   0%|          | 2/469 [00:00<01:30,  5.14it/s][A
Loss=0.10837431997060776 Batch_id=4 Accuracy=97.50:   0%|          | 2/469 [00:00<01:30,  5.14it/s][A
Loss=0.019143279641866684 Batch_id=5 Accuracy=97.92:   0%|          | 2/469 [00:00<01:30,  5.14it/s][A
Loss=0.019143279641866684 Batch_id=5 Accuracy=97.92:   1%|▏         | 6/469 [00:00<01:09,  6.70it/s][A
Loss=0.13426123559474945 Batch_id=6 Accuracy=97.43:   1%|▏         | 6/469 [


Test set: Average loss: 0.0404, Accuracy: 9878/10000 (98.78%)

EPOCH: 6



Loss=0.09497884660959244 Batch_id=0 Accuracy=96.88:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.09497884660959244 Batch_id=0 Accuracy=96.88:   0%|          | 1/469 [00:00<02:24,  3.24it/s][A
Loss=0.06838139146566391 Batch_id=1 Accuracy=97.27:   0%|          | 1/469 [00:00<02:24,  3.24it/s][A
Loss=0.0629921406507492 Batch_id=2 Accuracy=97.66:   0%|          | 1/469 [00:00<02:24,  3.24it/s] [A
Loss=0.12420473247766495 Batch_id=3 Accuracy=97.27:   0%|          | 1/469 [00:00<02:24,  3.24it/s][A
Loss=0.02936386689543724 Batch_id=4 Accuracy=97.66:   0%|          | 1/469 [00:00<02:24,  3.24it/s][A
Loss=0.02936386689543724 Batch_id=4 Accuracy=97.66:   1%|          | 5/469 [00:00<01:44,  4.43it/s][A
Loss=0.09993086010217667 Batch_id=5 Accuracy=97.40:   1%|          | 5/469 [00:00<01:44,  4.43it/s][A
Loss=0.05203697457909584 Batch_id=6 Accuracy=97.54:   1%|          | 5/469 [00:00<01:44,  4.43it/s][A
Loss=0.05203697457909584 Batch_id=6 Accuracy=97.54:   1%|▏         | 7/469 [00:0


Test set: Average loss: 0.0373, Accuracy: 9880/10000 (98.80%)

EPOCH: 7



Loss=0.05838296189904213 Batch_id=0 Accuracy=98.44:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.05838296189904213 Batch_id=0 Accuracy=98.44:   0%|          | 1/469 [00:00<02:22,  3.29it/s][A
Loss=0.11502956598997116 Batch_id=1 Accuracy=97.27:   0%|          | 1/469 [00:00<02:22,  3.29it/s][A
Loss=0.0629868134856224 Batch_id=2 Accuracy=97.40:   0%|          | 1/469 [00:00<02:22,  3.29it/s] [A
Loss=0.055504899471998215 Batch_id=3 Accuracy=97.46:   0%|          | 1/469 [00:00<02:22,  3.29it/s][A
Loss=0.12966948747634888 Batch_id=4 Accuracy=97.19:   0%|          | 1/469 [00:00<02:22,  3.29it/s] [A
Loss=0.12966948747634888 Batch_id=4 Accuracy=97.19:   1%|          | 5/469 [00:00<01:42,  4.50it/s][A
Loss=0.023003123700618744 Batch_id=5 Accuracy=97.66:   1%|          | 5/469 [00:00<01:42,  4.50it/s][A
Loss=0.10578195750713348 Batch_id=6 Accuracy=97.77:   1%|          | 5/469 [00:00<01:42,  4.50it/s] [A
Loss=0.06000513210892677 Batch_id=7 Accuracy=97.85:   1%|          | 5/469 [


Test set: Average loss: 0.0339, Accuracy: 9892/10000 (98.92%)

EPOCH: 8



Loss=0.0617879256606102 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.0617879256606102 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:31,  3.09it/s][A
Loss=0.03730715066194534 Batch_id=1 Accuracy=98.83:   0%|          | 1/469 [00:00<02:31,  3.09it/s][A
Loss=0.07982203364372253 Batch_id=2 Accuracy=98.70:   0%|          | 1/469 [00:00<02:31,  3.09it/s][A
Loss=0.07177303731441498 Batch_id=3 Accuracy=98.05:   0%|          | 1/469 [00:00<02:31,  3.09it/s][A
Loss=0.07177303731441498 Batch_id=3 Accuracy=98.05:   1%|          | 4/469 [00:00<01:50,  4.22it/s][A
Loss=0.0698612779378891 Batch_id=4 Accuracy=97.97:   1%|          | 4/469 [00:00<01:50,  4.22it/s] [A
Loss=0.11346519738435745 Batch_id=5 Accuracy=97.92:   1%|          | 4/469 [00:00<01:50,  4.22it/s][A
Loss=0.11346519738435745 Batch_id=5 Accuracy=97.92:   1%|▏         | 6/469 [00:00<01:23,  5.53it/s][A
Loss=0.03806005045771599 Batch_id=6 Accuracy=98.10:   1%|▏         | 6/469 [00:00<


Test set: Average loss: 0.0367, Accuracy: 9887/10000 (98.87%)

EPOCH: 9



Loss=0.025193404406309128 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.025193404406309128 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:16,  3.42it/s][A
Loss=0.043679576367139816 Batch_id=1 Accuracy=98.44:   0%|          | 1/469 [00:00<02:16,  3.42it/s][A
Loss=0.03664763271808624 Batch_id=2 Accuracy=98.70:   0%|          | 1/469 [00:00<02:16,  3.42it/s] [A
Loss=0.03045758046209812 Batch_id=3 Accuracy=99.02:   0%|          | 1/469 [00:00<02:16,  3.42it/s][A
Loss=0.09258921444416046 Batch_id=4 Accuracy=98.75:   0%|          | 1/469 [00:00<02:16,  3.42it/s][A
Loss=0.09258921444416046 Batch_id=4 Accuracy=98.75:   1%|          | 5/469 [00:00<01:40,  4.61it/s][A
Loss=0.11485251784324646 Batch_id=5 Accuracy=98.44:   1%|          | 5/469 [00:00<01:40,  4.61it/s][A
Loss=0.011550385504961014 Batch_id=6 Accuracy=98.66:   1%|          | 5/469 [00:00<01:40,  4.61it/s][A
Loss=0.0319521464407444 Batch_id=7 Accuracy=98.63:   1%|          | 5/469 [


Test set: Average loss: 0.0357, Accuracy: 9873/10000 (98.73%)

EPOCH: 10



Loss=0.049488041549921036 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.049488041549921036 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:39,  2.94it/s][A
Loss=0.015263587236404419 Batch_id=1 Accuracy=99.61:   0%|          | 1/469 [00:00<02:39,  2.94it/s][A
Loss=0.045398514717817307 Batch_id=2 Accuracy=99.22:   0%|          | 1/469 [00:00<02:39,  2.94it/s][A
Loss=0.19434908032417297 Batch_id=3 Accuracy=98.24:   0%|          | 1/469 [00:00<02:39,  2.94it/s] [A
Loss=0.06706573069095612 Batch_id=4 Accuracy=98.28:   0%|          | 1/469 [00:00<02:39,  2.94it/s][A
Loss=0.06706573069095612 Batch_id=4 Accuracy=98.28:   1%|          | 5/469 [00:00<01:54,  4.04it/s][A
Loss=0.04914059862494469 Batch_id=5 Accuracy=98.18:   1%|          | 5/469 [00:00<01:54,  4.04it/s][A
Loss=0.07951308786869049 Batch_id=6 Accuracy=98.10:   1%|          | 5/469 [00:00<01:54,  4.04it/s][A
Loss=0.11820051074028015 Batch_id=7 Accuracy=97.95:   1%|          | 5/469 


Test set: Average loss: 0.0317, Accuracy: 9903/10000 (99.03%)

EPOCH: 11



Loss=0.020846620202064514 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.020846620202064514 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:18,  3.38it/s][A
Loss=0.04058774560689926 Batch_id=1 Accuracy=98.83:   0%|          | 1/469 [00:00<02:18,  3.38it/s] [A
Loss=0.057682570070028305 Batch_id=2 Accuracy=98.18:   0%|          | 1/469 [00:00<02:18,  3.38it/s][A
Loss=0.03971394896507263 Batch_id=3 Accuracy=98.44:   0%|          | 1/469 [00:00<02:18,  3.38it/s] [A
Loss=0.03971394896507263 Batch_id=3 Accuracy=98.44:   1%|          | 4/469 [00:00<01:41,  4.60it/s][A
Loss=0.047784846276044846 Batch_id=4 Accuracy=98.44:   1%|          | 4/469 [00:00<01:41,  4.60it/s][A
Loss=0.017616525292396545 Batch_id=5 Accuracy=98.70:   1%|          | 4/469 [00:00<01:41,  4.60it/s][A
Loss=0.022280562669038773 Batch_id=6 Accuracy=98.66:   1%|          | 4/469 [00:00<01:41,  4.60it/s][A
Loss=0.022280562669038773 Batch_id=6 Accuracy=98.66:   1%|▏         | 7/


Test set: Average loss: 0.0315, Accuracy: 9899/10000 (98.99%)

EPOCH: 12



Loss=0.0418957881629467 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.0418957881629467 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:18,  3.39it/s][A
Loss=0.0940515398979187 Batch_id=1 Accuracy=97.66:   0%|          | 1/469 [00:00<02:18,  3.39it/s][A
Loss=0.0835784524679184 Batch_id=2 Accuracy=97.14:   0%|          | 1/469 [00:00<02:18,  3.39it/s][A
Loss=0.09568244963884354 Batch_id=3 Accuracy=97.07:   0%|          | 1/469 [00:00<02:18,  3.39it/s][A
Loss=0.04621465876698494 Batch_id=4 Accuracy=97.50:   0%|          | 1/469 [00:00<02:18,  3.39it/s][A
Loss=0.04621465876698494 Batch_id=4 Accuracy=97.50:   1%|          | 5/469 [00:00<01:40,  4.60it/s][A
Loss=0.058907222002744675 Batch_id=5 Accuracy=97.66:   1%|          | 5/469 [00:00<01:40,  4.60it/s][A
Loss=0.028347685933113098 Batch_id=6 Accuracy=97.88:   1%|          | 5/469 [00:00<01:40,  4.60it/s][A
Loss=0.017127901315689087 Batch_id=7 Accuracy=98.14:   1%|          | 5/469 [00:00


Test set: Average loss: 0.0292, Accuracy: 9909/10000 (99.09%)

EPOCH: 13



Loss=0.04940362274646759 Batch_id=0 Accuracy=98.44:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.04940362274646759 Batch_id=0 Accuracy=98.44:   0%|          | 1/469 [00:00<02:26,  3.19it/s][A
Loss=0.13711762428283691 Batch_id=1 Accuracy=98.05:   0%|          | 1/469 [00:00<02:26,  3.19it/s][A
Loss=0.050844185054302216 Batch_id=2 Accuracy=98.18:   0%|          | 1/469 [00:00<02:26,  3.19it/s][A
Loss=0.07066953182220459 Batch_id=3 Accuracy=98.05:   0%|          | 1/469 [00:00<02:26,  3.19it/s] [A
Loss=0.07066953182220459 Batch_id=3 Accuracy=98.05:   1%|          | 4/469 [00:00<01:46,  4.36it/s][A
Loss=0.021181195974349976 Batch_id=4 Accuracy=98.44:   1%|          | 4/469 [00:00<01:46,  4.36it/s][A
Loss=0.1128886491060257 Batch_id=5 Accuracy=98.05:   1%|          | 4/469 [00:00<01:46,  4.36it/s]  [A
Loss=0.043369997292757034 Batch_id=6 Accuracy=98.10:   1%|          | 4/469 [00:00<01:46,  4.36it/s][A
Loss=0.043369997292757034 Batch_id=6 Accuracy=98.10:   1%|▏         | 7/469


Test set: Average loss: 0.0308, Accuracy: 9909/10000 (99.09%)

EPOCH: 14



Loss=0.10411328077316284 Batch_id=0 Accuracy=96.09:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.10411328077316284 Batch_id=0 Accuracy=96.09:   0%|          | 1/469 [00:00<02:32,  3.08it/s][A
Loss=0.09047097712755203 Batch_id=1 Accuracy=96.88:   0%|          | 1/469 [00:00<02:32,  3.08it/s][A
Loss=0.04203145578503609 Batch_id=2 Accuracy=97.40:   0%|          | 1/469 [00:00<02:32,  3.08it/s][A
Loss=0.012962596490979195 Batch_id=3 Accuracy=98.05:   0%|          | 1/469 [00:00<02:32,  3.08it/s][A
Loss=0.012962596490979195 Batch_id=3 Accuracy=98.05:   1%|          | 4/469 [00:00<01:51,  4.16it/s][A
Loss=0.0175713449716568 Batch_id=4 Accuracy=98.28:   1%|          | 4/469 [00:00<01:51,  4.16it/s]  [A
Loss=0.02930474281311035 Batch_id=5 Accuracy=98.31:   1%|          | 4/469 [00:00<01:51,  4.16it/s][A
Loss=0.03739023581147194 Batch_id=6 Accuracy=98.33:   1%|          | 4/469 [00:00<01:51,  4.16it/s][A
Loss=0.03739023581147194 Batch_id=6 Accuracy=98.33:   1%|▏         | 7/469 [0


Test set: Average loss: 0.0288, Accuracy: 9906/10000 (99.06%)

EPOCH: 15



Loss=0.057533178478479385 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.057533178478479385 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:29,  3.13it/s][A
Loss=0.12664854526519775 Batch_id=1 Accuracy=97.27:   0%|          | 1/469 [00:00<02:29,  3.13it/s] [A
Loss=0.03931061923503876 Batch_id=2 Accuracy=97.92:   0%|          | 1/469 [00:00<02:29,  3.13it/s][A
Loss=0.023060884326696396 Batch_id=3 Accuracy=98.24:   0%|          | 1/469 [00:00<02:29,  3.13it/s][A
Loss=0.023060884326696396 Batch_id=3 Accuracy=98.24:   1%|          | 4/469 [00:00<01:48,  4.27it/s][A
Loss=0.09096857160329819 Batch_id=4 Accuracy=97.97:   1%|          | 4/469 [00:00<01:48,  4.27it/s] [A
Loss=0.053920675069093704 Batch_id=5 Accuracy=98.05:   1%|          | 4/469 [00:00<01:48,  4.27it/s][A
Loss=0.053920675069093704 Batch_id=5 Accuracy=98.05:   1%|▏         | 6/469 [00:00<01:22,  5.58it/s][A
Loss=0.02743791788816452 Batch_id=6 Accuracy=98.10:   1%|▏         | 6/4


Test set: Average loss: 0.0325, Accuracy: 9905/10000 (99.05%)

EPOCH: 16



Loss=0.01609417051076889 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.01609417051076889 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:36,  2.99it/s][A
Loss=0.060250431299209595 Batch_id=1 Accuracy=98.83:   0%|          | 1/469 [00:00<02:36,  2.99it/s][A
Loss=0.05545563995838165 Batch_id=2 Accuracy=98.96:   0%|          | 1/469 [00:00<02:36,  2.99it/s] [A
Loss=0.06856077164411545 Batch_id=3 Accuracy=98.63:   0%|          | 1/469 [00:00<02:36,  2.99it/s][A
Loss=0.08978063613176346 Batch_id=4 Accuracy=98.44:   0%|          | 1/469 [00:00<02:36,  2.99it/s][A
Loss=0.08978063613176346 Batch_id=4 Accuracy=98.44:   1%|          | 5/469 [00:00<01:53,  4.07it/s][A
Loss=0.014491412788629532 Batch_id=5 Accuracy=98.70:   1%|          | 5/469 [00:00<01:53,  4.07it/s][A
Loss=0.05762704461812973 Batch_id=6 Accuracy=98.44:   1%|          | 5/469 [00:00<01:53,  4.07it/s] [A
Loss=0.055943988263607025 Batch_id=7 Accuracy=98.14:   1%|          | 5/469 


Test set: Average loss: 0.0292, Accuracy: 9902/10000 (99.02%)

EPOCH: 17



Loss=0.04306092485785484 Batch_id=0 Accuracy=98.44:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.04306092485785484 Batch_id=0 Accuracy=98.44:   0%|          | 1/469 [00:00<02:15,  3.45it/s][A
Loss=0.0878930389881134 Batch_id=1 Accuracy=97.27:   0%|          | 1/469 [00:00<02:15,  3.45it/s] [A
Loss=0.0880071297287941 Batch_id=2 Accuracy=97.14:   0%|          | 1/469 [00:00<02:15,  3.45it/s][A
Loss=0.037400469183921814 Batch_id=3 Accuracy=97.27:   0%|          | 1/469 [00:00<02:15,  3.45it/s][A
Loss=0.037400469183921814 Batch_id=3 Accuracy=97.27:   1%|          | 4/469 [00:00<01:47,  4.32it/s][A
Loss=0.05687606707215309 Batch_id=4 Accuracy=97.34:   1%|          | 4/469 [00:00<01:47,  4.32it/s] [A
Loss=0.06834138184785843 Batch_id=5 Accuracy=97.27:   1%|          | 4/469 [00:00<01:47,  4.32it/s][A
Loss=0.04176422953605652 Batch_id=6 Accuracy=97.32:   1%|          | 4/469 [00:00<01:47,  4.32it/s][A
Loss=0.04176422953605652 Batch_id=6 Accuracy=97.32:   1%|▏         | 7/469 [00


Test set: Average loss: 0.0300, Accuracy: 9899/10000 (98.99%)

EPOCH: 18



Loss=0.02988259121775627 Batch_id=0 Accuracy=99.22:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.02988259121775627 Batch_id=0 Accuracy=99.22:   0%|          | 1/469 [00:00<02:23,  3.26it/s][A
Loss=0.07839132845401764 Batch_id=1 Accuracy=98.44:   0%|          | 1/469 [00:00<02:23,  3.26it/s][A
Loss=0.09494126588106155 Batch_id=2 Accuracy=97.92:   0%|          | 1/469 [00:00<02:23,  3.26it/s][A
Loss=0.11226115375757217 Batch_id=3 Accuracy=97.85:   0%|          | 1/469 [00:00<02:23,  3.26it/s][A
Loss=0.03334442526102066 Batch_id=4 Accuracy=97.97:   0%|          | 1/469 [00:00<02:23,  3.26it/s][A
Loss=0.03334442526102066 Batch_id=4 Accuracy=97.97:   1%|          | 5/469 [00:00<01:45,  4.41it/s][A
Loss=0.03291555121541023 Batch_id=5 Accuracy=98.05:   1%|          | 5/469 [00:00<01:45,  4.41it/s][A
Loss=0.006035063415765762 Batch_id=6 Accuracy=98.33:   1%|          | 5/469 [00:00<01:45,  4.41it/s][A
Loss=0.05587584897875786 Batch_id=7 Accuracy=98.24:   1%|          | 5/469 [00:


Test set: Average loss: 0.0283, Accuracy: 9909/10000 (99.09%)

EPOCH: 19



Loss=0.041092947125434875 Batch_id=0 Accuracy=98.44:   0%|          | 0/469 [00:00<?, ?it/s][A
Loss=0.041092947125434875 Batch_id=0 Accuracy=98.44:   0%|          | 1/469 [00:00<02:44,  2.84it/s][A
Loss=0.0677061602473259 Batch_id=1 Accuracy=98.05:   0%|          | 1/469 [00:00<02:44,  2.84it/s]  [A
Loss=0.024353574961423874 Batch_id=2 Accuracy=98.44:   0%|          | 1/469 [00:00<02:44,  2.84it/s][A
Loss=0.04608890786767006 Batch_id=3 Accuracy=98.44:   0%|          | 1/469 [00:00<02:44,  2.84it/s] [A
Loss=0.05851665511727333 Batch_id=4 Accuracy=98.28:   0%|          | 1/469 [00:00<02:44,  2.84it/s][A
Loss=0.05851665511727333 Batch_id=4 Accuracy=98.28:   1%|          | 5/469 [00:00<01:58,  3.91it/s][A
Loss=0.06799633800983429 Batch_id=5 Accuracy=98.18:   1%|          | 5/469 [00:00<01:58,  3.91it/s][A
Loss=0.021800095215439796 Batch_id=6 Accuracy=98.33:   1%|          | 5/469 [00:00<01:58,  3.91it/s][A
Loss=0.05567914620041847 Batch_id=7 Accuracy=98.34:   1%|          | 5/469


Test set: Average loss: 0.0267, Accuracy: 9913/10000 (99.13%)



In [15]:
fig, axs = plt.subplots(2,2,figsize=(15,10))
axs[0, 0].plot(train_losses)
axs[0, 0].set_title("Training Loss")
axs[1, 0].plot(train_acc)
axs[1, 0].set_title("Training Accuracy")
axs[0, 1].plot(test_losses)
axs[0, 1].set_title("Test Loss")
axs[1, 1].plot(test_acc)
axs[1, 1].set_title("Test Accuracy")

NameError: ignored