In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

This step will try to reduce overfitting

**Target**

1.   Fix overfitting, add dropout 


**Results**
Overfitting seems to have gone. The model is still not learning enough, and fast enough as the train accuracy is not moving beyond 99%.

1. Params 9k
2. Best train accuracy 98.95
3. Best test accuracy 99.05








In [6]:
dropout_value = 0.1
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        ## INPUT BLOCK
        self.convblock1 = nn.Sequential(
            nn.Conv2d(1, 8, 3), 
            nn.BatchNorm2d(num_features=8, eps=1e-05, momentum=0.1),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        ) #i 28 o 26 RF 3

        ## BLOCK 1
        self.convblock2 = nn.Sequential(
            nn.Conv2d(8, 16, 3), 
            nn.BatchNorm2d(num_features=16, eps=1e-05, momentum=0.1),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        ) #i 26 o 24 RF 5
        self.convblock3 = nn.Sequential(
            nn.Conv2d(16, 24, 3), 
            nn.BatchNorm2d(num_features=24, eps=1e-05, momentum=0.1),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        ) #i 24 o 22

        ## TRANSITION BLOCK
        self.pool1 = nn.MaxPool2d(2, 2) #i 22 o 11
        self.convblock4 = nn.Sequential(
            nn.Conv2d(24, 16, 1), 
            nn.BatchNorm2d(num_features=16, eps=1e-05, momentum=0.1),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        ) #i 11 o 11

        
        #self.dropout = nn.Dropout(0.1)
        
        ## BLOCK 2
        self.convblock5 = nn.Sequential(
            nn.Conv2d(16, 16, 3),
            nn.BatchNorm2d(num_features=16, eps=1e-05, momentum=0.1),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        ) #i 11 o 9

        self.pool2 = nn.MaxPool2d(2, 2) #i 9 o 4

        ## OUTPUT BLOCK
        self.convblock6 = nn.Sequential(
            nn.Conv2d(16, 10, 3), #i4 o 2
            nn.AdaptiveAvgPool2d(1)
        )  
        

    def forward(self, x):
        ## block 1
        x = self.convblock1(x) 
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.pool1(x) 
        x = self.convblock4(x) 
        x = self.convblock5(x)
        x = self.pool2(x)
        x = self.convblock6(x) 
        x = x.view(-1, 10)
        return F.log_softmax(x)

In [7]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = Net().to(device)
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              80
       BatchNorm2d-2            [-1, 8, 26, 26]              16
              ReLU-3            [-1, 8, 26, 26]               0
           Dropout-4            [-1, 8, 26, 26]               0
            Conv2d-5           [-1, 16, 24, 24]           1,168
       BatchNorm2d-6           [-1, 16, 24, 24]              32
              ReLU-7           [-1, 16, 24, 24]               0
           Dropout-8           [-1, 16, 24, 24]               0
            Conv2d-9           [-1, 24, 22, 22]           3,480
      BatchNorm2d-10           [-1, 24, 22, 22]              48
             ReLU-11           [-1, 24, 22, 22]               0
          Dropout-12           [-1, 24, 22, 22]               0
        MaxPool2d-13           [-1, 24, 11, 11]               0
           Conv2d-14           [-1, 16,

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [8]:


torch.manual_seed(1)
batch_size = 256

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [9]:
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx}')

        ## accumulate correct over each batch
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
    print(f'Epoch {epoch}, train accuracy {100*correct / len(train_loader.dataset)}')

def test(model, device, test_loader, epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'Epoch {epoch}, test accuracy {100. * correct / len(test_loader.dataset)}')
    # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    #     test_loss, correct, len(test_loader.dataset),
    #     100. * correct / len(test_loader.dataset)))

In [10]:

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(1, 16):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader, epoch)

loss=0.1858602911233902 batch_id=234: 100%|██████████| 235/235 [00:19<00:00, 11.86it/s]

Epoch 1, train accuracy 85.27333333333333





Epoch 1, test accuracy 95.96


loss=0.049647506326436996 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.67it/s]

Epoch 2, train accuracy 96.32833333333333





Epoch 2, test accuracy 97.45


loss=0.044559378176927567 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.54it/s]

Epoch 3, train accuracy 97.335





Epoch 3, test accuracy 98.1


loss=0.06529667973518372 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.59it/s]

Epoch 4, train accuracy 97.77666666666667





Epoch 4, test accuracy 97.68


loss=0.01906120963394642 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.53it/s]

Epoch 5, train accuracy 98.01





Epoch 5, test accuracy 98.55


loss=0.06280892342329025 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.53it/s]

Epoch 6, train accuracy 98.22166666666666





Epoch 6, test accuracy 98.73


loss=0.02386438101530075 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.53it/s]

Epoch 7, train accuracy 98.47166666666666





Epoch 7, test accuracy 98.57


loss=0.059703607112169266 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.67it/s]

Epoch 8, train accuracy 98.525





Epoch 8, test accuracy 98.56


loss=0.09337475150823593 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.51it/s]

Epoch 9, train accuracy 98.61





Epoch 9, test accuracy 98.92


loss=0.020893385633826256 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.51it/s]

Epoch 10, train accuracy 98.61333333333333





Epoch 10, test accuracy 99.05


loss=0.02000034786760807 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.51it/s]

Epoch 11, train accuracy 98.785





Epoch 11, test accuracy 98.67


loss=0.024615267291665077 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.47it/s]

Epoch 12, train accuracy 98.86





Epoch 12, test accuracy 98.86


loss=0.08296527713537216 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.59it/s]

Epoch 13, train accuracy 98.795





Epoch 13, test accuracy 98.95


loss=0.040672171860933304 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.41it/s]

Epoch 14, train accuracy 98.91833333333334





Epoch 14, test accuracy 99.0


loss=0.026193417608737946 batch_id=234: 100%|██████████| 235/235 [00:20<00:00, 11.53it/s]

Epoch 15, train accuracy 98.95166666666667





Epoch 15, test accuracy 99.0
