In [None]:
import torch
import numpy as np
import faiss
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
from wnnnk_core.experiments.exp6.config import IMAGENET1K_ACTIVATIONS_PATH, INATURALIST21_ACTIVATIONS_PATH

# ============================================================
# PARAMETER: Number of samples per class
# ============================================================
SAMPLES_PER_CLASS = 10  # Change this to test different values

print("="*60)
print(f"KNN OOD Detection ({SAMPLES_PER_CLASS} samples per class)")
print("="*60)

print("\nLoading ImageNet-1K activations...")
imagenet1k_data = torch.load(IMAGENET1K_ACTIVATIONS_PATH)
imagenet1k_acts = imagenet1k_data["activations"].numpy().astype('float32')
imagenet1k_labels = imagenet1k_data["labels"].numpy()

print(f"✓ ImageNet-1K - Activations: {imagenet1k_acts.shape}, Labels: {imagenet1k_labels.shape}")

print("\nLoading iNaturalist-21 activations...")
inaturalist21_data = torch.load(INATURALIST21_ACTIVATIONS_PATH)
inaturalist21_acts = inaturalist21_data["activations"].numpy().astype('float32')
inaturalist21_labels = inaturalist21_data["labels"].numpy()

print(f"✓ iNaturalist-21 - Activations: {inaturalist21_acts.shape}, Labels: {inaturalist21_labels.shape}")

# ============================================================
# Step 1: Select N samples per class from ImageNet-1K
# ============================================================
print("\n" + "="*60)
print(f"Step 1: Selecting {SAMPLES_PER_CLASS} samples per class from ImageNet-1K")
print("="*60)

unique_classes = np.unique(imagenet1k_labels)
print(f"Number of unique classes: {len(unique_classes)}")

selected_indices = []
for cls in unique_classes:
    cls_indices = np.where(imagenet1k_labels == cls)[0]
    # Randomly select N samples (or all if less than N)
    n_samples = min(SAMPLES_PER_CLASS, len(cls_indices))
    selected = np.random.choice(cls_indices, size=n_samples, replace=False)
    selected_indices.extend(selected)

selected_indices = np.array(selected_indices)
print(f"Selected {len(selected_indices)} samples ({len(unique_classes)} classes × ~{SAMPLES_PER_CLASS} samples)")

# Create reference set (training bank for KNN)
train_acts = imagenet1k_acts[selected_indices]
train_labels = imagenet1k_labels[selected_indices]

print(f"Training bank shape: {train_acts.shape}")

# ============================================================
# Step 2: Normalize activations (L2 normalization)
# ============================================================
print("\n" + "="*60)
print("Step 2: L2 Normalizing activations")
print("="*60)

def l2_normalize(x):
    """L2 normalize along feature dimension"""
    norms = np.linalg.norm(x, axis=1, keepdims=True)
    return x / (norms + 1e-10)

train_acts_norm = l2_normalize(train_acts)
imagenet1k_acts_norm = l2_normalize(imagenet1k_acts)
inaturalist21_acts_norm = l2_normalize(inaturalist21_acts)

print(f"✓ Normalized training bank: {train_acts_norm.shape}")
print(f"✓ Normalized ImageNet-1K: {imagenet1k_acts_norm.shape}")
print(f"✓ Normalized iNaturalist-21: {inaturalist21_acts_norm.shape}")

# ============================================================
# Step 3: Build FAISS index
# ============================================================
print("\n" + "="*60)
print("Step 3: Building FAISS index")
print("="*60)

d = train_acts_norm.shape[1]  # Dimension
print(f"Feature dimension: {d}")

# Create FAISS index for L2 distance
index = faiss.IndexFlatL2(d)

# Add training vectors to index
index.add(train_acts_norm)
print(f"✓ Added {index.ntotal} vectors to FAISS index")

# ============================================================
# Step 4: Compute KNN distances using FAISS
# ============================================================
print("\n" + "="*60)
print("Step 4: Computing KNN distances with FAISS")
print("="*60)

def compute_knn_distance_faiss(test_acts_norm, index, k=1):
    """
    Compute k-nearest neighbor distance using FAISS
    
    Returns: array of distances to k-th nearest neighbor for each test sample
    """
    # Search for k nearest neighbors
    distances, indices = index.search(test_acts_norm, k)
    
    # Return k-th nearest neighbor distance (last column)
    return distances[:, -1]

