In [1]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

torch.set_default_dtype(torch.float64)

torch.manual_seed(6666)

train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

batch_size = 64

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

class ConvNet(torch.nn.Module):
    def __init__(self, output=10):
        super(ConvNet, self).__init__()        
        self.conv1 = torch.nn.Conv2d(1, 4, kernel_size=7, padding=0, stride=3)
        self.flatten = torch.nn.Flatten(1)
        self.fc1 = torch.nn.Linear(256, 100)
        self.fc2 = torch.nn.Linear(100, output)

    def forward(self, x):
        x = self.conv1(x)
        x = x * x
        x = self.flatten(x)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        
        return x


def train(model, train_loader, criterion, optimizer, n_epochs=10):
    model.train()
    for epoch in range(1, n_epochs+1):

        train_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / len(train_loader)

        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
    
    model.eval()
    return model


model = ConvNet()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = train(model, train_loader, criterion, optimizer, 10)

Epoch: 1 	Training Loss: 0.410645
Epoch: 2 	Training Loss: 0.141948
Epoch: 3 	Training Loss: 0.092934
Epoch: 4 	Training Loss: 0.070609
Epoch: 5 	Training Loss: 0.057050
Epoch: 6 	Training Loss: 0.046860
Epoch: 7 	Training Loss: 0.041463
Epoch: 8 	Training Loss: 0.035764
Epoch: 9 	Training Loss: 0.032735
Epoch: 10 	Training Loss: 0.027682


In [2]:
def test(model, test_loader, criterion):
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    model.eval()

    for data, target in test_loader:
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item()
        
        _, pred = torch.max(output, 1)
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    test_loss = test_loss/len(test_loader)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% ' 
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )
    
test(model, test_loader, criterion)

Test Loss: 0.091254

Test Accuracy of 0: 98% (969/980)
Test Accuracy of 1: 99% (1128/1135)
Test Accuracy of 2: 98% (1020/1032)
Test Accuracy of 3: 97% (982/1010)
Test Accuracy of 4: 99% (975/982)
Test Accuracy of 5: 97% (873/892)
Test Accuracy of 6: 98% (939/958)
Test Accuracy of 7: 97% (998/1028)
Test Accuracy of 8: 97% (947/974)
Test Accuracy of 9: 97% (980/1009)

Test Accuracy (Overall): 98% (9811/10000)


In [3]:
for i, f in enumerate(model.conv1.weight.detach().numpy()):
    np.savetxt(f"CW{i}.csv", f[0], delimiter=",")
np.savetxt(f"CB.csv", model.conv1.bias.detach().numpy(), delimiter=",")

np.savetxt(f"LW1.csv", model.fc1.weight.detach().numpy(), delimiter=",")
np.savetxt(f"LB1.csv", model.fc1.bias.detach().numpy(), delimiter=",")

np.savetxt(f"LW2.csv", model.fc2.weight.detach().numpy(), delimiter=",")
np.savetxt(f"LB2.csv", model.fc2.bias.detach().numpy(), delimiter=",")

In [4]:
from sklearn.covariance import LedoitWolf

class MahaDist(torch.nn.Module):
    def __init__(self, output=10):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 4, kernel_size=7, padding=0, stride=3)
        self.flatten = torch.nn.Flatten(1)
        self.fc1 = torch.nn.Linear(256, 100)
        self.fc2 = torch.nn.Linear(100, output)

    def forward_features(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        return x

maha = MahaDist()
model.conv1 = model.conv1
maha.conv1.bias = model.conv1.bias
maha.fc1.weight = model.fc1.weight
maha.fc1.bias = model.fc1.bias
maha.fc2.weight = model.fc2.weight
maha.fc2.bias = model.fc2.bias

@torch.no_grad()
def fit_mahalanobis(maha, loader, num_classes=10):
    maha.eval()
    features = []
    labels = []

    for x, y in loader:
        f = maha.forward_features(x).numpy()
        features.append(f)
        labels.append(y.numpy())

    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)

    # class means
    means = []
    for c in range(num_classes):
        means.append(features[labels == c].mean(axis=0))
    means = np.stack(means, axis=0)

    lw = LedoitWolf().fit(features)     # Σ estimate
    precision = lw.precision_        # Σ^{-1}
    return means, precision

