In [None]:
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from torchvision import datasets, transforms

In [None]:
# Load the MNIST dataset
train_dataset = datasets.MNIST(
    "data", train=True, download=True, transform=transforms.ToTensor()
)
test_dataset = datasets.MNIST(
    "data", train=False, download=True, transform=transforms.ToTensor()
)

In [None]:
# Define the model
model = nn.Sequential(
    nn.Linear(784, 64),
    nn.Dropout(0.2),
    nn.ReLU(),
    nn.Linear(64, 10),
)

In [None]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [None]:
# Train the model
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

n_epochs = 1
losses = []
eval_losses = []
eval_accuracies = []
for epoch in range(n_epochs):
    model.train()
    progress_bar = tqdm(train_loader, total=len(train_loader))
    for imgs, labels in progress_bar:
        imgs = imgs.view(imgs.shape[0], -1)
        optimizer.zero_grad()
        output = model(imgs)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        rolling_loss = torch.tensor(losses[-100:]).mean()
        progress_bar.set_description(
            f"Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}"
        )

    # Evaluate the model
    model.eval()
    progress_bar = tqdm(test_loader, total=len(test_loader))
    for imgs, labels in progress_bar:
        imgs = imgs.view(imgs.shape[0], -1)
        output = model(imgs)
        loss = loss_fn(output, labels)
        acc = accuracy(output, labels)
        eval_losses.append(loss.item())
        eval_accuracies.append(acc.item())
        rolling_loss = torch.tensor(eval_losses[-100:]).mean()
        rolling_acc = torch.tensor(eval_accuracies[-100:]).mean()
        progress_bar.set_description(
            f"Epoch {epoch+1}/{n_epochs}, loss: {rolling_loss.item():.4f}, acc: {rolling_acc.item():.4f}"
        )

# uncertainty estimation with monte carlo dropout

the model is trained with dropout, and the uncertainty is estimated by running the model multiple times with dropout enabled.

In [None]:
import numpy as np

X = test_dataset.data.float().view(-1, 784) / 255.0
y_true = test_dataset.targets

model.train()  # Set the model to training mode so that dropout is applied
y_mc = torch.stack([model(X) for _ in range(100)])

y_mean = y_mc.mean(dim=0).detach().numpy()
y_std = y_mc.std(dim=0).detach().numpy()
y_prob = nn.functional.softmax(torch.tensor(y_mean), dim=1).numpy()

entropy = -(y_std * np.log(y_std)).sum(axis=1)

In [None]:
n_examples = 5

In [None]:
import matplotlib.pyplot as plt

most_confident_indices = entropy.argsort()[:n_examples]
fig, axs = plt.subplots(1, n_examples, figsize=(12, 4))
for i, ax in enumerate(axs):
    idx = most_confident_indices[i]
    ax.imshow(test_dataset.data[idx], cmap="gray")
    ax.set_title(f"Entropy {entropy[idx]:.2f}, class {y_mean[idx].argmax()}")
    ax.axis("off")
plt.tight_layout()

# save the plot as an svg file
plt.savefig("../images/most-confident.svg")

In [None]:
import matplotlib.pyplot as plt

least_confident_indices = entropy.argsort()[-n_examples:]
fig, axs = plt.subplots(1, n_examples, figsize=(12, 4))
for i, ax in enumerate(axs):
    idx = least_confident_indices[i]
    ax.imshow(test_dataset.data[idx], cmap="gray")
    ax.set_title(f"Entropy {entropy[idx]:.2f}, class {y_mean[idx].argmax()}")
    ax.axis("off")
plt.tight_layout()

# save the plot as an svg file
plt.savefig("../images/least-confident.svg")

In [None]:
# find high confidence incorrect predictions
incorrect_indices = np.where(y_mean.argmax(axis=1) != y_true.numpy())[0]
sorted_by_entropy = incorrect_indices[np.argsort(entropy[incorrect_indices])]

fig, axs = plt.subplots(1, n_examples, figsize=(12, 4))
for i, ax in enumerate(axs):
    idx = sorted_by_entropy[i]
    ax.imshow(test_dataset.data[idx], cmap="gray")
    ax.set_title(f"Entropy {entropy[idx]:.2f}, class {y_mean[idx].argmax()}")
    ax.axis("off")
plt.tight_layout()

# save the plot as an svg file
plt.savefig("../images/most-confident-incorrect.svg")