In [1]:
import torch
import torchvision
import PIL
import numpy
import collections
import types
from scipy.spatial import distance
import os

In [2]:
class Network(torch.nn.Module):
    def __init__(self, K, O):
        super(Network, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5,5), stride=1, padding=2)
        self.conv2 = torch.nn.Conv2d(in_channels=self.conv1.out_channels, out_channels=32, kernel_size=(5,5), stride=1, padding=2)
        self.pool = torch.nn.MaxPool2d(kernel_size=(2,2), stride=2)
        self.activation = torch.nn.Sigmoid()
        self.bn = torch.nn.BatchNorm2d(self.conv2.out_channels)
        
        self.fc1 = torch.nn.Linear(7*7*32, K, bias=True)
        self.fc2 = torch.nn.Linear(K, O)
        
    def forward(self, x):
        a = self.activation(self.pool(self.conv1(x)))
        a = self.activation(self.bn(self.pool(self.conv2(a))))
        a = torch.flatten(a,1)    
        return self.fc2(self.activation(self.fc1(a)))

In [3]:
# one hot vector for targets
class TargetVector():
    def __init__(self, known_targets = (4,5,8,9), unknown_targets = (0,2,3,7), ignored_targets = (1,6)):
        self.known_targets = known_targets
        self.unknown_targets = unknown_targets
        self.ignored_targets = ignored_targets
        
        # generate one hot representation of known targets
        self.one_hot_known = numpy.eye(len(known_targets))
        self.target_known = {k:self.one_hot_known[i] for i,k in enumerate(self.known_targets)}
        
        # generate one hot representation of unknown targets
        self.target_unknown = numpy.ones(len(known_targets)) / len(known_targets)
        
    # creates target batch for given targets    
    def __call__(self, inputs, targets):
        valid = []
        vectors = []
    
        # split off unknown unknown samples
        for i, t in enumerate(targets):
            if t in self.known_targets:
                vectors.append(self.target_known[int(t)])
                valid.append(inputs[i].numpy())
            elif t in self.unknown_targets:
                vectors.append(self.target_unknown)
                valid.append(inputs[i].numpy())
                
        # filtered original inputs and one hot vector
        return torch.tensor(valid), torch.tensor(vectors)
    
    # predicts class & its confidence
    def predict(self, logits):
        # softmax over logits in batch
        confidences = torch.nn.functional.softmax(logits, dim=1)
        
        # indexes of the prediction
        indexes = torch.argmax(logits, dim=1)
        
        # confidences for the predicted values
        max_confidences = confidences[range(len(logits)), indexes]
        
        # return tuple predicted class and confidence for a batch
        return [(self.known_targets[indexes[i]], max_confidences[i]) for i in range(len(logits))]
    
    # computes the confidence metric for given samples
    def confidences(self, logits, targets):
        # softmax over logits in batch
        confidences = torch.nn.functional.softmax(logits, dim=1).numpy()
        
        # return confidence of correct class for known samples and 1-max(confidences) + 1/0 for unknown samples
        return [
            numpy.sum(confidences[i] * self.target_known[int(targets[i])])
            if targets[i] in self.known_targets
            else 1 - numpy.max(confidences[i]) + 1./len(self.known_targets)
            for i in range(len(logits))
        ]
        
def adapted_softmax_loss(self, logits, targets):
    return - torch.mean(torch.nn.functional.log_softmax(logits, dim=1) * targets)

In [4]:
class AdaptedSoftmaxFunction(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, logits, targets):
        # compute log probabilities via log softmax
        log_y = torch.log_softmax(logits, dim=1)
        ctx.save_for_backward(log_y, targets)
        
        # return the computed loss
        return - torch.mean(log_y * targets)
    
    @staticmethod
    def backward(ctx, result):
        # get stored result from forward pass
        log_y, targets = ctx.saved_tensors
        
        # compute probablities from log probabilities
        y = torch.exp(log_y)
        
        # return y-t as Jacobian for the logits, None for the targets
        return y - targets, None

In [5]:
# get the MNIST dataset
transform = torchvision.transforms.ToTensor()
train_set = torchvision.datasets.MNIST(
    root="/tmp/MNIST",
    train=True, download=True, 
    transform=transform
)

test_set = torchvision.datasets.MNIST(
    root="/tmp/MNIST",
    train=False, download=True,
    transform=transform
)

# loaders
train_loader = torch.utils.data.DataLoader(
    train_set, shuffle=True, batch_size=32
)

test_loader = torch.utils.data.DataLoader(
    test_set, shuffle=False, batch_size=32
)

In [None]:
# training

network = Network(50, 4)

loss = AdaptedSoftmaxFunction.apply

optimizer = torch.optim.SGD(
    params=network.parameters(),
    lr=0.01, momentum=0.9
    )

targets = TargetVector()

threshold = 0.5
epochs = 100

for epoch in range(epochs):
    for x, t in train_loader:
        optimizer.zero_grad()
        x,t = targets(x,t)
        z = network(x)
        J = loss(z, t)
        J.backward()
        optimizer.step()
        
    # evaluation correctly classified and total number of samples
    k, ku, uu = 0, 0, 0
    nk, nku, nuu = 0, 0, 0
    
    # evaluation: average confidence
    conf = 0.
    with torch.no_grad():
        for x,t in test_loader:
            z = network(x)
            
            # compute predicted classes and confidences
            predictions = targets.predict(z)
            
            # add confidence metric for a batch
            conf += numpy.sum(targets.confidences(z,t))
            
            # compute accuracy
            for i in range(len(t)):
                # iterate over all samples in batch
                if t[i] in targets.known_targets:
                    # known sample: correctly classified?
                    if predictions[i][0] == int(t[1]) and predictions[i][1] >= threshold:
                        k += 1
                    nk += 1
                elif t[i] in targets.unknown_targets:
                    # known unknown sample: correctly rejected?
                    if predictions[i][1] < threshold:
                        ku += 1
                    nku += 1
                else:
                    # unknown unknown sample: correctly rejected?
                    if predictions[i][1] < threshold:
                        uu += 1
                    nuu += 1
        
        # print epochs and metrics
        print(F"Epoch {epoch}, test known: {k/nk*100.:1.2f} %, known unknown: {ku/nku*100.:1.2f} %, unknown unknown: {uu/nuu*100.:1.2f} %, average confidence: {conf/len(test_set):1.5f}")

Epoch 0, test known: 11.69 %, known unknown: 93.51 %, unknown unknown: 76.06 %, average confidence: 0.82891
Epoch 1, test known: 12.50 %, known unknown: 99.73 %, unknown unknown: 42.47 %, average confidence: 0.90159
Epoch 2, test known: 12.68 %, known unknown: 98.81 %, unknown unknown: 45.34 %, average confidence: 0.90613
Epoch 3, test known: 12.76 %, known unknown: 99.46 %, unknown unknown: 28.62 %, average confidence: 0.89500
Epoch 4, test known: 12.65 %, known unknown: 99.41 %, unknown unknown: 40.56 %, average confidence: 0.90864
Epoch 5, test known: 12.68 %, known unknown: 99.23 %, unknown unknown: 29.96 %, average confidence: 0.89688
Epoch 6, test known: 12.73 %, known unknown: 99.68 %, unknown unknown: 40.04 %, average confidence: 0.90768
Epoch 7, test known: 12.65 %, known unknown: 99.38 %, unknown unknown: 33.83 %, average confidence: 0.90036
Epoch 8, test known: 12.70 %, known unknown: 99.73 %, unknown unknown: 46.97 %, average confidence: 0.91815
Epoch 9, test known: 12.70 %