Домашка 3

Нужно взять / написать код для ReAct (https://arxiv.org/pdf/2111.12797) в небольшом пайплайне с последней пары и сравнить его по тем же метрикам против Softmax и Monte Carlo Dropout на задаче OOD детекции (CIFAR10 = ID, MNIST = OOD). ReAct относительно простой и требует по сути клиппинга активаций

Дедлайн 21 декабря, сдавать также в своей репе. По 2 таске скоро вернусь с фидбеком

ReAct также есть в OpenOOD, про который мы не говорили на паре https://arxiv.org/pdf/2210.07242

То есть по сути у вас получится сравнение база vs байесовский метод vs небайесовский метод, но более новый

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

torch.manual_seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 128
epochs = 10
learning_rate = 1e-3
mc_samples = 20


In [2]:
transform_cifar = transforms.Compose([
    transforms.ToTensor()
])

transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.expand(3, -1, -1))
])

train_id_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform_cifar
)

test_id_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform_cifar
)

test_ood_dataset = datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform_mnist
)

train_id_loader = DataLoader(
    train_id_dataset,
    batch_size=batch_size,
    shuffle=True
)

test_id_loader = DataLoader(
    test_id_dataset,
    batch_size=batch_size,
    shuffle=False
)

test_ood_loader = DataLoader(
    test_ood_dataset,
    batch_size=batch_size,
    shuffle=False
)


In [4]:
class CNN(nn.Module):
    def __init__(self, num_classes=10, dropout_p=0.3):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(p=dropout_p)

        self.fc1 = nn.Linear(128 * 4 * 4, 256)   
        self.fc2 = nn.Linear(256, num_classes)

        self.react_threshold = None

    def forward(self, x, react=False):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.dropout(x)

        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)

        x = self.pool(F.relu(self.conv3(x)))
        x = self.dropout(x)

        x = x.view(x.size(0), -1)

        h = self.fc1(x)

        if react and self.react_threshold is not None:
            h = torch.clamp(h, max=self.react_threshold)

        h = F.relu(h)
        h = self.dropout(h)

        logits = self.fc2(h)
        return logits


In [5]:
def train(model, train_loader, epochs, learning_rate):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

        avg_loss = total_loss / total
        accuracy = correct / total

        print(
            f"Epoch {epoch}: "
            f"Loss = {avg_loss:.4f}, "
            f"Accuracy = {accuracy * 100:.2f}%"
        )


In [7]:
model = CNN(dropout_p=0.3, num_classes=10)
train(model, train_id_loader, epochs=epochs, learning_rate=learning_rate)


Epoch 1: Loss = 1.7095, Accuracy = 36.98%
Epoch 2: Loss = 1.3540, Accuracy = 50.75%
Epoch 3: Loss = 1.2203, Accuracy = 56.22%
Epoch 4: Loss = 1.1213, Accuracy = 60.06%
Epoch 5: Loss = 1.0526, Accuracy = 62.62%
Epoch 6: Loss = 0.9943, Accuracy = 64.68%
Epoch 7: Loss = 0.9505, Accuracy = 66.43%
Epoch 8: Loss = 0.9099, Accuracy = 67.65%
Epoch 9: Loss = 0.8717, Accuracy = 69.34%
Epoch 10: Loss = 0.8457, Accuracy = 70.17%


In [8]:
def comp_ood_metrics(id_scores, ood_scores):


    y_true = np.concatenate([
        np.zeros_like(id_scores),   # ID = 0
        np.ones_like(ood_scores)    # OOD = 1
    ])

    scores = np.concatenate([id_scores, ood_scores])

    auroc = roc_auc_score(y_true, scores)
    aupr = average_precision_score(y_true, scores)

    fpr, tpr, _ = roc_curve(y_true, scores)
    target_tpr = 0.95
    idxs = np.where(tpr >= target_tpr)[0]

    if len(idxs) > 0:
        fpr95 = fpr[idxs[0]]
    else:
        fpr95 = 1.0

    return auroc, aupr, fpr95


In [9]:
def get_softmax_ood_scores(model, id_loader, ood_loader):
    model.to(device)
    model.eval()

    id_scores = []
    ood_scores = []

    with torch.no_grad():
        for x, _ in id_loader:
            x = x.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            max_probs, _ = probs.max(dim=1)
            scores = 1.0 - max_probs
            id_scores.append(scores.cpu().numpy())

        for x, _ in ood_loader:
            x = x.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            max_probs, _ = probs.max(dim=1)
            scores = 1.0 - max_probs
            ood_scores.append(scores.cpu().numpy())

    id_scores = np.concatenate(id_scores)
    ood_scores = np.concatenate(ood_scores)

    return id_scores, ood_scores


In [10]:
softmax_id_scores, softmax_ood_scores = get_softmax_ood_scores(
    model,
    test_id_loader,
    test_ood_loader
)

softmax_auroc, softmax_aupr, softmax_fpr95 = comp_ood_metrics(
    softmax_id_scores,
    softmax_ood_scores
)

