### Extract Logits for Use in MSP

In [6]:
import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import os
from PIL import Image
from pathlib import Path

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

def extract_logits(model, dataloader, device):
    """Extract logits from model for all samples in dataloader"""
    all_logits = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting logits"):
            images, labels = batch
            images = images.to(device)
            logits = model(images)
            all_logits.append(logits.cpu())
    
    return torch.cat(all_logits, dim=0)


class FlatImageDataset(Dataset):
    """Dataset for loading all images from a flat directory structure"""
    def __init__(self, root_dir, transform=None, extensions=('.jpg', '.jpeg', '.png', '.JPEG', '.JPG', '.PNG')):
        self.root_dir = Path(root_dir)
        self.transform = transform
        
        # Get all image files recursively
        self.image_paths = []
        for ext in extensions:
            self.image_paths.extend(list(self.root_dir.rglob(f'*{ext}')))
        
        self.image_paths.sort()  # For reproducibility
        print(f"Found {len(self.image_paths)} images")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image if loading fails
            image = Image.new('RGB', (224, 224), color='black')
        
        if self.transform:
            image = self.transform(image)
        
        # Return dummy label (0) since we only care about logits for OOD detection
        return image, 0


def extract_and_save_logits(dataset_path, 
                            output_path, 
                            batch_size=128, 
                            num_workers=4, 
                            flat_structure=False):
    """
    Extract logits from a dataset and save to file
    
    Args:
        dataset_path: Path to the dataset directory
        output_path: Path where to save the logits .pt file
        batch_size: Batch size for processing
        num_workers: Number of workers for data loading
        flat_structure: If False (default), expects ImageFolder structure with class subdirectories.
                       If True, loads all images from directory recursively (flat structure).
    """
    # Load pre-trained ResNet-50
    model = torchvision.models.resnet50(pretrained=True)
    model = model.to(device)
    model.eval()
    
    # Load dataset
    print(f"Loading dataset from {dataset_path}...")
    print(f"Using {'flat' if flat_structure else 'structured (ImageFolder)'} dataset loading")
    
    if flat_structure:
        dataset = FlatImageDataset(dataset_path, transform=transform)
    else:
        dataset = ImageFolder(dataset_path, transform=transform)
    
    print(f"Dataset size: {len(dataset)}")
    
    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    # Extract logits
    print(f"\nExtracting logits...")
    logits = extract_logits(model, dataloader, device)
    
    # Save logits
    print(f"\nSaving logits to {output_path}...")
    torch.save(logits, output_path)
    print(f"Saved logits: {logits.shape}")
    
    return logits


# Example usage:
if __name__ == "__main__":
    # Example 1: Structured dataset (ImageFolder format with class subdirectories)
    # extract_and_save_logits(
    #     dataset_path='/path/to/imagenet/val',
    #     output_path='imagenet_val_logits.pt',
    #     flat_structure=False  # Uses ImageFolder
    # )
    
    # Example 2: Flat dataset (all images in one directory or nested without class structure)
    DATASET_PATH = r'E:\datasets\KNN-OOD\Textures\images'
    OUTPUT_PATH = 'textures_logits.pt'
    
    extract_and_save_logits(
        dataset_path=DATASET_PATH,
        output_path=OUTPUT_PATH,
        batch_size=128,
        num_workers=0,
        flat_structure=False  # Set to True for flat directory structure
    )
    
    print("\nDone!")

Using device: cuda
Loading dataset from E:\datasets\KNN-OOD\Textures\images...
Using structured (ImageFolder) dataset loading
Dataset size: 5640

Extracting logits...


Extracting logits: 100%|██████████| 45/45 [00:23<00:00,  1.88it/s]


Saving logits to textures_logits.pt...
Saved logits: torch.Size([5640, 1000])

Done!





In [7]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, roc_curve


def compute_msp_scores(logits):
    """Compute Maximum Softmax Probability scores."""
    if isinstance(logits, np.ndarray):
        logits_tensor = torch.from_numpy(logits).float()
    else:
        logits_tensor = logits.float()
    
    probs = F.softmax(logits_tensor, dim=-1)
    msp_scores = probs.max(dim=-1).values.numpy()
    return msp_scores


def compute_auroc(id_scores, ood_scores):
    """Compute AUROC (higher is better)."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    return roc_auc_score(labels, scores) * 100


def compute_fpr95(id_scores, ood_scores):
    """Compute FPR when TPR=95% (lower is better)."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    idx = np.argmax(tpr >= 0.95)
    return fpr[idx] * 100


def load_logits(path):
    """
    Load logits from either .pt or .npy file.
    Handles both your format and KNN-OOD format.
    """
    if path.endswith('.pt'):
        # Your format: torch.save(logits, path)
        logits = torch.load(path, map_location='cpu')
        if isinstance(logits, torch.Tensor):
            logits = logits.numpy()
    elif path.endswith('.npy'):
        # KNN-OOD format: (feat_log, score_log, label_log)
        data = np.load(path, allow_pickle=True)
        if len(data) == 3:
            _, score_log, _ = data
        elif len(data) == 2:
            _, score_log = data
        else:
            # Assume it's just raw logits
            score_log = data
        logits = score_log.T if score_log.shape[0] < score_log.shape[1] else score_log
        logits = logits.astype(np.float32)
    else:
        raise ValueError(f"Unsupported file format: {path}")
    
    return logits


