# Adaptive dropout MNIST example

I've pulled this example directly from the [PyTorch examples](https://github.com/pytorch/examples) github.

### Notes

I return the average loss of the training across the whole batch.
A better approach might be to return the average of the last several batches to give a more accurate loss and avoid random fluctuations coming from just the last batch.
I think we can get away with the whole batch average since the first few epochs where the larger training loss changes happen we're usually not too worried about overfitting.

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load up MNIST dataset

In [3]:
batch_size = 16
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        '../data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
    ),
    batch_size=batch_size, shuffle=True, **kwargs,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])),
    batch_size=batch_size, shuffle=True, **kwargs,
)

## Create model

This is where the magic happens.

In [4]:
class Net(nn.Module):
    def __init__(self, dropout=0.):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
#         self.dropout1 = nn.Dropout2d(0.25)
#         self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = dropout

        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        # In the official example this is half the last dropout rate
        x = F.dropout2d(x, self.dropout/2., training=self.training)
        
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        
        x = F.dropout2d(x, self.dropout, training=self.training)
        
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
    def set_dropout(self, dropout):
        self.dropout = dropout

## Construction train/validation loops

Nothing special done here except returning the last loss.

In [16]:
def train(log_interval, model, device, train_loader, optimizer, epoch):
    model.train()
    avg_loss = 0.
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()
            ))
        avg_loss += loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        
    avg_loss /= len(train_loader.dataset)
            
    return avg_loss.item(), correct / len(train_loader.dataset)

In [17]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    return test_loss, correct / len(test_loader.dataset)

## Instantiate

In [21]:
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

# Run training

In [22]:
epochs = 20
log_interval = 1000

In [23]:
for epoch in range(1, epochs + 1):
    train_loss, train_acc = train(log_interval, model, device, train_loader, optimizer, epoch)
    test_loss, test_acc = test(model, device, test_loader)
    scheduler.step()

#     dropout = 1. - (train_loss / test_loss)
    dropout = 1. - (test_acc / train_acc)
    dropout = max(dropout, 0.)  # Ensure it doesn't dip below zero
    dropout = min(dropout, 0.5)  # Set upper limit on loss
    print(f'Train loss/acc: {train_loss:.5f}/{train_acc:.3f}, Test loss/acc: {test_loss:.5f}/{test_acc:.3f}, New dropout: {dropout:.5f}')
    model.set_dropout(dropout)


Test set: Average loss: 0.0505, Accuracy: 9840/10000 (98%)

Train loss/acc: 0.00607/0.97, Test loss/acc: 0.05053/0.98, New dropout: 0.00000

Test set: Average loss: 0.0381, Accuracy: 9893/10000 (99%)

Train loss/acc: 0.00182/0.99, Test loss/acc: 0.03809/0.99, New dropout: 0.00260

Test set: Average loss: 0.0346, Accuracy: 9912/10000 (99%)

Train loss/acc: 0.00083/1.00, Test loss/acc: 0.03456/0.99, New dropout: 0.00507

Test set: Average loss: 0.0363, Accuracy: 9926/10000 (99%)

Train loss/acc: 0.00044/1.00, Test loss/acc: 0.03631/0.99, New dropout: 0.00563

Test set: Average loss: 0.0371, Accuracy: 9918/10000 (99%)

Train loss/acc: 0.00025/1.00, Test loss/acc: 0.03707/0.99, New dropout: 0.00739

Test set: Average loss: 0.0396, Accuracy: 9922/10000 (99%)

Train loss/acc: 0.00016/1.00, Test loss/acc: 0.03961/0.99, New dropout: 0.00727

Test set: Average loss: 0.0401, Accuracy: 9921/10000 (99%)

Train loss/acc: 0.00013/1.00, Test loss/acc: 0.04010/0.99, New dropout: 0.00755

Test set: Av

Exception in thread Thread-113:
Traceback (most recent call last):
  File "/portal/ekpbms1/home/jkahn/miniconda3/envs/ml/lib/python3.7/threading.py", line 917, in _bootstrap_inner
    self.run()
  File "/portal/ekpbms1/home/jkahn/miniconda3/envs/ml/lib/python3.7/threading.py", line 865, in run
    self._target(*self._args, **self._kwargs)
  File "/portal/ekpbms1/home/jkahn/miniconda3/envs/ml/lib/python3.7/site-packages/torch/utils/data/_utils/pin_memory.py", line 25, in _pin_memory_loop
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/portal/ekpbms1/home/jkahn/miniconda3/envs/ml/lib/python3.7/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/portal/ekpbms1/home/jkahn/miniconda3/envs/ml/lib/python3.7/site-packages/torch/multiprocessing/reductions.py", line 294, in rebuild_storage_fd
    fd = df.detach()
  File "/portal/ekpbms1/home/jkahn/miniconda3/envs/ml/lib/python3.7/multiprocessing/resource_sharer.py", line 57, in detach
    w

KeyboardInterrupt: 