# Test different k values (scale with reference set size)
k_values = [1, 5, 10, 50, 100]

results = {}

for k in k_values:
    print(f"\n{'='*60}")
    print(f"K = {k}")
    print(f"{'='*60}")
    
    # Compute KNN distances
    print("Computing distances for ImageNet-1K...")
    imagenet1k_knn = compute_knn_distance_faiss(imagenet1k_acts_norm, index, k=k)
    
    print("Computing distances for iNaturalist-21...")
    inaturalist21_knn = compute_knn_distance_faiss(inaturalist21_acts_norm, index, k=k)
    
    # Create labels (1 for ID, 0 for OOD)
    y_true = np.concatenate([
        np.ones(len(imagenet1k_knn)),
        np.zeros(len(inaturalist21_knn))
    ])
    
    # Combine scores (LOWER distance = more likely ID, so we negate for AUROC)
    y_scores = np.concatenate([-imagenet1k_knn, -inaturalist21_knn])
    
    # Calculate metrics
    auroc = roc_auc_score(y_true, y_scores)
    aupr = average_precision_score(y_true, y_scores)
    
    # Calculate FPR95
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    fpr95 = fpr[np.argmax(tpr >= 0.95)]
    
    # Calculate ID accuracy (TPR at FPR95)
    tpr95 = tpr[np.argmax(tpr >= 0.95)]
    id_acc = tpr95 * 100
    
    # Store results
    results[k] = {
        'auroc': auroc,
        'aupr': aupr,
        'fpr95': fpr95,
        'id_acc': id_acc,
        'imagenet_mean': imagenet1k_knn.mean(),
        'imagenet_std': imagenet1k_knn.std(),
        'inaturalist_mean': inaturalist21_knn.mean(),
        'inaturalist_std': inaturalist21_knn.std()
    }
    
    print(f"\nMetrics:")
    print(f"  AUROC:  {auroc:.4f}")
    print(f"  AUPR:   {aupr:.4f}")
    print(f"  FPR95:  {fpr95:.4f}")
    print(f"  ID Acc: {id_acc:.2f}%")
    
    print(f"\nDistance Statistics:")
    print(f"  ImageNet-1K (ID):     Mean={imagenet1k_knn.mean():.4f}, Std={imagenet1k_knn.std():.4f}")
    print(f"  iNaturalist-21 (OOD): Mean={inaturalist21_knn.mean():.4f}, Std={inaturalist21_knn.std():.4f}")
    print(f"  Separation (OOD-ID):  {inaturalist21_knn.mean() - imagenet1k_knn.mean():.4f}")

# ============================================================
# Summary
# ============================================================
print("\n" + "="*60)
print(f"Summary: KNN Performance with {SAMPLES_PER_CLASS} samples/class")
print("="*60)

print(f"\n{'K':<10} {'AUROC':<10} {'AUPR':<10} {'FPR95':<10} {'ID Acc':<10}")
print("-" * 50)
for k in k_values:
    r = results[k]
    print(f"{k:<10} {r['auroc']:<10.4f} {r['aupr']:<10.4f} {r['fpr95']:<10.4f} {r['id_acc']:<10.2f}")

print("\n" + "="*60)
print("KNN Analysis Complete!")
print("="*60)

KNN OOD Detection (10 samples per class)

Loading ImageNet-1K activations...
✓ ImageNet-1K - Activations: (50000, 2048), Labels: (50000,)

Loading iNaturalist-21 activations...
✓ iNaturalist-21 - Activations: (100000, 2048), Labels: (100000,)

Step 1: Selecting 10 samples per class from ImageNet-1K
Number of unique classes: 1000
Selected 10000 samples (1000 classes × ~10 samples)
Training bank shape: (10000, 2048)

Step 2: L2 Normalizing activations
✓ Normalized training bank: (10000, 2048)
✓ Normalized ImageNet-1K: (50000, 2048)
✓ Normalized iNaturalist-21: (100000, 2048)

Step 3: Building FAISS index
Feature dimension: 2048
✓ Added 10000 vectors to FAISS index

Step 4: Computing KNN distances with FAISS

K = 1
Computing distances for ImageNet-1K...
Computing distances for iNaturalist-21...

