In [3]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, roc_curve


def get_energy_score(logits, temperature=1.0):
    """
    Compute Energy scores from logits.
    Following the KNN-OOD paper implementation.
    
    Args:
        logits: torch.Tensor or numpy array (N, num_classes)
        temperature: float, temperature scaling parameter (default=1.0)
    
    Returns:
        scores: numpy array (N,) - higher = more ID-like
    """
    # Convert to torch tensor if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    
    # Apply temperature scaling
    if temperature != 1.0:
        logits = logits / temperature
    
    # Compute energy score: logsumexp(logits)
    # Note: Paper uses positive logsumexp (higher = more ID-like)
    scores = torch.logsumexp(logits.data.cpu(), dim=1).numpy()
    
    return scores


def compute_auroc(id_scores, ood_scores):
    """Compute AUROC (higher is better)."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    return roc_auc_score(labels, scores) * 100


def compute_fpr95(id_scores, ood_scores):
    """Compute FPR when TPR=95% (lower is better)."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    idx = np.argmax(tpr >= 0.95)
    return fpr[idx] * 100


def load_logits(path):
    """Load logits - handles multiple formats."""
    print(f"  Loading from: {path}")
    
    if path.endswith('.pt'):
        data = torch.load(path, map_location='cpu')
        
        if isinstance(data, dict):
            if 'logits' in data:
                logits = data['logits']
            elif 'activations' in data:
                raise ValueError(f"File contains 'activations' (features), not 'logits'.")
            else:
                raise ValueError(f"Dictionary doesn't contain 'logits'. Keys: {data.keys()}")
        elif isinstance(data, torch.Tensor):
            logits = data
        else:
            raise ValueError(f"Unexpected data type: {type(data)}")
            
    elif path.endswith('.npy'):
        logits = torch.from_numpy(np.load(path))
    else:
        raise ValueError(f"Unsupported file format: {path}")
    
    print(f"  Loaded logits shape: {logits.shape}")
    
    # Ensure correct shape: (num_samples, num_classes)
    if logits.shape[0] == 1000 and logits.shape[1] > 1000:
        print(f"  Transposing from {logits.shape}")
        logits = logits.T
    
    return logits


def evaluate_energy_ood(id_logits_path, ood_logits_dict, temperature=1.0):
    """
    Evaluate Energy score OOD detection.
    
    Args:
        id_logits_path: Path to ID logits file
        ood_logits_dict: Dictionary {ood_name: ood_logits_path}
        temperature: Temperature scaling parameter
    
    Returns:
        Dictionary with results for each OOD dataset and averages
    """
    print("="*70)
    print("Energy Score OOD Detection")
    print(f"Temperature: {temperature}")
    print("="*70)
    
    # Load ID logits
    print(f"\nLoading ID data...")
    id_logits = load_logits(id_logits_path)
    
    # Compute Energy scores for ID data
    print("\nComputing Energy scores for ID data...")
    id_scores = get_energy_score(id_logits, temperature=temperature)
    print(f"ID Energy scores: min={id_scores.min():.4f}, max={id_scores.max():.4f}, mean={id_scores.mean():.4f}\n")
    
    # Evaluate on each OOD dataset
    results = {}
    
    for ood_name, ood_path in ood_logits_dict.items():
        print(f"{'='*70}")
        print(f"Evaluating on {ood_name}...")
        
        # Load OOD logits
        ood_logits = load_logits(ood_path)
        
        # Compute Energy scores for OOD data
        ood_scores = get_energy_score(ood_logits, temperature=temperature)
        print(f"  OOD Energy scores: min={ood_scores.min():.4f}, max={ood_scores.max():.4f}, mean={ood_scores.mean():.4f}")
        
        # Compute metrics
        auroc = compute_auroc(id_scores, ood_scores)
        fpr95 = compute_fpr95(id_scores, ood_scores)
        
        results[ood_name] = {
            'fpr95': fpr95,
            'auroc': auroc,
            'num_samples': len(ood_logits)
        }
        
        print(f"  ✓ Results: FPR95={fpr95:.2f}%, AUROC={auroc:.2f}%\n")
    
    # Compute averages
    avg_fpr95 = np.mean([r['fpr95'] for r in results.values()])
    avg_auroc = np.mean([r['auroc'] for r in results.values()])
    results['average'] = {'fpr95': avg_fpr95, 'auroc': avg_auroc}
    
    # Print summary table
    print(f"\n{'='*70}")
    print("Energy Score OOD Detection Results")
    print(f"{'='*70}")
    print(f"{'OOD Dataset':<20} {'FPR95 ↓':>12} {'AUROC ↑':>12} {'Samples':>12}")
    print(f"{'-'*70}")
    
    for ood_name, metrics in results.items():
        if ood_name == 'average':
            continue
        print(f"{ood_name:<20} {metrics['fpr95']:>11.2f}% {metrics['auroc']:>11.2f}% {metrics['num_samples']:>12}")
    
    print(f"{'-'*70}")
    print(f"{'Average':<20} {avg_fpr95:>11.2f}% {avg_auroc:>11.2f}% {'-':>12}")
    print(f"{'='*70}\n")
    
    return results


# ============================================================================
# USAGE
# ============================================================================

# ID dataset logits
id_data_path = "imagenet1k_logits.pt"

# Your OOD datasets
ood_data_dict = {
    "iNaturalist": "selected_inaturalist21_logits.pt",
    "SUN": "sun_logits.pt",
    "Places": "places_logits.pt",
    "Textures": "textures_logits.pt",
}

# Run evaluation (temperature=1.0 is standard, can tune if needed)
results = evaluate_energy_ood(id_data_path, ood_data_dict, temperature=1.0)

# Access individual results
print("\nSummary:")
for ood_name in ood_data_dict.keys():
    print(f"  {ood_name}: FPR95={results[ood_name]['fpr95']:.2f}%, AUROC={results[ood_name]['auroc']:.2f}%")
print(f"\nAverage: FPR95={results['average']['fpr95']:.2f}%, AUROC={results['average']['auroc']:.2f}%")

Energy Score OOD Detection
Temperature: 1.0

Loading ID data...
  Loading from: imagenet1k_logits.pt
  Loaded logits shape: torch.Size([50000, 1000])

Computing Energy scores for ID data...
ID Energy scores: min=7.6989, max=45.4658, mean=17.1768

Evaluating on iNaturalist...
  Loading from: selected_inaturalist21_logits.pt
  Loaded logits shape: torch.Size([20000, 1000])
  OOD Energy scores: min=7.4395, max=26.7316, mean=11.3818
  ✓ Results: FPR95=53.78%, AUROC=90.62%

Evaluating on SUN...
  Loading from: sun_logits.pt
  Loaded logits shape: torch.Size([20000, 1000])
  OOD Energy scores: min=7.9612, max=30.8284, mean=11.9306
  ✓ Results: FPR95=58.83%, AUROC=86.57%

Evaluating on Places...
  Loading from: places_logits.pt
  Loaded logits shape: torch.Size([20000, 1000])
  OOD Energy scores: min=7.7614, max=37.3130, mean=12.3254
  ✓ Results: FPR95=66.03%, AUROC=83.96%

Evaluating on Textures...
  Loading from: textures_logits.pt
  Loaded logits shape: torch.Size([5640, 1000])
  OOD Energ