# Evidential Deep Learning with Pytorch
This example will implement the paper "Evidential Deep Learning to Quantify Classification Uncertainty" on Pytorch.
#### References
* [Baseline MNIST train with Pytorch](https://nextjournal.com/gkoehler/pytorch-mnist)

In [1]:
import mnist_data_pytorch as data
import evd_losses as evd_loss
import scipy.ndimage as nd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
print('Pytorch version:', torch.__version__)
# Tensorboard
from torch.utils.tensorboard import SummaryWriter
!rm -rf ../runs
writer = SummaryWriter('./runs/train')

# Metaparameters
num_epochs = 25
num_classes = 10

Device: cpu
Pytorch version: 1.4.0


#### Some helper functions

In [2]:
def one_hot(labels, num_classes=10):
    """
    Convert labels to one_hot_encoding
    """
    # Convert to One Hot Encoding
    y = torch.eye(num_classes)
    return y[labels]

def rotate_img(x, deg):
    """
    Rotate image (used to test uncertainty)
    """
    return nd.rotate(x.reshape(28, 28), deg, reshape=False).ravel()

#### Define Model

In [3]:
class LeNet(nn.Module):
    def __init__(self, dropout=False, num_classes=10):
        super().__init__()
        self.use_dropout = dropout
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.fc1 = nn.Linear(20000, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 1))
        x = F.relu(F.max_pool2d(self.conv2(x), 1))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        if self.use_dropout:
            x = F.dropout(x, training=self.training)
        
        # logits (positive or negative numbers)
        logits = self.fc2(x)
        # Changes
        evidence = F.relu(logits)
        alpha = evidence + 1
        uncertainty = self.num_classes / torch.sum(alpha, dim=1, keepdim=True)
        return logits, evidence, alpha, uncertainty

model = LeNet()
#writer.add_graph(model, X.unsqueeze(0))
model = model.to(device)

#### Define Loss Function

In [4]:
criterion = evd_loss.edl_mse_loss

#### Define Optimizer and Learning Rate Scheduler

In [5]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

#### Train Model

In [None]:
model.train() 
for epoch in range(num_epochs):
    print("Epoch {}/{}".format(epoch, num_epochs - 1))
    print("-" * 10)
    running_loss = 0.0
    running_corrects = 0.0
    correct = 0
    # Iterate over the data
    for idx_sample, (inputs, labels) in enumerate(data.dataloaders['train']):
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # Convert y to one-hot and send to GPU
        y = one_hot(labels)
        y = y.to(device)
        logits, evidence, alpha, uncertainty = model(inputs)
        _, preds = torch.max(logits, 1)
        # Calculate loss
        loss = criterion(logits, y.float(), evidence, alpha, epoch, num_classes, 10, device)

        match = torch.reshape(torch.eq(preds, labels).float(), (-1, 1))
        acc = torch.mean(match)

        total_evidence = torch.sum(evidence, 1, keepdim=True)
        mean_evidence = torch.mean(total_evidence)
        mean_evidence_succ = torch.sum(torch.sum(evidence, 1, keepdim=True) * match) / torch.sum(match + 1e-20)
        mean_evidence_fail = torch.sum(torch.sum(evidence, 1, keepdim=True) * (1 - match)) / (torch.sum(torch.abs(1 - match)) + 1e-20)
        
        loss.backward()
        optimizer.step()
    # 
    # statistics
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data)
    scheduler.step()
    epoch_loss = running_loss / len(data.dataloaders['train'].dataset)
    epoch_acc = running_corrects.double() / len(data.dataloaders['train'].dataset)
    
    losses["loss"].append(epoch_loss)
    losses["phase"].append(phase)
    losses["epoch"].append(epoch)
    accuracy["accuracy"].append(epoch_acc.item())
    accuracy["epoch"].append(epoch)
    accuracy["phase"].append(phase)

    print("{} loss: {:.4f} acc: {:.4f}".format(
        phase.capitalize(), epoch_loss, epoch_acc))

Epoch 0/24
----------