def mahalanobis_dist_diag(features, means, precision_diag):
    diff = features[:, None, :] - means[None, :, :]
    dists = ((diff * diff) * precision_diag[None, None, :]).sum(axis=2)
    return -dists.min(axis=1)

@torch.no_grad()
def get_dists_diag(maha, loader, means, precision_diag):
    maha.eval()
    dists = []
    ys = []
    for x, y in loader:
        f = maha.forward_features(x).detach().cpu().numpy()
        s = mahalanobis_dist_diag(f, means, precision_diag)
        dists.append(s)
        ys.append(y.detach().cpu().numpy())
    return np.concatenate(dists), np.concatenate(ys)

In [5]:
from sklearn.metrics import roc_auc_score, average_precision_score

ood_data = datasets.ImageFolder(root="data/notMNIST_small",
                                transform=transforms.Compose([
                                        transforms.Grayscale(num_output_channels=1),
                                        transforms.Resize((28, 28)),
                                        transforms.ToTensor(),
]))
ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=batch_size, shuffle=False)

means, precision = fit_mahalanobis(maha, train_loader, num_classes=10)
precision_diag = np.diag(precision).copy()
precision_diag = np.maximum(precision_diag, 1e-12)
train_dists, _ = get_dists_diag(maha, train_loader, means, precision_diag)
test_dists,  _ = get_dists_diag(maha, test_loader,  means, precision_diag)
ood_dists,   _ = get_dists_diag(maha, ood_loader,   means, precision_diag)
tau = np.quantile(train_dists, 0.05) # threshold at 95% TPR on ID (i.e., accept 95% of ID)

# predicted as OOD if dist < tau
fpr = (ood_dists >= tau).mean()    # OOD incorrectly accepted as ID
tpr = (test_dists >= tau).mean()    # ID correctly accepted as ID
print("FPR on OOD =", fpr)
print("TPR on ID =",tpr)
print("Min OOD dist =",ood_dists.min())

y_true = np.concatenate([
    np.ones_like(test_dists),
    np.zeros_like(ood_dists)
])

y_dist = np.concatenate([
    test_dists,
    ood_dists
])

auroc = roc_auc_score(y_true, y_dist)
aupr  = average_precision_score(y_true, y_dist)

print(f"AUROC = {auroc:.4f}")
print(f"AUPR  = {aupr:.4f}")
print(f"tau = {tau:.4f}")

FPR on OOD = 0.013458662678914763
TPR on ID = 0.9556
Min OOD dist = -234068.04295735466
AUROC = 0.9971
AUPR  = 0.9935
tau = -7324.3396


In [6]:
np.savetxt(f"Means.csv", means, delimiter=",")
np.savetxt(f"Precision.csv", precision_diag, delimiter=",")

In [7]:
import math
import numpy as np
import torch

torch.set_default_dtype(torch.float64)

def g_gaussian(delta: float) -> float:
    return math.sqrt(2.0 * math.log(1.25 / delta))

@torch.no_grad()
def collect_logits(model, loader, device="cpu", max_batches=None):
    model.eval()
    outs = []
    for b, (x, _) in enumerate(loader):
        x = x.to(device)
        outs.append(model(x).detach().cpu())
        if max_batches is not None and (b + 1) >= max_batches:
            break
    return torch.cat(outs, dim=0)

def train_one(model_ctor, train_ds, batch_size=64, n_epochs=10, lr=1e-3, seed0=6666, device="cpu"):
    torch.manual_seed(seed0)
    np.random.seed(seed0)

    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=False)

    m = model_ctor().to(device)
    crit = torch.nn.CrossEntropyLoss()
    opt = torch.optim.Adam(m.parameters(), lr=lr)

    m.train()
    for _ in range(n_epochs):
        for x, y in train_loader:
            x = x.to(device); y = y.to(device)
            opt.zero_grad()
            loss = crit(m(x), y)
            loss.backward()
            opt.step()

    m.eval()
    return m