Metrics:
  AUROC:  0.8130
  AUPR:   0.7421
  FPR95:  0.7026
  ID Acc: 95.00%

Distance Statistics:
  ImageNet-1K (ID):     Mean=0.2592, Std=0.1576
  iNaturalist-21 (OOD): Mean=0.4267, Std=0.0975
 

In [13]:
import torch
import numpy as np
import faiss
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
from wnnnk_core.experiments.exp6.config import DEEP_INVERSION_ACTIVATIONS_PATH, IMAGENET1K_ACTIVATIONS_PATH, INATURALIST21_ACTIVATIONS_PATH

print("="*60)
print("KNN OOD Detection: Deep Inversion as Reference Set")
print("="*60)

print("\nLoading activations...")
imagenet1k_data = torch.load(IMAGENET1K_ACTIVATIONS_PATH)
imagenet1k_acts = imagenet1k_data["activations"].numpy().astype('float32')

inaturalist21_data = torch.load(INATURALIST21_ACTIVATIONS_PATH)
inaturalist21_acts = inaturalist21_data["activations"].numpy().astype('float32')

deep_inversion_data = torch.load(DEEP_INVERSION_ACTIVATIONS_PATH)
deep_inversion_acts = deep_inversion_data["activations"].numpy().astype('float32')

print(f"✓ ImageNet-1K: {imagenet1k_acts.shape}")
print(f"✓ iNaturalist-21: {inaturalist21_acts.shape}")
print(f"✓ Deep Inversion (reference): {deep_inversion_acts.shape}")

# L2 Normalize
def l2_normalize(x):
    norms = np.linalg.norm(x, axis=1, keepdims=True)
    return x / (norms + 1e-10)

train_acts_norm = l2_normalize(deep_inversion_acts)  # Deep Inversion as reference
imagenet1k_acts_norm = l2_normalize(imagenet1k_acts)
inaturalist21_acts_norm = l2_normalize(inaturalist21_acts)

# Build FAISS index with Deep Inversion
d = train_acts_norm.shape[1]
index = faiss.IndexFlatL2(d)
index.add(train_acts_norm)
print(f"\n✓ Built FAISS index with {index.ntotal} Deep Inversion samples")

# Compute KNN distances
def compute_knn_distance_faiss(test_acts_norm, index, k=1):
    distances, indices = index.search(test_acts_norm, k)
    return distances[:, -1]

k_values = [1, 5, 10, 50, 100]

print("\n" + "="*60)
print("Results: ImageNet-1K (ID) vs iNaturalist-21 (OOD)")
print("="*60)

print(f"\n{'K':<10} {'AUROC':<10} {'AUPR':<10} {'FPR95':<10}")
print("-" * 40)

for k in k_values:
    imagenet1k_knn = compute_knn_distance_faiss(imagenet1k_acts_norm, index, k=k)
    inaturalist21_knn = compute_knn_distance_faiss(inaturalist21_acts_norm, index, k=k)
    
    # Create labels and scores
    y_true = np.concatenate([
        np.ones(len(imagenet1k_knn)),
        np.zeros(len(inaturalist21_knn))
    ])
    y_scores = np.concatenate([-imagenet1k_knn, -inaturalist21_knn])
    
    # Calculate metrics
    auroc = roc_auc_score(y_true, y_scores)
    aupr = average_precision_score(y_true, y_scores)
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    fpr95 = fpr[np.argmax(tpr >= 0.95)]
    
    print(f"{k:<10} {auroc:<10.4f} {aupr:<10.4f} {fpr95:<10.4f}")

print("\n" + "="*60)

KNN OOD Detection: Deep Inversion as Reference Set

Loading activations...
✓ ImageNet-1K: (50000, 2048)
✓ iNaturalist-21: (100000, 2048)
✓ Deep Inversion (reference): (10000, 2048)

✓ Built FAISS index with 10000 Deep Inversion samples

Results: ImageNet-1K (ID) vs iNaturalist-21 (OOD)

K          AUROC      AUPR       FPR95     
----------------------------------------
1          0.8041     0.6483     0.5998    
5          0.7963     0.6359     0.5990    
10         0.7701     0.5839     0.6345    
50         0.7155     0.5108     0.7332    
100        0.7099     0.5062     0.7436    



In [18]:
ALPHA = 0.01
id_train_size = 1281167
ALPHA * id_train_size

12811.67