print("Softmax baseline:")
print(f"AUROC: {softmax_auroc:.4f}")
print(f"AUPR:  {softmax_aupr:.4f}")
print(f"FPR@95%TPR: {softmax_fpr95:.4f}")


Softmax baseline:
AUROC: 0.5768
AUPR:  0.5263
FPR@95%TPR: 0.8372


In [11]:
def mc_dropout_entropy(model, x, T=20):
    model.to(device)
    model.train()

    with torch.no_grad():
        probs_T = []
        for _ in range(T):
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            probs_T.append(probs.unsqueeze(0))

        probs_T = torch.cat(probs_T, dim=0)

    p_mean = probs_T.mean(dim=0)
    eps = 1e-8
    entropy = -torch.sum(p_mean * torch.log(p_mean + eps), dim=1)

    return entropy


In [12]:
def get_mcd_ood_scores(model, id_loader, ood_loader, T=20):
    id_scores = []
    ood_scores = []

    for x, _ in id_loader:
        x = x.to(device)
        entropy = mc_dropout_entropy(model, x, T=T)
        id_scores.append(entropy.cpu().numpy())

    for x, _ in ood_loader:
        x = x.to(device)
        entropy = mc_dropout_entropy(model, x, T=T)
        ood_scores.append(entropy.cpu().numpy())

    id_scores = np.concatenate(id_scores)
    ood_scores = np.concatenate(ood_scores)

    return id_scores, ood_scores


In [13]:
mcd_id_scores, mcd_ood_scores = get_mcd_ood_scores(
    model,
    test_id_loader,
    test_ood_loader,
    T=mc_samples
)

mcd_auroc, mcd_aupr, mcd_fpr95 = comp_ood_metrics(
    mcd_id_scores,
    mcd_ood_scores
)

print("Monte Carlo Dropout:")
print(f"AUROC: {mcd_auroc:.4f}")
print(f"AUPR:  {mcd_aupr:.4f}")
print(f"FPR@95%TPR: {mcd_fpr95:.4f}")


Monte Carlo Dropout:
AUROC: 0.6740
AUPR:  0.5900
FPR@95%TPR: 0.7279


In [14]:
def collect_penultimate_activations(model, loader):
    model.to(device)
    model.eval()

    activations = []

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)

            x_feat = model.pool(F.relu(model.conv1(x)))
            x_feat = model.pool(F.relu(model.conv2(x_feat)))
            x_feat = model.pool(F.relu(model.conv3(x_feat)))

            x_feat = x_feat.view(x_feat.size(0), -1)
            h = model.fc1(x_feat)

            activations.append(h.cpu())

    activations = torch.cat(activations, dim=0)
    return activations


In [15]:
train_activations = collect_penultimate_activations(model, train_id_loader)

quantile = 0.99
react_threshold = torch.quantile(train_activations, quantile)

model.react_threshold = react_threshold.to(device)

print(f"ReAct threshold (q={quantile}): {model.react_threshold.item():.4f}")


ReAct threshold (q=0.99): 3.1420


In [16]:
def get_react_ood_scores(model, id_loader, ood_loader):
    model.to(device)
    model.eval()

    id_scores = []
    ood_scores = []

    with torch.no_grad():
        for x, _ in id_loader:
            x = x.to(device)
            logits = model(x, react=True)
            probs = F.softmax(logits, dim=1)
            max_probs, _ = probs.max(dim=1)
            scores = 1.0 - max_probs
            id_scores.append(scores.cpu().numpy())

        for x, _ in ood_loader:
            x = x.to(device)
            logits = model(x, react=True)
            probs = F.softmax(logits, dim=1)
            max_probs, _ = probs.max(dim=1)
            scores = 1.0 - max_probs
            ood_scores.append(scores.cpu().numpy())

    id_scores = np.concatenate(id_scores)
    ood_scores = np.concatenate(ood_scores)

    return id_scores, ood_scores


In [17]:
react_id_scores, react_ood_scores = get_react_ood_scores(
    model,
    test_id_loader,
    test_ood_loader
)

react_auroc, react_aupr, react_fpr95 = comp_ood_metrics(
    react_id_scores,
    react_ood_scores
)

print("ReAct:")
print(f"AUROC: {react_auroc:.4f}")
print(f"AUPR:  {react_aupr:.4f}")
print(f"FPR@95%TPR: {react_fpr95:.4f}")


ReAct:
AUROC: 0.6055
AUPR:  0.5468
FPR@95%TPR: 0.8033


## выводы

- softmax confidence показывает слабое качество ood-детекции из-за переоценки уверенности на данных вне обучающего распределения  
- monte carlo dropout улучшает ood-детекцию за счёт учёта предсказательной неопределённости
- react достигает качества, сопоставимого с monte carlo dropout, без увеличения вычислительной сложности на этапе inference  
- эффект react связан с подавлением аномально больших внутренних активаций, приводящих к ложной уверенности модели  

