In [None]:
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, roc_curve

# ============================================================
# 1. Setup device and model
# ============================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load pretrained ResNet-50
model = torchvision.models.resnet50(pretrained=True).to(device)
model.eval()

# Penultimate layer for ReAct
feature_layer = model.avgpool   # <-- this is the layer to clip

# ============================================================
# 2. Define transforms (same normalization as ImageNet)
# ============================================================
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# ============================================================
# 3. Utility functions for metrics
# ============================================================
def compute_msp_scores(logits):
    probs = F.softmax(logits, dim=-1)
    return probs.max(dim=-1).values.cpu().numpy()

def compute_auroc(id_scores, ood_scores):
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    return roc_auc_score(labels, scores) * 100

def compute_fpr95(id_scores, ood_scores):
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    fpr, tpr, _ = roc_curve(labels, scores)
    idx = np.argmax(tpr >= 0.95)
    return fpr[idx] * 100

# ============================================================
# 4. Estimate ReAct clipping threshold τ
# ============================================================
def estimate_tau(model, loader, feature_layer, percentile=90, max_batches=10):
    activations = []
    def hook_fn(_, __, output):
        activations.append(output.flatten())
    handle = feature_layer.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        for i, (x, _) in enumerate(loader):
            _ = model(x.to(device))
            if i >= max_batches:
                break
    handle.remove()
    acts = torch.cat(activations)
    tau = torch.quantile(acts.abs(), percentile / 100.0).item()
    print(f"Estimated τ (ReAct clip) = {tau:.4f}")
    return tau

# ============================================================
# 5. ReAct forward + feature clipping
# ============================================================
def react_forward(model, x, feature_layer, tau):
    feats = []
    def hook_fn(_, __, output):
        feats.append(output.clone())
    h = feature_layer.register_forward_hook(hook_fn)
    with torch.no_grad():
        _ = model(x)
    h.remove()

    features = torch.clamp(feats[0], max=tau)    # ReAct clipping
    features = features.squeeze()                # remove spatial dims
    logits = model.fc(features)                  # classifier head
    return logits

def extract_logits_react(model, loader, feature_layer, tau, device):
    all_logits = []
    model.eval()
    with torch.no_grad():
        for images, _ in tqdm(loader, desc="Extracting ReAct logits"):
            images = images.to(device)
            logits = react_forward(model, images, feature_layer, tau)
            all_logits.append(logits.cpu())
    return torch.cat(all_logits, dim=0)

# ============================================================
# 6. Example usage
# ============================================================
# Replace these with your dataset paths
id_path  = r"E:\datasets\ImageNet\ILSVRC2012_img_val"
ood_path = r"E:\datasets\Textures"

id_dataset  = datasets.ImageFolder(id_path,  transform=transform)
ood_dataset = datasets.ImageFolder(ood_path, transform=transform)
id_loader  = DataLoader(id_dataset,  batch_size=64, shuffle=False, num_workers=4)
ood_loader = DataLoader(ood_dataset, batch_size=64, shuffle=False, num_workers=4)

# 1) Estimate τ on a few ID batches
tau = estimate_tau(model, id_loader, feature_layer, percentile=90)

# 2) Extract logits with ReAct
id_logits  = extract_logits_react(model, id_loader,  feature_layer, tau, device)
ood_logits = extract_logits_react(model, ood_loader, feature_layer, tau, device)

# 3) Compute MSP + metrics
id_scores  = compute_msp_scores(id_logits)
ood_scores = compute_msp_scores(ood_logits)

auroc = compute_auroc(id_scores, ood_scores)
fpr95 = compute_fpr95(id_scores, ood_scores)

print("\n================ ReAct + MSP Results ================")
print(f"AUROC  : {auroc:.2f}%")
print(f"FPR95  : {fpr95:.2f}%")
print("====================================================\n")
