In [7]:
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.covariance import EmpiricalCovariance


# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


def compute_mahalanobis_scores_batch(features, means, precision, num_classes, batch_size=500):
    """
    Compute Mahalanobis scores in batches following the paper's approach.
    Much faster than the loop-based version.
    
    Args:
        features: numpy array (N, feature_dim)
        means: numpy array (num_classes, feature_dim)
        precision: numpy array (feature_dim, feature_dim)
        num_classes: int
        batch_size: int, number of samples to process at once
    
    Returns:
        scores: numpy array (N,)
    """
    # Convert to torch tensors on GPU
    features_torch = torch.from_numpy(features).float().cuda()
    means_torch = torch.from_numpy(means).float().cuda()
    precision_torch = torch.from_numpy(precision).float().cuda()
    
    num_samples = features.shape[0]
    all_scores = []
    
    print(f"  Computing Mahalanobis scores for {num_samples} samples (batch_size={batch_size})...")
    
    num_batches = (num_samples + batch_size - 1) // batch_size
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, num_samples)
        
        batch_features = features_torch[start_idx:end_idx]  # (batch_size, feature_dim)
        
        gaussian_scores = []
        
        # Compute distance to each class mean
        for cls in range(num_classes):
            batch_sample_mean = means_torch[cls]  # (feature_dim,)
            zero_f = batch_features - batch_sample_mean  # (batch_size, feature_dim)
            
            # Mahalanobis distance: -0.5 * (zero_f @ precision @ zero_f.T).diag()
            term_gau = -0.5 * torch.mm(torch.mm(zero_f, precision_torch), zero_f.t()).diag()
            gaussian_scores.append(term_gau.view(-1, 1))
        
        # Stack all class scores and take max
        gaussian_scores = torch.cat(gaussian_scores, dim=1)  # (batch_size, num_classes)
        batch_scores, _ = torch.max(gaussian_scores, dim=1)
        
        all_scores.append(batch_scores.cpu().numpy())
        
        # Print progress
        if (batch_idx + 1) % max(1, num_batches // 10) == 0 or (batch_idx + 1) == num_batches:
            progress = (batch_idx + 1) / num_batches * 100
            print(f"    Progress: {end_idx}/{num_samples} ({progress:.1f}%)")
    
    scores = np.concatenate(all_scores)
    return scores


def fit_mahalanobis_params(features, labels, num_classes):
    """
    Compute sample means and tied covariance following the paper.
    
    Args:
        features: numpy array (N, feature_dim)
        labels: numpy array (N,)
        num_classes: int
    
    Returns:
        sample_class_mean: numpy array (num_classes, feature_dim)
        precision: numpy array (feature_dim, feature_dim)
    """
    print("Fitting Mahalanobis parameters...")
    
    feature_dim = features.shape[1]
    sample_class_mean = np.zeros((num_classes, feature_dim))
    
    # Compute per-class means
    print("  Computing class means...")
    for cls in range(num_classes):
        class_mask = (labels == cls)
        class_features = features[class_mask]
        if len(class_features) > 0:
            sample_class_mean[cls] = class_features.mean(axis=0)
        
        if (cls + 1) % 100 == 0:
            print(f"    Processed {cls+1}/{num_classes} classes")
    
    print(f"  Completed all {num_classes} classes")
    
    # Compute tied covariance (same as paper's group_lasso approach)
    print("  Computing tied covariance matrix...")
    
    # Center features by class means
    X = []
    for cls in range(num_classes):
        class_mask = (labels == cls)
        class_features = features[class_mask]
        if len(class_features) > 0:
            X_cls = class_features - sample_class_mean[cls]
            X.append(X_cls)
    
    X = np.vstack(X)
    
    # Fit empirical covariance
    print("  Fitting covariance estimator...")
    group_lasso = EmpiricalCovariance(assume_centered=False)
    group_lasso.fit(X)
    precision = group_lasso.precision_
    
    print(f"  Precision matrix shape: {precision.shape}")
    print(f"  Condition number: {np.linalg.cond(precision):.2e}")
    
    return sample_class_mean, precision


def compute_auroc(id_scores, ood_scores):
    """Compute AUROC."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    return roc_auc_score(labels, scores) * 100


def compute_fpr95(id_scores, ood_scores):
    """Compute FPR@95."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    idx = np.argmax(tpr >= 0.95)
    return fpr[idx] * 100


def load_embeddings(path):
    """Load embeddings from your format."""
    print(f"  Loading: {path}")
    data = torch.load(path, map_location='cpu')
    
    if isinstance(data, dict):
        activations = data['activations']
        labels = data.get('labels', None)
        
        if isinstance(activations, torch.Tensor):
            activations = activations.numpy()
        if labels is not None and isinstance(labels, torch.Tensor):
            labels = labels.numpy()
    else:
        activations = data.numpy() if isinstance(data, torch.Tensor) else data
        labels = None
    
    print(f"    Shape: {activations.shape}")
    return activations, labels


def evaluate_mahalanobis_ood(id_val_path, ood_dict, num_classes=1000, batch_size=500):
    """
    Evaluate Mahalanobis OOD detection following the paper's implementation.
    
    Args:
        id_val_path: Path to ID validation embeddings (with labels)
        ood_dict: Dictionary {ood_name: ood_path}
        num_classes: Number of classes
        batch_size: Batch size for computing scores
    
    Returns:
        Dictionary with results
    """
    print("="*70)
    print("Mahalanobis Distance OOD Detection")
    print(f"Device: {device}")
    print("="*70)
    
    # Load ID validation data
    print(f"\n[Step 1/4] Loading ID validation data...")
    val_activations, val_labels = load_embeddings(id_val_path)
    
    if val_labels is None:
        raise ValueError("ID validation data must have labels!")
    
    val_features = val_activations.astype(np.float32)
    val_labels = val_labels.astype(np.int64)
    
    # Fit Mahalanobis parameters
    print(f"\n[Step 2/4] Fitting Mahalanobis parameters...")
    sample_mean, precision = fit_mahalanobis_params(val_features, val_labels, num_classes)
    
    # Compute ID scores
    print(f"\n[Step 3/4] Computing ID scores...")
    id_scores = compute_mahalanobis_scores_batch(
        val_features, sample_mean, precision, num_classes, batch_size
    )
    print(f"  ID scores: min={id_scores.min():.4f}, max={id_scores.max():.4f}, mean={id_scores.mean():.4f}")
    
    # Evaluate OOD datasets
    print(f"\n[Step 4/4] Evaluating OOD datasets...")
    print("="*70)
    
    results = {}
    
    for ood_idx, (ood_name, ood_path) in enumerate(ood_dict.items(), 1):
        print(f"\nOOD Dataset {ood_idx}/{len(ood_dict)}: {ood_name}")
        
        ood_activations, _ = load_embeddings(ood_path)
        ood_features = ood_activations.astype(np.float32)
        
        ood_scores = compute_mahalanobis_scores_batch(
            ood_features, sample_mean, precision, num_classes, batch_size
        )
        print(f"  OOD scores: min={ood_scores.min():.4f}, max={ood_scores.max():.4f}, mean={ood_scores.mean():.4f}")
        
        auroc = compute_auroc(id_scores, ood_scores)
        fpr95 = compute_fpr95(id_scores, ood_scores)
        
        results[ood_name] = {
            'fpr95': fpr95,
            'auroc': auroc,
            'num_samples': len(ood_features)
        }
        
        print(f"  ✓ Results: FPR95={fpr95:.2f}%, AUROC={auroc:.2f}%")
    
    # Compute averages
    avg_fpr95 = np.mean([r['fpr95'] for r in results.values()])
    avg_auroc = np.mean([r['auroc'] for r in results.values()])
    results['average'] = {'fpr95': avg_fpr95, 'auroc': avg_auroc}
    
    # Print summary
    print(f"\n{'='*70}")
    print("Mahalanobis Distance OOD Detection Results")
    print(f"{'='*70}")
    print(f"{'OOD Dataset':<20} {'FPR95 ↓':>12} {'AUROC ↑':>12} {'Samples':>12}")
    print(f"{'-'*70}")
    
    for ood_name, metrics in results.items():
        if ood_name == 'average':
            continue
        print(f"{ood_name:<20} {metrics['fpr95']:>11.2f}% {metrics['auroc']:>11.2f}% {metrics['num_samples']:>12}")
    
    print(f"{'-'*70}")
    print(f"{'Average':<20} {avg_fpr95:>11.2f}% {avg_auroc:>11.2f}% {'-':>12}")
    print(f"{'='*70}\n")
    
    return results


# ============================================================================
# USAGE
# ============================================================================

id_val_path = r"C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\activations\imagenet1k_activations\resnet50_imagenet1k_val_avgpool.pt"

ood_dict = {
    "iNaturalist": r"C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\activations\selected_inaturalist21_activations\resnet50_inaturalist21_10k_avgpool.pt",
    "SUN": r"C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\activations\selected_sun_activations\resnet50_sun_10k_avgpool.pt",
    "Places": r"C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\activations\selected_places_activations\resnet50_places_10k_avgpool.pt",
    "Textures": r"C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\activations\textures_activations\resnet50_textures_avgpool.pt",
}

# Run evaluation
# batch_size=500 processes 500 samples at once (adjust based on GPU memory)
results = evaluate_mahalanobis_ood(
    id_val_path=id_val_path,
    ood_dict=ood_dict,
    num_classes=1000,
    batch_size=500  # Increase if you have more GPU memory
)

print("\nSummary:")
for ood_name in ood_dict.keys():
    print(f"  {ood_name}: FPR95={results[ood_name]['fpr95']:.2f}%, AUROC={results[ood_name]['auroc']:.2f}%")
print(f"\nAverage: FPR95={results['average']['fpr95']:.2f}%, AUROC={results['average']['auroc']:.2f}%")

Using device: cuda
Mahalanobis Distance OOD Detection
Device: cuda

[Step 1/4] Loading ID validation data...
  Loading: C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\activations\imagenet1k_activations\resnet50_imagenet1k_val_avgpool.pt
    Shape: (50000, 2048)

[Step 2/4] Fitting Mahalanobis parameters...
Fitting Mahalanobis parameters...
  Computing class means...
    Processed 100/1000 classes
    Processed 200/1000 classes
    Processed 300/1000 classes
    Processed 400/1000 classes
    Processed 500/1000 classes
    Processed 600/1000 classes
    Processed 700/1000 classes
    Processed 800/1000 classes
    Processed 900/1000 classes
    Processed 1000/1000 classes
  Completed all 1000 classes
  Computing tied covariance matrix...
  Fitting covariance estimator...
  Precision matrix shape: (2048, 2048)
  Condition number: 1.20e+04

[Step 3/4] Computing ID scores...
  Computing Mahalanobis scores for 50000 samples (batch_size=500)...
    P