def remove_one(train_ds, ridx: int):
    idxs = list(range(len(train_ds)))
    idxs.pop(int(ridx))
    return torch.utils.data.Subset(train_ds, idxs)

def calibrate_sigma_wp(
    train_data,
    test_data,
    model_ctor,
    batch_size=64,
    n_epochs=10,
    lr=1e-3,
    device="cpu",
    seed0=6666,
    M=10,                  # number of remove-one trials
    max_batches=50,        # number of test batches to use (None = full test)
    q=0.99,                # Delta quantile
    eps=16.0,
    delta=1e-5,
):
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

    print("[calib] train baseline on FULL train set...")
    base_model = train_one(model_ctor, train_data, batch_size, n_epochs, lr, seed0, device)
    base_logits = collect_logits(base_model, test_loader, device=device, max_batches=max_batches)
    N = base_logits.shape[0]

    rng = np.random.default_rng(seed0)
    remove_idxs = rng.choice(len(train_data), size=M, replace=False)

    all_l2 = []
    print(f"[calib] {M} trials remove-one, eval N={N} test samples (max_batches={max_batches})")
    for t, ridx in enumerate(remove_idxs, start=1):
        adj_ds = remove_one(train_data, ridx)
        adj_model = train_one(model_ctor, adj_ds, batch_size, n_epochs, lr, seed0, device)
        adj_logits = collect_logits(adj_model, test_loader, device=device, max_batches=max_batches)
        diff = base_logits - adj_logits
        l2 = torch.linalg.vector_norm(diff, ord=2, dim=1).numpy()
        all_l2.append(l2)
        print(f"  [t={t}/{M}] ridx={int(ridx)} meanΔ={l2.mean():.4g}")

    all_l2 = np.concatenate(all_l2, axis=0)
    Delta_q = float(np.quantile(all_l2, q))

    factor = g_gaussian(delta)
    sigma_wp = (Delta_q / eps) * factor

    print(f"Test subset size N = {N} (max_batches={max_batches})")
    print(f"Quantile q={q}: Delta_wp(q) = {Delta_q:.6g}")
    print(f"DP params: (eps={eps}, delta={delta})  factor={factor:.6g}")
    print(f"sigma_wp = {sigma_wp:.6g}")

    return {
        "Delta_q": Delta_q,
        "sigma_wp": sigma_wp,
        "eps": eps,
        "delta": delta,
        "q": q,
        "M": M,
        "test_N": int(N),
        "test_max_batches": max_batches,
        "train_epochs": n_epochs,
        "lr": lr,
        "seed": seed0,
    }

calibrate_sigma_wp(train_data=train_data, test_data=test_data, model_ctor=lambda: ConvNet())

[calib] train baseline on FULL train set...
[calib] 10 trials remove-one, eval N=3200 test samples (max_batches=50)
  [t=1/10] ridx=38163 meanΔ=12.29
  [t=2/10] ridx=7875 meanΔ=10.75
  [t=3/10] ridx=19170 meanΔ=10.68
  [t=4/10] ridx=5466 meanΔ=11.45
  [t=5/10] ridx=12132 meanΔ=11.88
  [t=6/10] ridx=43950 meanΔ=10.73
  [t=7/10] ridx=51948 meanΔ=10.9
  [t=8/10] ridx=4263 meanΔ=9.847
  [t=9/10] ridx=57318 meanΔ=9.472
  [t=10/10] ridx=9135 meanΔ=9.704
Test subset size N = 3200 (max_batches=50)
Quantile q=0.99: Delta_wp(q) = 35.4959
DP params: (eps=16.0, delta=1e-05)  factor=4.84481
sigma_wp = 10.7482


{'Delta_q': 35.49586045160214,
 'sigma_wp': 10.748158219789284,
 'eps': 16.0,
 'delta': 1e-05,
 'q': 0.99,
 'M': 10,
 'test_N': 3200,
 'test_max_batches': 50,
 'train_epochs': 10,
 'lr': 0.001,
 'seed': 6666}