In [1]:
import torch
import torchvision
import numpy as np
import random

In [2]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
device = "cuda"

random_seed = 6
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 10)
        self.activation = nn.ELU()

        self.cf = nn.Linear(10, 50)

    def forward(self, x, pred_prob=None):
        x = self.activation(F.max_pool2d(self.conv1(x), 2))
        x = self.activation(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = self.activation(self.fc1(x))

        if pred_prob is not None:
            pred_prob = torch.nn.functional.softmax(pred_prob, 1)
            x += self.activation(self.cf(pred_prob))

        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x)

In [5]:
network = Net().to(device)
optimizer = optim.Adam(network.parameters(), lr=learning_rate)

In [6]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [7]:
def inference_train(x, N = 3):
    preds = []
    pred = None
    for _ in range(N):
        pred = network(x, pred)
        preds.append(pred)
    return torch.cat(preds)

def inference_test(x, N = 3):
    preds = []
    pred = None
    for _ in range(N):
        pred = torch.nn.functional.softmax(network(x, pred), 1)
        preds.append(pred)
    return sum(preds) / N

In [8]:
def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):

    data = data.to(device)
    target = target.to(device)

    optimizer.zero_grad()
    output = inference_train(data)

    loss = F.nll_loss(output, target.tile(3))
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
#       torch.save(network.state_dict(), '/vanilla_model.pth')
#       torch.save(optimizer.state_dict(), '/vanilla_optimizer.pth')

def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data = data.to(device)
      target = target.to(device)
      output = inference_test(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [9]:
test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

  return F.log_softmax(x)



Test set: Avg. loss: -0.0985, Accuracy: 984/10000 (10%)


Test set: Avg. loss: -0.9608, Accuracy: 9713/10000 (97%)


Test set: Avg. loss: -0.9696, Accuracy: 9745/10000 (97%)


Test set: Avg. loss: -0.9667, Accuracy: 9736/10000 (97%)



In [17]:
for param_group in optimizer.param_groups:
    param_group['lr'] = 1e-3

for _ in range(2):
  train(epoch)
  test()

  return F.log_softmax(x)







Test set: Avg. loss: -0.9867, Accuracy: 9891/10000 (99%)



In [18]:
train_fashion_loader = torch.utils.data.DataLoader(
  torchvision.datasets.FashionMNIST('./FMNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_fashion_loader = torch.utils.data.DataLoader(
  torchvision.datasets.FashionMNIST('./FMNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [19]:
# Computing OOD metrics

network.eval()

uncertainties = np.array([])
labels = np.array([])
eps = 1e-10

with torch.no_grad():
  for data, target in test_loader:

    data = data.to(device)
    prob = inference_test(data)
    uncertainty = (-prob * torch.log(prob + eps)).sum(dim=1).cpu().detach().numpy()
    label = np.zeros_like(uncertainty)

    uncertainties = np.concatenate([uncertainties, uncertainty])
    labels = np.concatenate([labels, label])

with torch.no_grad():
  for data, target in test_fashion_loader:

    data = data.to(device)

    prob = inference_test(data)
    uncertainty = (-prob * torch.log(prob + eps)).sum(dim=1).cpu().detach().numpy()
    label = np.ones_like(uncertainty)

    uncertainties = np.concatenate([uncertainties, uncertainty])
    labels = np.concatenate([labels, label])

  return F.log_softmax(x)


In [20]:
import sklearn.metrics
roc_auc = sklearn.metrics.roc_auc_score(labels, uncertainties)
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, uncertainties)
pr_auc = sklearn.metrics.auc(recall, precision)
roc_auc, pr_auc

(0.974304505, 0.9800322945563024)