def evaluate_msp_ood(id_data_path, ood_data_dict):
    """Evaluate MSP on ID and multiple OOD datasets."""
    # Load ID data and compute MSP scores
    print(f"Loading ID data from: {id_data_path}")
    id_logits = load_logits(id_data_path)
    
    # Ensure correct shape: (num_samples, num_classes)
    if id_logits.shape[1] > 1000:  # Likely transposed
        id_logits = id_logits.T
    
    id_msp_scores = compute_msp_scores(id_logits)
    
    print(f"ID data: {id_logits.shape[0]} samples, {id_logits.shape[1]} classes")
    print(f"ID MSP: min={id_msp_scores.min():.4f}, max={id_msp_scores.max():.4f}, mean={id_msp_scores.mean():.4f}\n")
    
    # Evaluate on each OOD dataset
    results = {}
    
    for ood_name, ood_path in ood_data_dict.items():
        print(f"Evaluating on {ood_name}...")
        
        # Load OOD data
        ood_logits = load_logits(ood_path)
        
        # Ensure correct shape
        if ood_logits.shape[1] > 1000:
            ood_logits = ood_logits.T
        
        ood_msp_scores = compute_msp_scores(ood_logits)
        
        # Compute metrics
        auroc = compute_auroc(id_msp_scores, ood_msp_scores)
        fpr95 = compute_fpr95(id_msp_scores, ood_msp_scores)
        
        results[ood_name] = {'fpr95': fpr95, 'auroc': auroc, 'num_samples': len(ood_logits)}
        print(f"  {ood_name}: FPR95={fpr95:.2f}%, AUROC={auroc:.2f}%")
    
    # Compute averages
    avg_fpr95 = np.mean([r['fpr95'] for r in results.values()])
    avg_auroc = np.mean([r['auroc'] for r in results.values()])
    results['average'] = {'fpr95': avg_fpr95, 'auroc': avg_auroc}
    
    print(f"\n{'='*60}")
    print(f"Average: FPR95={avg_fpr95:.2f}%, AUROC={avg_auroc:.2f}%")
    print(f"{'='*60}")
    
    # Print table
    print(f"\n{'='*70}")
    print(f"{'OOD Dataset':<20} {'FPR95 ↓':>12} {'AUROC ↑':>12} {'Samples':>12}")
    print(f"{'-'*70}")
    for ood_name, metrics in results.items():
        if ood_name == 'average':
            continue
        print(f"{ood_name:<20} {metrics['fpr95']:>11.2f}% {metrics['auroc']:>11.2f}% {metrics['num_samples']:>12}")
    print(f"{'-'*70}")
    print(f"{'Average':<20} {avg_fpr95:>11.2f}% {avg_auroc:>11.2f}% {'-':>12}")
    print(f"{'='*70}\n")
    
    return results


# ============================================================================
# USAGE: Replace with your actual .pt file paths
# ============================================================================

# Your ID dataset (e.g., ImageNet validation set)
id_data_path = "imagenet1k_logits.pt"

# Your OOD datasets
ood_data_dict = {
    "iNaturalist": "selected_inaturalist21_logits.pt",
    "SUN": "sun_logits.pt",
    "Places": "places_logits.pt",
    "Textures": "textures_logits.pt",
}

# Run evaluation
results = evaluate_msp_ood(id_data_path, ood_data_dict)

# Access individual results
print("Access results:")
for ood_name in ood_data_dict.keys():
    print(f"{ood_name} FPR95: {results[ood_name]['fpr95']:.2f}%")
print(f"Average AUROC: {results['average']['auroc']:.2f}%")

Loading ID data from: imagenet1k_logits.pt
ID data: 50000 samples, 1000 classes
ID MSP: min=0.0191, max=1.0000, mean=0.7970

Evaluating on iNaturalist...
  iNaturalist: FPR95=52.82%, AUROC=88.39%
Evaluating on SUN...
  SUN: FPR95=69.11%, AUROC=81.64%
Evaluating on Places...
  Places: FPR95=72.07%, AUROC=80.53%
Evaluating on Textures...
  Textures: FPR95=66.26%, AUROC=80.43%

Average: FPR95=65.06%, AUROC=82.75%

OOD Dataset               FPR95 ↓      AUROC ↑      Samples
----------------------------------------------------------------------
iNaturalist                52.82%       88.39%        20000
SUN                        69.11%       81.64%        20000
Places                     72.07%       80.53%        20000
Textures                   66.26%       80.43%         5640
----------------------------------------------------------------------
Average                    65.06%       82.75%            -

Access results:
iNaturalist FPR95: 52.82%
SUN FPR95: 69.11%
Places FPR95: 72.07%
T