<a href="https://colab.research.google.com/github/gkdivya/EVA/blob/main/6_BatchNormalization_Regularization/Experiments/GroupNorm/GroupNormalization_Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from collections import OrderedDict
from itertools import product
from torch.optim.lr_scheduler import StepLR,OneCycleLR

In [2]:
class Net(nn.Module):
    def __init__(self, dropout_value = 0.03, groups = 2):
        super(Net, self).__init__()
        self.groups = groups

        # Input Block
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3, 3), padding=0, bias=False),  # Input 28x28 output 26x26 RF : 3x3
            nn.ReLU(),
            nn.GroupNorm(num_groups=self.groups, num_channels=8),
            nn.Dropout(dropout_value),

            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=0, bias=False), # Input 26x26 output 24x24 RF : 5x5
            nn.ReLU(),
            nn.GroupNorm(num_groups=self.groups, num_channels=16),
            nn.Dropout(dropout_value)
        ) 

        #Transition Block
        self.trans1 = nn.Sequential(
            
            nn.MaxPool2d(2, 2), #  Input 24x24 output 12x12 RF : 6x6
            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=(1, 1), padding=0, bias=False)  # Input 12x12 output 12x12 RF : 6x6
        )
        

        # CONVOLUTION BLOCK 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=0, bias=False),  # Input 12x12 output 10x10 RF : 6x6
            nn.ReLU(),            
            nn.GroupNorm(num_groups=self.groups, num_channels=16),
            nn.Dropout(dropout_value),

            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding=0, bias=False),  # Input 10x10 output 8x8 RF : 10x10
            nn.ReLU(),            
            nn.GroupNorm(num_groups=self.groups, num_channels=16),
            nn.Dropout(dropout_value),

            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding=0, bias=False), # Input 8x8 output 6x6 RF : 14x14
            nn.ReLU(),            
            nn.GroupNorm(num_groups=self.groups, num_channels=16),
            nn.Dropout(dropout_value)

        ) 
        
        # OUTPUT BLOCK
        self.avgpool2d = nn.AvgPool2d(kernel_size=6)

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(1, 1), padding=0, bias=False), # Input 6x6 output 6x6 RF : 18x18
            nn.ReLU(),            
            nn.GroupNorm(num_groups=self.groups, num_channels=16),
            nn.Dropout(dropout_value))

        self.conv4 = nn.Conv2d(in_channels=16, out_channels=10, kernel_size=(1, 1), padding=0, bias=False)  # Input 6x6 output 6x6 RF : 18x18


    def forward(self, x):
        x = self.conv1(x)
        x = self.trans1(x)
        x = self.conv2(x)
        x = self.avgpool2d(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(-1, 10)
        return F.log_softmax(x, dim=-1)

In [3]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
model = Net(0.01,2).to(device)
summary(model, input_size=(1, 28, 28))

cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              72
              ReLU-2            [-1, 8, 26, 26]               0
         GroupNorm-3            [-1, 8, 26, 26]              16
           Dropout-4            [-1, 8, 26, 26]               0
            Conv2d-5           [-1, 16, 24, 24]           1,152
              ReLU-6           [-1, 16, 24, 24]               0
         GroupNorm-7           [-1, 16, 24, 24]              32
           Dropout-8           [-1, 16, 24, 24]               0
         MaxPool2d-9           [-1, 16, 12, 12]               0
           Conv2d-10            [-1, 8, 12, 12]             128
           Conv2d-11           [-1, 16, 10, 10]           1,152
             ReLU-12           [-1, 16, 10, 10]               0
        GroupNorm-13           [-1, 16, 10, 10]              32
          Dropout-14           [-1

In [4]:
# Train Phase transformations
train_transforms = transforms.Compose([
                                       transforms.RandomRotation((-6.0, 6.0), fill=(1,)),                
                                       transforms.RandomAffine(degrees=7, shear=10, translate=(0.1, 0.1), scale=(0.8, 1.2)),
                                       transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.40, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,)) # The mean and std have to be sequences (e.g., tuples), therefore you should add a comma after the values. 
                                       # Note the difference between (0.1307) and (0.1307,)
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

In [5]:
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=train_transforms)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=test_transforms)

In [6]:
SEED = 1

# CUDA?
cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

# For reproducibility
torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)

CUDA Available? True


In [7]:
from tqdm import tqdm

def train(model, device, train_loader, optimizer, epoch,train_acc,train_loss,lambda_l1,scheduler):

  model.train()
  pbar = tqdm(train_loader)
  
  correct = 0
  processed = 0
  
  for batch_idx, (data, target) in enumerate(pbar):
    # get samples
    data, target = data.to(device), target.to(device)

    # Init
    optimizer.zero_grad()
    # In PyTorch, we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes. 
    # Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly.

    # Predict
    y_pred = model(data)

    # Calculate loss
    loss = F.nll_loss(y_pred, target)
    
    #L1 Regularization
    if lambda_l1 > 0:
      l1 = 0
      for p in model.parameters():
        l1 = l1 + p.abs().sum()
      loss = loss + lambda_l1*l1

    train_loss.append(loss.data.cpu().numpy().item())

    # Backpropagation
    loss.backward()
    optimizer.step()
    scheduler.step()

    # Update pbar-tqdm
    
    pred = y_pred.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(target.view_as(pred)).sum().item()
    processed += len(data)

    pbar.set_description(desc= f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
    train_acc.append(100*correct/processed)

    


In [8]:

def test(model, device, test_loader,test_acc,test_losses):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    test_acc.append(100. * correct / len(test_loader.dataset))




In [9]:
def experiments(train_loader, test_loader, l1_factor, dropout, epochs, batchSize, groups):
    
    model = Net(dropout, groups).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.015, momentum=0.7)
    scheduler = OneCycleLR(optimizer, max_lr=0.015,epochs=epochs,steps_per_epoch=len(train_loader))
    epochs = epochs

    for epoch in range(1, epochs + 1):
      train_losses = []
      test_losses = []
      train_accuracy = []
      test_accuracy = []
      print(f'Epoch {epoch}:')
      train(model, device, train_loader, optimizer, epoch, train_accuracy, train_losses, l1_factor,scheduler)
      test(model, device, test_loader,test_accuracy,test_losses)

    return (train_accuracy[-1],test_accuracy[-1])

In [10]:
parameters = dict(
    batch_size = [32, 64],
    groups = [2, 4],
    l1 = [0, 0.01, 0.001]
)

