In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import classification as C
from nn_ood.posteriors import SCOD
from nn_ood.distributions import CategoricalLogit
import numpy as np
import torchvision
from tqdm import tqdm

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.num_classes = 10

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layers = nn.Sequential(*[
            nn.Conv2d(1, 8, 3, 1), # (b, 8, 26, 26)
            nn.ReLU(),
            nn.Conv2d(8, 16, 5, 1), # (b, 16, 22, 22)
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # (b, 16, 11, 11)
            nn.Flatten(), # (b, 16*11*11)
            nn.Linear(16*11*11, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        ])

        self.metrics = nn.ModuleDict({
            'acc': C.MulticlassAccuracy(10),
            'f1': C.MulticlassF1Score(10, average='micro'),
        })

    def forward(self, x):
        # x: (batch_size, 1, 28, 28)
        return self.layers(x)

In [3]:
# Training

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

def test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)
    return test_loss, acc

model = Net()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))
                               ])),
    batch_size=64, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('data', train=False, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=1000, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(1):
    print(f'Epoch {epoch}')
    train(model, train_loader, optimizer, criterion, device)
    test_loss, acc = test(model, test_loader, criterion, device)
    print(f'Epoch {epoch}: Test loss: {test_loss:.4f}, Test accuracy: {acc:.4f}')

Epoch 0


100%|██████████| 938/938 [00:09<00:00, 98.91it/s] 


Epoch 0: Test loss: 0.0000, Test accuracy: 0.9857


In [4]:
mnist_test = torchvision.datasets.MNIST('data', train=False, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
            ]))

wrapped_model = SCOD(model, CategoricalLogit())
wrapped_model.process_dataset(mnist_test)

computing basis
using T = 64


100%|██████████| 10000/10000 [01:46<00:00, 93.88it/s]
torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2192.)
  X,_ = torch.triangular_solve(U.t() @ self.W, T) # (k, N)


In [8]:
wrapped_model.eval()

for data, target in test_loader:
    data, target = data.to(device), target.to(device)
    output, unc = wrapped_model(data)
    print(output.shape, unc.shape)
    print(unc)
    break


torch.Size([1000, 10]) torch.Size([1000])
tensor([1.4033e+00, 4.4059e-01, 2.3828e-01, 1.2598e+00, 4.8967e+00, 2.6099e-01,
        1.1167e+00, 2.7186e-01, 2.2019e+00, 3.8087e-01, 2.0626e+00, 6.4375e-02,
        2.6413e-01, 1.5518e+00, 1.0360e+00, 8.2823e-01, 2.2509e+01, 1.1224e+00,
        2.0028e-01, 7.6915e-01, 2.7657e-01, 1.6222e+00, 1.8835e+01, 1.8802e-01,
        1.2063e+00, 4.4656e-01, 2.2343e+00, 1.3389e+00, 1.2773e+00, 7.2326e+00,
        1.2717e+01, 1.6522e+01, 2.8836e-01, 4.7021e-01, 2.0457e+00, 1.9605e+01,
        5.1310e-01, 5.6402e+00, 9.9181e-02, 6.3184e-01, 1.5551e+00, 4.6970e-01,
        2.3044e-01, 1.3637e+00, 1.7511e+00, 4.3291e-01, 8.4122e+00, 1.3451e+00,
        9.4520e-01, 3.1253e-01, 8.9526e-01, 3.0207e-01, 1.2391e+00, 6.0206e+00,
        6.6202e-01, 5.9056e-01, 1.0526e+01, 9.8520e-01, 2.3890e+00, 1.9355e+00,
        4.5551e-01, 3.8177e+00, 3.9123e-01, 6.5065e-01, 1.4731e+01, 1.9901e+01,
        2.8396e-01, 1.0213e+00, 9.4839e+00, 8.7186e-01, 9.2660e+00, 1.4080e+00

In [9]:
# Randomly set a batch
data = torch.rand(1000, 1, 28, 28).to(device)
output_r, unc_r = wrapped_model(data)
print(unc_r)

tensor([7.3582, 7.9131, 6.7250, 7.1332, 6.8550, 7.9214, 7.1222, 7.9870, 7.5219,
        6.9595, 7.0359, 7.3568, 7.6546, 7.3099, 7.2918, 7.6565, 6.9031, 6.7241,
        7.4198, 6.9781, 7.7382, 7.5573, 7.4055, 7.2572, 7.3463, 7.2791, 7.3249,
        7.2255, 6.6497, 6.9350, 7.0896, 8.3361, 7.5479, 7.7242, 7.9768, 6.9496,
        7.2433, 7.3700, 7.2580, 8.0888, 8.0518, 7.1969, 7.3811, 6.9334, 6.5396,
        7.5903, 7.4960, 7.6984, 7.7670, 7.2601, 6.9021, 7.0287, 7.9619, 7.8759,
        7.1458, 6.9842, 7.5583, 7.0821, 7.6115, 7.7466, 7.5403, 7.8372, 8.1009,
        7.4912, 7.5474, 6.7825, 7.4189, 7.5562, 6.6698, 7.6832, 6.9256, 7.4270,
        8.1530, 7.3856, 7.1763, 7.3068, 8.0920, 6.9831, 7.2108, 7.0776, 7.4730,
        7.5372, 7.1913, 7.5240, 8.0783, 7.1096, 7.4955, 7.2483, 7.5776, 7.6372,
        7.8779, 8.1441, 7.2814, 7.5347, 7.7365, 6.8730, 6.7651, 7.9368, 7.6651,
        7.6879, 7.4084, 7.0093, 7.6615, 7.0789, 7.2043, 7.8899, 7.3410, 7.2440,
        7.9782, 7.4676, 7.7927, 7.6790, 

In [13]:
threshold = 6.5
print((unc_r > threshold).sum())
print((unc < threshold).sum())

tensor(978)
tensor(904)