param_values = [v for v in parameters.values()]

for batch_size, groups, l1 in product(*param_values): 
  exp_metrics = []
  dataloader_args = dict(shuffle=True, batch_size=batch_size, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)
  train_loader = torch.utils.data.DataLoader(train_dataset, **dataloader_args)
  test_loader = torch.utils.data.DataLoader(test_dataset, **dataloader_args)
  print('\n' +'\033[1m' + '=====================================Model Training for batch size:  ' + str(batch_size) + ',l1:  '+ str(l1) + str(groups) + '======================================================'+'\033[0m\n')
  exp_metrics.append([str(batch_size), str(l1), str(groups), experiments(train_loader, test_loader,l1,0.03,15,batch_size, groups)])



  0%|          | 0/1875 [00:00<?, ?it/s]



Epoch 1:


Loss=0.03604322299361229 Batch_id=1874 Accuracy=82.83: 100%|██████████| 1875/1875 [00:32<00:00, 58.16it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0838, Accuracy: 9768/10000 (97.68%)

Epoch 2:


Loss=0.22630541026592255 Batch_id=1874 Accuracy=95.13: 100%|██████████| 1875/1875 [00:32<00:00, 58.33it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0778, Accuracy: 9785/10000 (97.85%)

Epoch 3:


Loss=0.14602985978126526 Batch_id=1874 Accuracy=96.34: 100%|██████████| 1875/1875 [00:32<00:00, 57.93it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0530, Accuracy: 9844/10000 (98.44%)

Epoch 4:


Loss=0.10697684437036514 Batch_id=1874 Accuracy=97.06: 100%|██████████| 1875/1875 [00:32<00:00, 57.66it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0335, Accuracy: 9895/10000 (98.95%)

Epoch 5:


Loss=0.014094357378780842 Batch_id=1874 Accuracy=97.36: 100%|██████████| 1875/1875 [00:32<00:00, 57.99it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0301, Accuracy: 9909/10000 (99.09%)

Epoch 6:


Loss=0.010314219631254673 Batch_id=1874 Accuracy=97.68: 100%|██████████| 1875/1875 [00:32<00:00, 58.24it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0427, Accuracy: 9866/10000 (98.66%)

Epoch 7:


Loss=0.154333233833313 Batch_id=1874 Accuracy=97.84: 100%|██████████| 1875/1875 [00:32<00:00, 58.02it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0253, Accuracy: 9925/10000 (99.25%)

Epoch 8:


Loss=0.04010225832462311 Batch_id=1874 Accuracy=97.98: 100%|██████████| 1875/1875 [00:32<00:00, 57.78it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0230, Accuracy: 9924/10000 (99.24%)

Epoch 9:


Loss=0.007430655416101217 Batch_id=1874 Accuracy=98.18: 100%|██████████| 1875/1875 [00:32<00:00, 57.71it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0328, Accuracy: 9894/10000 (98.94%)

Epoch 10:


Loss=0.09861667454242706 Batch_id=1874 Accuracy=98.19: 100%|██████████| 1875/1875 [00:32<00:00, 58.44it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0259, Accuracy: 9931/10000 (99.31%)

Epoch 11:


Loss=0.21602275967597961 Batch_id=1874 Accuracy=98.37: 100%|██████████| 1875/1875 [00:32<00:00, 57.65it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0245, Accuracy: 9925/10000 (99.25%)

Epoch 12:


Loss=0.01806841604411602 Batch_id=1874 Accuracy=98.48: 100%|██████████| 1875/1875 [00:32<00:00, 58.04it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0180, Accuracy: 9949/10000 (99.49%)

Epoch 13:


Loss=0.008281148970127106 Batch_id=1874 Accuracy=98.62: 100%|██████████| 1875/1875 [00:32<00:00, 57.91it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0189, Accuracy: 9945/10000 (99.45%)

Epoch 14:


Loss=0.0019438910530880094 Batch_id=1874 Accuracy=98.75: 100%|██████████| 1875/1875 [00:32<00:00, 58.01it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0180, Accuracy: 9947/10000 (99.47%)

Epoch 15:


Loss=0.007367601618170738 Batch_id=1874 Accuracy=98.75: 100%|██████████| 1875/1875 [00:32<00:00, 57.89it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0173, Accuracy: 9949/10000 (99.49%)



Epoch 1:


Loss=2.2966649532318115 Batch_id=1874 Accuracy=79.62: 100%|██████████| 1875/1875 [00:37<00:00, 50.08it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.4283, Accuracy: 8867/10000 (88.67%)

Epoch 2:


Loss=2.6748151779174805 Batch_id=1874 Accuracy=82.63: 100%|██████████| 1875/1875 [00:37<00:00, 49.59it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.5715, Accuracy: 8343/10000 (83.43%)

Epoch 3:


Loss=2.351832389831543 Batch_id=1874 Accuracy=79.16: 100%|██████████| 1875/1875 [00:37<00:00, 49.63it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.3974, Accuracy: 9010/10000 (90.10%)

Epoch 4:


Loss=2.0635266304016113 Batch_id=1874 Accuracy=72.92: 100%|██████████| 1875/1875 [00:38<00:00, 49.22it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.6216, Accuracy: 8165/10000 (81.65%)

Epoch 5:


Loss=1.6978237628936768 Batch_id=1874 Accuracy=75.21: 100%|██████████| 1875/1875 [00:37<00:00, 49.55it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.4550, Accuracy: 8787/10000 (87.87%)

Epoch 6:


Loss=1.9452731609344482 Batch_id=1874 Accuracy=71.36: 100%|██████████| 1875/1875 [00:38<00:00, 48.99it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.5442, Accuracy: 8451/10000 (84.51%)

Epoch 7:


Loss=2.2615952491760254 Batch_id=1874 Accuracy=74.67: 100%|██████████| 1875/1875 [00:37<00:00, 49.37it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.5428, Accuracy: 8440/10000 (84.40%)

Epoch 8:


Loss=1.8016935586929321 Batch_id=1874 Accuracy=74.40: 100%|██████████| 1875/1875 [00:38<00:00, 49.06it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.4818, Accuracy: 8779/10000 (87.79%)

Epoch 9:


Loss=1.912681221961975 Batch_id=1874 Accuracy=75.06: 100%|██████████| 1875/1875 [00:38<00:00, 49.14it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.7235, Accuracy: 7898/10000 (78.98%)

Epoch 10:


Loss=2.0685644149780273 Batch_id=1874 Accuracy=76.54: 100%|██████████| 1875/1875 [00:38<00:00, 49.24it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.5166, Accuracy: 8580/10000 (85.80%)

Epoch 11:


Loss=1.8049850463867188 Batch_id=1874 Accuracy=77.90: 100%|██████████| 1875/1875 [00:38<00:00, 49.01it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.3833, Accuracy: 9031/10000 (90.31%)

Epoch 12:


Loss=1.5088908672332764 Batch_id=1874 Accuracy=79.37: 100%|██████████| 1875/1875 [00:38<00:00, 49.12it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.3498, Accuracy: 9131/10000 (91.31%)

Epoch 13:


Loss=1.4131118059158325 Batch_id=1874 Accuracy=81.64: 100%|██████████| 1875/1875 [00:38<00:00, 48.59it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.3282, Accuracy: 9183/10000 (91.83%)

Epoch 14:


Loss=1.1609395742416382 Batch_id=1874 Accuracy=85.11: 100%|██████████| 1875/1875 [00:38<00:00, 48.99it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.2740, Accuracy: 9306/10000 (93.06%)

Epoch 15:


Loss=0.8967796564102173 Batch_id=1874 Accuracy=89.27: 100%|██████████| 1875/1875 [00:38<00:00, 48.69it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1708, Accuracy: 9628/10000 (96.28%)



Epoch 1:


Loss=0.6627992987632751 Batch_id=1874 Accuracy=80.23: 100%|██████████| 1875/1875 [00:39<00:00, 47.38it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0874, Accuracy: 9785/10000 (97.85%)

Epoch 2:


Loss=0.7496763467788696 Batch_id=1874 Accuracy=93.55: 100%|██████████| 1875/1875 [00:39<00:00, 47.83it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0863, Accuracy: 9771/10000 (97.71%)

Epoch 3:


Loss=0.8076798319816589 Batch_id=1874 Accuracy=94.11: 100%|██████████| 1875/1875 [00:38<00:00, 48.48it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1110, Accuracy: 9691/10000 (96.91%)

Epoch 4:


Loss=0.5337421298027039 Batch_id=1874 Accuracy=94.49: 100%|██████████| 1875/1875 [00:38<00:00, 48.86it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1011, Accuracy: 9711/10000 (97.11%)

Epoch 5:


Loss=0.8064131140708923 Batch_id=1874 Accuracy=94.62: 100%|██████████| 1875/1875 [00:38<00:00, 49.24it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0703, Accuracy: 9795/10000 (97.95%)

Epoch 6:


Loss=0.4667077362537384 Batch_id=1874 Accuracy=94.88: 100%|██████████| 1875/1875 [00:39<00:00, 47.69it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0672, Accuracy: 9808/10000 (98.08%)

Epoch 7:


Loss=0.6251775026321411 Batch_id=1874 Accuracy=94.74: 100%|██████████| 1875/1875 [00:39<00:00, 46.99it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0821, Accuracy: 9788/10000 (97.88%)

Epoch 8:


Loss=0.49737757444381714 Batch_id=1874 Accuracy=95.05: 100%|██████████| 1875/1875 [00:38<00:00, 48.77it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0575, Accuracy: 9851/10000 (98.51%)

Epoch 9:


Loss=0.5090993046760559 Batch_id=1874 Accuracy=95.17: 100%|██████████| 1875/1875 [00:38<00:00, 49.09it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0607, Accuracy: 9835/10000 (98.35%)

Epoch 10:


Loss=0.4623987674713135 Batch_id=1874 Accuracy=95.23: 100%|██████████| 1875/1875 [00:38<00:00, 48.68it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0816, Accuracy: 9770/10000 (97.70%)

Epoch 11:


Loss=0.4072147607803345 Batch_id=1874 Accuracy=95.75: 100%|██████████| 1875/1875 [00:38<00:00, 49.14it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1084, Accuracy: 9690/10000 (96.90%)

Epoch 12:


Loss=0.6140031218528748 Batch_id=1874 Accuracy=95.99: 100%|██████████| 1875/1875 [00:38<00:00, 48.50it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0518, Accuracy: 9863/10000 (98.63%)

Epoch 13:


Loss=0.33719542622566223 Batch_id=1874 Accuracy=96.69: 100%|██████████| 1875/1875 [00:38<00:00, 48.85it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0444, Accuracy: 9876/10000 (98.76%)

Epoch 14:


Loss=0.3333326280117035 Batch_id=1874 Accuracy=97.45: 100%|██████████| 1875/1875 [00:38<00:00, 48.94it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0314, Accuracy: 9920/10000 (99.20%)

Epoch 15:


Loss=0.2625085711479187 Batch_id=1874 Accuracy=98.05: 100%|██████████| 1875/1875 [00:37<00:00, 49.40it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0271, Accuracy: 9929/10000 (99.29%)



Epoch 1:


Loss=0.2131928950548172 Batch_id=1874 Accuracy=75.87: 100%|██████████| 1875/1875 [00:33<00:00, 56.16it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1260, Accuracy: 9666/10000 (96.66%)

Epoch 2:


Loss=0.13478921353816986 Batch_id=1874 Accuracy=94.40: 100%|██████████| 1875/1875 [00:33<00:00, 55.58it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0674, Accuracy: 9814/10000 (98.14%)

Epoch 3:


Loss=0.10672961175441742 Batch_id=1874 Accuracy=96.12: 100%|██████████| 1875/1875 [00:33<00:00, 55.60it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0680, Accuracy: 9796/10000 (97.96%)

Epoch 4:


Loss=0.0719282478094101 Batch_id=1874 Accuracy=96.78: 100%|██████████| 1875/1875 [00:33<00:00, 56.21it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0402, Accuracy: 9868/10000 (98.68%)

Epoch 5:


Loss=0.030021164566278458 Batch_id=1874 Accuracy=97.19: 100%|██████████| 1875/1875 [00:33<00:00, 55.53it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0354, Accuracy: 9899/10000 (98.99%)

Epoch 6:


Loss=0.02671084925532341 Batch_id=1874 Accuracy=97.59: 100%|██████████| 1875/1875 [00:33<00:00, 56.29it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0318, Accuracy: 9898/10000 (98.98%)

Epoch 7:


Loss=0.00602272804826498 Batch_id=1874 Accuracy=97.83: 100%|██████████| 1875/1875 [00:33<00:00, 55.79it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0310, Accuracy: 9904/10000 (99.04%)

Epoch 8:


Loss=0.008957859128713608 Batch_id=1874 Accuracy=97.92: 100%|██████████| 1875/1875 [00:33<00:00, 56.10it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0275, Accuracy: 9923/10000 (99.23%)

Epoch 9:


Loss=0.014924346469342709 Batch_id=1874 Accuracy=98.12: 100%|██████████| 1875/1875 [00:33<00:00, 56.33it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0275, Accuracy: 9925/10000 (99.25%)

Epoch 10:


Loss=0.007846439257264137 Batch_id=1874 Accuracy=98.27: 100%|██████████| 1875/1875 [00:33<00:00, 55.81it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0245, Accuracy: 9922/10000 (99.22%)

Epoch 11:


Loss=0.0037553515285253525 Batch_id=1874 Accuracy=98.31: 100%|██████████| 1875/1875 [00:34<00:00, 54.26it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0249, Accuracy: 9937/10000 (99.37%)

Epoch 12:


Loss=0.03681839630007744 Batch_id=1874 Accuracy=98.43: 100%|██████████| 1875/1875 [00:35<00:00, 53.24it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0219, Accuracy: 9931/10000 (99.31%)

Epoch 13:


Loss=0.03740353882312775 Batch_id=1874 Accuracy=98.57: 100%|██████████| 1875/1875 [00:33<00:00, 56.20it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0174, Accuracy: 9945/10000 (99.45%)

Epoch 14:


Loss=0.11652367562055588 Batch_id=1874 Accuracy=98.69: 100%|██████████| 1875/1875 [00:33<00:00, 56.02it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0160, Accuracy: 9948/10000 (99.48%)

Epoch 15:


Loss=0.003795674769207835 Batch_id=1874 Accuracy=98.74: 100%|██████████| 1875/1875 [00:33<00:00, 55.77it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0160, Accuracy: 9950/10000 (99.50%)



Epoch 1:


Loss=2.5291457176208496 Batch_id=1874 Accuracy=75.54: 100%|██████████| 1875/1875 [00:38<00:00, 48.33it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.4692, Accuracy: 8915/10000 (89.15%)

Epoch 2:


Loss=2.2051033973693848 Batch_id=1874 Accuracy=80.53: 100%|██████████| 1875/1875 [00:38<00:00, 48.79it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.3710, Accuracy: 9171/10000 (91.71%)

Epoch 3:


Loss=2.7037205696105957 Batch_id=1874 Accuracy=75.39: 100%|██████████| 1875/1875 [00:38<00:00, 48.83it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.7662, Accuracy: 7591/10000 (75.91%)

Epoch 4:


Loss=2.5595598220825195 Batch_id=1874 Accuracy=73.72: 100%|██████████| 1875/1875 [00:38<00:00, 48.90it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.4465, Accuracy: 8820/10000 (88.20%)

Epoch 5:


Loss=2.561896324157715 Batch_id=1874 Accuracy=71.95: 100%|██████████| 1875/1875 [00:38<00:00, 48.37it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.5512, Accuracy: 8830/10000 (88.30%)

Epoch 6:


Loss=2.0279295444488525 Batch_id=1874 Accuracy=72.55: 100%|██████████| 1875/1875 [00:38<00:00, 49.12it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.4633, Accuracy: 8952/10000 (89.52%)

Epoch 7:


Loss=1.8859385251998901 Batch_id=1874 Accuracy=70.89: 100%|██████████| 1875/1875 [00:38<00:00, 48.79it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.7293, Accuracy: 7717/10000 (77.17%)

Epoch 8:


Loss=1.9932217597961426 Batch_id=1874 Accuracy=72.02: 100%|██████████| 1875/1875 [00:38<00:00, 48.78it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.6853, Accuracy: 7908/10000 (79.08%)

Epoch 9:


Loss=2.259164333343506 Batch_id=1874 Accuracy=73.04: 100%|██████████| 1875/1875 [00:38<00:00, 48.91it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.5755, Accuracy: 8346/10000 (83.46%)

Epoch 10:


Loss=2.1655631065368652 Batch_id=1874 Accuracy=73.09: 100%|██████████| 1875/1875 [00:38<00:00, 48.65it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.6105, Accuracy: 8486/10000 (84.86%)

Epoch 11:


Loss=2.053159713745117 Batch_id=1874 Accuracy=74.88: 100%|██████████| 1875/1875 [00:38<00:00, 48.76it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.5924, Accuracy: 8335/10000 (83.35%)

Epoch 12:


Loss=1.7517530918121338 Batch_id=1874 Accuracy=76.66: 100%|██████████| 1875/1875 [00:38<00:00, 48.45it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.4335, Accuracy: 8785/10000 (87.85%)

Epoch 13:


Loss=1.7795584201812744 Batch_id=1874 Accuracy=79.67: 100%|██████████| 1875/1875 [00:38<00:00, 48.77it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.3561, Accuracy: 9152/10000 (91.52%)

Epoch 14:


Loss=1.4714961051940918 Batch_id=1874 Accuracy=83.75: 100%|██████████| 1875/1875 [00:38<00:00, 48.70it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.2905, Accuracy: 9310/10000 (93.10%)

Epoch 15:


Loss=1.0158836841583252 Batch_id=1874 Accuracy=89.06: 100%|██████████| 1875/1875 [00:38<00:00, 48.73it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.2161, Accuracy: 9577/10000 (95.77%)



Epoch 1:


Loss=1.1037184000015259 Batch_id=1874 Accuracy=80.92: 100%|██████████| 1875/1875 [00:38<00:00, 48.94it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.2126, Accuracy: 9462/10000 (94.62%)

Epoch 2:


Loss=0.610024631023407 Batch_id=1874 Accuracy=93.56: 100%|██████████| 1875/1875 [00:38<00:00, 49.14it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0781, Accuracy: 9787/10000 (97.87%)

Epoch 3:


Loss=0.7950971126556396 Batch_id=1874 Accuracy=94.03: 100%|██████████| 1875/1875 [00:38<00:00, 48.75it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1142, Accuracy: 9684/10000 (96.84%)

Epoch 4:


Loss=0.5383807420730591 Batch_id=1874 Accuracy=94.19: 100%|██████████| 1875/1875 [00:38<00:00, 48.38it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1309, Accuracy: 9625/10000 (96.25%)

Epoch 5:


Loss=0.643061101436615 Batch_id=1874 Accuracy=94.21: 100%|██████████| 1875/1875 [00:38<00:00, 48.43it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0978, Accuracy: 9729/10000 (97.29%)

Epoch 6:


Loss=0.7170119285583496 Batch_id=1874 Accuracy=94.26: 100%|██████████| 1875/1875 [00:38<00:00, 48.85it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1180, Accuracy: 9666/10000 (96.66%)

Epoch 7:


Loss=0.4727374315261841 Batch_id=1874 Accuracy=94.60: 100%|██████████| 1875/1875 [00:38<00:00, 48.46it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0679, Accuracy: 9832/10000 (98.32%)

Epoch 8:


Loss=0.5289927124977112 Batch_id=1874 Accuracy=94.17: 100%|██████████| 1875/1875 [00:39<00:00, 47.57it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.1077, Accuracy: 9680/10000 (96.80%)

Epoch 9:


Loss=0.5208366513252258 Batch_id=1874 Accuracy=95.01: 100%|██████████| 1875/1875 [00:38<00:00, 48.16it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0657, Accuracy: 9839/10000 (98.39%)

Epoch 10:


Loss=0.4022778868675232 Batch_id=1874 Accuracy=95.28: 100%|██████████| 1875/1875 [00:39<00:00, 47.72it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0729, Accuracy: 9802/10000 (98.02%)

Epoch 11:


Loss=0.6775970458984375 Batch_id=1874 Accuracy=95.52: 100%|██████████| 1875/1875 [00:38<00:00, 48.24it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0760, Accuracy: 9792/10000 (97.92%)

Epoch 12:


Loss=0.346333771944046 Batch_id=1874 Accuracy=95.87: 100%|██████████| 1875/1875 [00:39<00:00, 47.80it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0627, Accuracy: 9835/10000 (98.35%)

Epoch 13:


Loss=0.3335966467857361 Batch_id=1874 Accuracy=96.29: 100%|██████████| 1875/1875 [00:38<00:00, 48.11it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0461, Accuracy: 9871/10000 (98.71%)

Epoch 14:


Loss=0.3018738627433777 Batch_id=1874 Accuracy=97.28: 100%|██████████| 1875/1875 [00:39<00:00, 47.04it/s]
  0%|          | 0/1875 [00:00<?, ?it/s]


Test set: Average loss: 0.0353, Accuracy: 9909/10000 (99.09%)

Epoch 15:


Loss=0.2701476216316223 Batch_id=1874 Accuracy=97.76: 100%|██████████| 1875/1875 [00:39<00:00, 47.74it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0313, Accuracy: 9918/10000 (99.18%)



Epoch 1:


Loss=0.3761192262172699 Batch_id=937 Accuracy=78.75: 100%|██████████| 938/938 [00:25<00:00, 36.09it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1043, Accuracy: 9765/10000 (97.65%)

Epoch 2:


Loss=0.026097215712070465 Batch_id=937 Accuracy=94.52: 100%|██████████| 938/938 [00:27<00:00, 34.47it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0652, Accuracy: 9807/10000 (98.07%)

Epoch 3:


Loss=0.02512061968445778 Batch_id=937 Accuracy=96.27: 100%|██████████| 938/938 [00:27<00:00, 34.15it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0405, Accuracy: 9885/10000 (98.85%)

Epoch 4:


Loss=0.0907648578286171 Batch_id=937 Accuracy=97.05: 100%|██████████| 938/938 [00:26<00:00, 34.76it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0334, Accuracy: 9905/10000 (99.05%)

Epoch 5:


Loss=0.14726179838180542 Batch_id=937 Accuracy=97.32: 100%|██████████| 938/938 [00:27<00:00, 34.63it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0319, Accuracy: 9893/10000 (98.93%)

Epoch 6:


Loss=0.012556998059153557 Batch_id=937 Accuracy=97.71: 100%|██████████| 938/938 [00:27<00:00, 33.82it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0363, Accuracy: 9893/10000 (98.93%)

Epoch 7:


Loss=0.02303885668516159 Batch_id=937 Accuracy=97.89: 100%|██████████| 938/938 [00:27<00:00, 34.12it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0319, Accuracy: 9909/10000 (99.09%)

Epoch 8:


Loss=0.01147466991096735 Batch_id=937 Accuracy=98.02: 100%|██████████| 938/938 [00:27<00:00, 33.97it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0279, Accuracy: 9919/10000 (99.19%)

Epoch 9:


Loss=0.036713868379592896 Batch_id=937 Accuracy=98.19: 100%|██████████| 938/938 [00:27<00:00, 34.50it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0216, Accuracy: 9935/10000 (99.35%)

Epoch 10:


Loss=0.024600045755505562 Batch_id=937 Accuracy=98.29: 100%|██████████| 938/938 [00:26<00:00, 35.04it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0221, Accuracy: 9939/10000 (99.39%)

Epoch 11:


Loss=0.01263779029250145 Batch_id=937 Accuracy=98.46: 100%|██████████| 938/938 [00:27<00:00, 34.39it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0179, Accuracy: 9940/10000 (99.40%)

Epoch 12:


Loss=0.0515856072306633 Batch_id=937 Accuracy=98.50: 100%|██████████| 938/938 [00:27<00:00, 34.27it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0173, Accuracy: 9945/10000 (99.45%)

Epoch 13:


Loss=0.04209083318710327 Batch_id=937 Accuracy=98.75: 100%|██████████| 938/938 [00:26<00:00, 34.96it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0160, Accuracy: 9951/10000 (99.51%)

Epoch 14:


Loss=0.05294476076960564 Batch_id=937 Accuracy=98.75: 100%|██████████| 938/938 [00:26<00:00, 35.05it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0166, Accuracy: 9949/10000 (99.49%)

Epoch 15:


Loss=0.03651567921042442 Batch_id=937 Accuracy=98.84: 100%|██████████| 938/938 [00:26<00:00, 34.82it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0161, Accuracy: 9951/10000 (99.51%)



Epoch 1:


Loss=2.3642971515655518 Batch_id=937 Accuracy=73.54: 100%|██████████| 938/938 [00:28<00:00, 33.03it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2848, Accuracy: 9367/10000 (93.67%)

Epoch 2:


Loss=2.0284461975097656 Batch_id=937 Accuracy=85.48: 100%|██████████| 938/938 [00:29<00:00, 32.23it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2449, Accuracy: 9435/10000 (94.35%)

Epoch 3:


Loss=2.3017866611480713 Batch_id=937 Accuracy=82.32: 100%|██████████| 938/938 [00:28<00:00, 32.83it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.5581, Accuracy: 8099/10000 (80.99%)

Epoch 4:


Loss=1.658518671989441 Batch_id=937 Accuracy=82.07: 100%|██████████| 938/938 [00:28<00:00, 32.90it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.4127, Accuracy: 9001/10000 (90.01%)

Epoch 5:


Loss=1.7342171669006348 Batch_id=937 Accuracy=81.88: 100%|██████████| 938/938 [00:28<00:00, 32.83it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.4296, Accuracy: 8882/10000 (88.82%)

Epoch 6:


Loss=1.9916768074035645 Batch_id=937 Accuracy=81.79: 100%|██████████| 938/938 [00:28<00:00, 33.35it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3962, Accuracy: 9013/10000 (90.13%)

Epoch 7:


Loss=1.6988788843154907 Batch_id=937 Accuracy=81.88: 100%|██████████| 938/938 [00:28<00:00, 32.52it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2687, Accuracy: 9420/10000 (94.20%)

Epoch 8:


Loss=1.4936442375183105 Batch_id=937 Accuracy=82.38: 100%|██████████| 938/938 [00:28<00:00, 32.61it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3966, Accuracy: 8893/10000 (88.93%)

Epoch 9:


Loss=1.8405027389526367 Batch_id=937 Accuracy=82.16: 100%|██████████| 938/938 [00:28<00:00, 33.21it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3032, Accuracy: 9295/10000 (92.95%)

Epoch 10:


Loss=1.4570266008377075 Batch_id=937 Accuracy=83.08: 100%|██████████| 938/938 [00:28<00:00, 32.98it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3476, Accuracy: 9067/10000 (90.67%)

Epoch 11:


Loss=1.4214507341384888 Batch_id=937 Accuracy=83.85: 100%|██████████| 938/938 [00:28<00:00, 32.42it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3478, Accuracy: 9208/10000 (92.08%)

Epoch 12:


Loss=1.2258358001708984 Batch_id=937 Accuracy=85.41: 100%|██████████| 938/938 [00:29<00:00, 31.76it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2478, Accuracy: 9438/10000 (94.38%)

Epoch 13:


Loss=1.1838465929031372 Batch_id=937 Accuracy=87.11: 100%|██████████| 938/938 [00:29<00:00, 32.15it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2060, Accuracy: 9484/10000 (94.84%)

Epoch 14:


Loss=0.901394248008728 Batch_id=937 Accuracy=89.61: 100%|██████████| 938/938 [00:29<00:00, 31.79it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1477, Accuracy: 9680/10000 (96.80%)

Epoch 15:


Loss=0.7306032180786133 Batch_id=937 Accuracy=92.66: 100%|██████████| 938/938 [00:29<00:00, 31.94it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1193, Accuracy: 9797/10000 (97.97%)



Epoch 1:


Loss=0.6067317724227905 Batch_id=937 Accuracy=79.00: 100%|██████████| 938/938 [00:28<00:00, 32.59it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1285, Accuracy: 9691/10000 (96.91%)

Epoch 2:


Loss=0.46574005484580994 Batch_id=937 Accuracy=94.01: 100%|██████████| 938/938 [00:28<00:00, 32.36it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0896, Accuracy: 9732/10000 (97.32%)

Epoch 3:


Loss=0.44925013184547424 Batch_id=937 Accuracy=94.75: 100%|██████████| 938/938 [00:28<00:00, 32.75it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0937, Accuracy: 9719/10000 (97.19%)

Epoch 4:


Loss=0.4860391616821289 Batch_id=937 Accuracy=95.22: 100%|██████████| 938/938 [00:28<00:00, 32.97it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0810, Accuracy: 9772/10000 (97.72%)

Epoch 5:


Loss=0.5651965141296387 Batch_id=937 Accuracy=95.34: 100%|██████████| 938/938 [00:28<00:00, 32.42it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0557, Accuracy: 9850/10000 (98.50%)

Epoch 6:


Loss=0.6525204181671143 Batch_id=937 Accuracy=95.66: 100%|██████████| 938/938 [00:28<00:00, 32.82it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1304, Accuracy: 9593/10000 (95.93%)

Epoch 7:


Loss=0.5783227682113647 Batch_id=937 Accuracy=95.54: 100%|██████████| 938/938 [00:28<00:00, 32.63it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0772, Accuracy: 9795/10000 (97.95%)

Epoch 8:


Loss=0.4419381618499756 Batch_id=937 Accuracy=95.85: 100%|██████████| 938/938 [00:28<00:00, 32.36it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0705, Accuracy: 9818/10000 (98.18%)

Epoch 9:


Loss=0.5743560194969177 Batch_id=937 Accuracy=95.84: 100%|██████████| 938/938 [00:28<00:00, 32.78it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0535, Accuracy: 9857/10000 (98.57%)

Epoch 10:


Loss=0.4540412425994873 Batch_id=937 Accuracy=96.06: 100%|██████████| 938/938 [00:28<00:00, 32.68it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0587, Accuracy: 9854/10000 (98.54%)

Epoch 11:


Loss=0.3299703299999237 Batch_id=937 Accuracy=96.39: 100%|██████████| 938/938 [00:28<00:00, 32.62it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0686, Accuracy: 9795/10000 (97.95%)

Epoch 12:


Loss=0.29861295223236084 Batch_id=937 Accuracy=96.59: 100%|██████████| 938/938 [00:28<00:00, 32.72it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0525, Accuracy: 9853/10000 (98.53%)

Epoch 13:


Loss=0.3926628828048706 Batch_id=937 Accuracy=97.11: 100%|██████████| 938/938 [00:28<00:00, 33.34it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0386, Accuracy: 9889/10000 (98.89%)

Epoch 14:


Loss=0.265083372592926 Batch_id=937 Accuracy=97.74: 100%|██████████| 938/938 [00:28<00:00, 33.15it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0282, Accuracy: 9928/10000 (99.28%)

Epoch 15:


Loss=0.2410869002342224 Batch_id=937 Accuracy=98.12: 100%|██████████| 938/938 [00:28<00:00, 32.92it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0261, Accuracy: 9932/10000 (99.32%)



Epoch 1:


Loss=0.1752077341079712 Batch_id=937 Accuracy=70.66: 100%|██████████| 938/938 [00:27<00:00, 34.50it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1377, Accuracy: 9691/10000 (96.91%)

Epoch 2:


Loss=0.1410074234008789 Batch_id=937 Accuracy=93.62: 100%|██████████| 938/938 [00:27<00:00, 34.09it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0631, Accuracy: 9830/10000 (98.30%)

Epoch 3:


Loss=0.4736015498638153 Batch_id=937 Accuracy=95.94: 100%|██████████| 938/938 [00:27<00:00, 34.15it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0520, Accuracy: 9866/10000 (98.66%)

Epoch 4:


Loss=0.24038749933242798 Batch_id=937 Accuracy=96.74: 100%|██████████| 938/938 [00:27<00:00, 34.45it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0414, Accuracy: 9875/10000 (98.75%)

Epoch 5:


Loss=0.1506640762090683 Batch_id=937 Accuracy=97.00: 100%|██████████| 938/938 [00:26<00:00, 34.96it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0494, Accuracy: 9874/10000 (98.74%)

Epoch 6:


Loss=0.060906991362571716 Batch_id=937 Accuracy=97.35: 100%|██████████| 938/938 [00:26<00:00, 34.96it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0370, Accuracy: 9891/10000 (98.91%)

Epoch 7:


Loss=0.008969505317509174 Batch_id=937 Accuracy=97.56: 100%|██████████| 938/938 [00:27<00:00, 34.41it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0417, Accuracy: 9883/10000 (98.83%)

Epoch 8:


Loss=0.15343324840068817 Batch_id=937 Accuracy=97.78: 100%|██████████| 938/938 [00:27<00:00, 34.45it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0261, Accuracy: 9929/10000 (99.29%)

Epoch 9:


Loss=0.034854400902986526 Batch_id=937 Accuracy=97.94: 100%|██████████| 938/938 [00:27<00:00, 34.55it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0287, Accuracy: 9919/10000 (99.19%)

Epoch 10:


Loss=0.011699127033352852 Batch_id=937 Accuracy=97.96: 100%|██████████| 938/938 [00:26<00:00, 34.87it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0279, Accuracy: 9915/10000 (99.15%)

Epoch 11:


Loss=0.028722791001200676 Batch_id=937 Accuracy=98.16: 100%|██████████| 938/938 [00:27<00:00, 34.73it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0302, Accuracy: 9917/10000 (99.17%)

Epoch 12:


Loss=0.01984882354736328 Batch_id=937 Accuracy=98.43: 100%|██████████| 938/938 [00:27<00:00, 34.19it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0248, Accuracy: 9929/10000 (99.29%)

Epoch 13:


Loss=0.005749055650085211 Batch_id=937 Accuracy=98.49: 100%|██████████| 938/938 [00:26<00:00, 35.06it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0245, Accuracy: 9937/10000 (99.37%)

Epoch 14:


Loss=0.01666695810854435 Batch_id=937 Accuracy=98.64: 100%|██████████| 938/938 [00:27<00:00, 34.73it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0206, Accuracy: 9941/10000 (99.41%)

Epoch 15:


Loss=0.012343412265181541 Batch_id=937 Accuracy=98.69: 100%|██████████| 938/938 [00:27<00:00, 34.24it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0202, Accuracy: 9944/10000 (99.44%)



Epoch 1:


Loss=2.7754223346710205 Batch_id=937 Accuracy=68.18: 100%|██████████| 938/938 [00:28<00:00, 32.97it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.4989, Accuracy: 8648/10000 (86.48%)

Epoch 2:


Loss=2.1253273487091064 Batch_id=937 Accuracy=85.76: 100%|██████████| 938/938 [00:28<00:00, 32.85it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3350, Accuracy: 9201/10000 (92.01%)

Epoch 3:


Loss=2.0544044971466064 Batch_id=937 Accuracy=82.33: 100%|██████████| 938/938 [00:29<00:00, 32.31it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.4137, Accuracy: 8915/10000 (89.15%)

Epoch 4:


Loss=2.348853826522827 Batch_id=937 Accuracy=81.02: 100%|██████████| 938/938 [00:28<00:00, 32.70it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.5367, Accuracy: 8514/10000 (85.14%)

Epoch 5:


Loss=2.297788143157959 Batch_id=937 Accuracy=79.42: 100%|██████████| 938/938 [00:28<00:00, 32.54it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.5411, Accuracy: 8539/10000 (85.39%)

Epoch 6:


Loss=2.2920172214508057 Batch_id=937 Accuracy=77.81: 100%|██████████| 938/938 [00:28<00:00, 33.01it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.6023, Accuracy: 8431/10000 (84.31%)

Epoch 7:


Loss=1.6932252645492554 Batch_id=937 Accuracy=79.18: 100%|██████████| 938/938 [00:28<00:00, 32.89it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3463, Accuracy: 9229/10000 (92.29%)

Epoch 8:


Loss=1.69957435131073 Batch_id=937 Accuracy=80.42: 100%|██████████| 938/938 [00:28<00:00, 32.65it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.3662, Accuracy: 9107/10000 (91.07%)

Epoch 9:


Loss=1.930107831954956 Batch_id=937 Accuracy=80.80: 100%|██████████| 938/938 [00:29<00:00, 32.13it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.4265, Accuracy: 8772/10000 (87.72%)

Epoch 10:


Loss=2.257108449935913 Batch_id=937 Accuracy=81.52: 100%|██████████| 938/938 [00:29<00:00, 32.02it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.6949, Accuracy: 7667/10000 (76.67%)

Epoch 11:


Loss=1.9677538871765137 Batch_id=937 Accuracy=82.86: 100%|██████████| 938/938 [00:29<00:00, 32.30it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2901, Accuracy: 9300/10000 (93.00%)

Epoch 12:


Loss=1.1989219188690186 Batch_id=937 Accuracy=84.14: 100%|██████████| 938/938 [00:28<00:00, 32.78it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2737, Accuracy: 9351/10000 (93.51%)

Epoch 13:


Loss=1.2867799997329712 Batch_id=937 Accuracy=86.95: 100%|██████████| 938/938 [00:29<00:00, 32.26it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.2133, Accuracy: 9522/10000 (95.22%)

Epoch 14:


Loss=1.0687167644500732 Batch_id=937 Accuracy=89.94: 100%|██████████| 938/938 [00:28<00:00, 32.60it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1787, Accuracy: 9611/10000 (96.11%)

Epoch 15:


Loss=0.8176392316818237 Batch_id=937 Accuracy=93.28: 100%|██████████| 938/938 [00:28<00:00, 32.56it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1321, Accuracy: 9739/10000 (97.39%)



Epoch 1:


Loss=0.9174429178237915 Batch_id=937 Accuracy=71.52: 100%|██████████| 938/938 [00:28<00:00, 32.89it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1848, Accuracy: 9541/10000 (95.41%)

Epoch 2:


Loss=0.6110222339630127 Batch_id=937 Accuracy=93.31: 100%|██████████| 938/938 [00:29<00:00, 32.23it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.1035, Accuracy: 9705/10000 (97.05%)

Epoch 3:


Loss=0.4553270637989044 Batch_id=937 Accuracy=94.45: 100%|██████████| 938/938 [00:28<00:00, 32.45it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0814, Accuracy: 9760/10000 (97.60%)

Epoch 4:


Loss=0.5825378894805908 Batch_id=937 Accuracy=94.35: 100%|██████████| 938/938 [00:28<00:00, 33.37it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0843, Accuracy: 9764/10000 (97.64%)

Epoch 5:


Loss=0.44892144203186035 Batch_id=937 Accuracy=94.82: 100%|██████████| 938/938 [00:28<00:00, 33.47it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0601, Accuracy: 9830/10000 (98.30%)

Epoch 6:


Loss=0.5528338551521301 Batch_id=937 Accuracy=95.15: 100%|██████████| 938/938 [00:28<00:00, 33.32it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0649, Accuracy: 9819/10000 (98.19%)

Epoch 7:


Loss=0.4132385551929474 Batch_id=937 Accuracy=95.27: 100%|██████████| 938/938 [00:28<00:00, 33.28it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0741, Accuracy: 9796/10000 (97.96%)

Epoch 8:


Loss=0.6911716461181641 Batch_id=937 Accuracy=95.44: 100%|██████████| 938/938 [00:28<00:00, 33.05it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0698, Accuracy: 9800/10000 (98.00%)

Epoch 9:


Loss=0.49637657403945923 Batch_id=937 Accuracy=95.55: 100%|██████████| 938/938 [00:28<00:00, 33.27it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0653, Accuracy: 9820/10000 (98.20%)

Epoch 10:


Loss=0.3977661728858948 Batch_id=937 Accuracy=95.82: 100%|██████████| 938/938 [00:28<00:00, 33.33it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0501, Accuracy: 9865/10000 (98.65%)

Epoch 11:


Loss=0.41506704688072205 Batch_id=937 Accuracy=96.00: 100%|██████████| 938/938 [00:28<00:00, 33.46it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0502, Accuracy: 9857/10000 (98.57%)

Epoch 12:


Loss=0.4951786994934082 Batch_id=937 Accuracy=96.52: 100%|██████████| 938/938 [00:28<00:00, 33.14it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0537, Accuracy: 9856/10000 (98.56%)

Epoch 13:


Loss=0.4081740975379944 Batch_id=937 Accuracy=96.96: 100%|██████████| 938/938 [00:28<00:00, 33.43it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0406, Accuracy: 9895/10000 (98.95%)

Epoch 14:


Loss=0.3379467725753784 Batch_id=937 Accuracy=97.64: 100%|██████████| 938/938 [00:28<00:00, 33.36it/s]
  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0304, Accuracy: 9909/10000 (99.09%)

Epoch 15:


Loss=0.3799988627433777 Batch_id=937 Accuracy=98.05: 100%|██████████| 938/938 [00:28<00:00, 33.30it/s]



Test set: Average loss: 0.0271, Accuracy: 9926/10000 (99.26%)



In [17]:
print(len(exp_metrics))

1


In [12]:
def plot_metrics(results):
  fig, axs = plt.subplots(2,2,figsize=(25,15))
  for idx, exp_name in enumerate(results.keys()):
      train_accuracy,train_losses,test_accuracy,test_losses  = results[exp_name]
      axs[0, 0].set_title("Training Loss")
      axs[1, 0].set_title("Training Accuracy")
      axs[0, 1].set_title("Test Loss")
      axs[1, 1].set_title("Test Accuracy")

      axs[0, 0].plot(train_losses, label='{} reg'.format(exp_name))
      axs[0,0].legend(loc='upper right')
      axs[0,0].set_xlabel('epochs')
      axs[0,0].set_ylabel('loss')

      axs[1, 0].plot(train_accuracy, label='{} reg'.format(exp_name))
      axs[1,0].legend(loc='lower right')
      axs[1,0].set_xlabel('epochs')
      axs[1,0].set_ylabel('loss')

      axs[0, 1].plot(test_losses, label='{} reg'.format(exp_name))
      axs[0,1].legend(loc='upper right')
      axs[0,1].set_xlabel('epochs')
      axs[0,1].set_ylabel('loss')

      axs[1, 1].plot(test_accuracy, label='{} reg'.format(exp_name))
      axs[1,1].legend(loc='lower right')
      axs[1,1].set_xlabel('epochs')
      axs[1,1].set_ylabel('loss')

In [13]:
plot_metrics(exp_metrics)

NameError: ignored

In [None]:
def wrong_predictions(test_loader,model,device):
  wrong_images=[]
  wrong_label=[]
  correct_label=[]
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)        
      pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability

      wrong_pred = (pred.eq(target.view_as(pred)) == False)
      wrong_images.append(data[wrong_pred])
      wrong_label.append(pred[wrong_pred])
      correct_label.append(target.view_as(pred)[wrong_pred])  
      
      wrong_predictions = list(zip(torch.cat(wrong_images),torch.cat(wrong_label),torch.cat(correct_label)))    
    print(f'Total wrong predictions are {len(wrong_predictions)}')
      
      
    fig = plt.figure(figsize=(8,10))
    fig.tight_layout()
    for i, (img, pred, correct) in enumerate(wrong_predictions[:10]):
          img, pred, target = img.cpu().numpy(), pred.cpu(), correct.cpu()
          ax = fig.add_subplot(5, 2, i+1)
          ax.axis('off')
          ax.set_title(f'\nactual {target.item()}\npredicted {pred.item()}',fontsize=10)  
          ax.imshow(img.squeeze(), cmap='gray_r')  
          
    plt.show()
      
  return 