In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from PIL import Image
from pathlib import Path

class FlatImageDataset(Dataset):
    """Dataset for loading all images from a flat directory structure"""
    def __init__(self, root_dir, transform=None, extensions=('.jpg', '.jpeg', '.png', '.JPEG', '.JPG', '.PNG')):
        self.root_dir = Path(root_dir)
        self.transform = transform
        
        # Get all image files recursively
        self.image_paths = []
        for ext in extensions:
            self.image_paths.extend(list(self.root_dir.rglob(f'*{ext}')))
        
        self.image_paths.sort()  # For reproducibility
        print(f"Found {len(self.image_paths)} images")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a blank image if loading fails
            image = Image.new('RGB', (224, 224), color='black')
        
        if self.transform:
            image = self.transform(image)
        
        # Return dummy label (0) since we only care about logits for OOD detection
        return image, 0


class ODINPostprocessor:
    def __init__(self, temperature=1000, noise=0.0014, input_std=[0.229, 0.224, 0.225]):
        """
        ODIN (Out-of-DIstribution detector for Neural networks) postprocessor
        
        Args:
            temperature: Temperature scaling parameter (default: 1000)
            noise: Magnitude of input perturbation (default: 0.0014)
            input_std: Standard deviation used for normalization (default: ImageNet std)
        """
        self.temperature = temperature
        self.noise = noise
        self.input_std = input_std

    def postprocess_logits(self, model, data, device='cuda'):
        """
        Apply ODIN postprocessing to get confidence scores
        
        Args:
            model: Neural network model
            data: Input images [batch_size, 3, H, W]
            device: Device to run on
            
        Returns:
            pred: Predicted classes
            conf: ODIN confidence scores
        """
        model.eval()
        data = data.to(device)
        data.requires_grad = True
        
        # Forward pass
        output = model(data)
        
        # Get predicted labels
        labels = output.detach().argmax(axis=1)
        
        # Apply temperature scaling
        output = output / self.temperature
        
        # Calculate loss and gradients
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, labels)
        loss.backward()
        
        # Normalize gradient to binary {-1, 1}
        gradient = torch.ge(data.grad.detach(), 0)
        gradient = (gradient.float() - 0.5) * 2
        
        # Scale gradient by input std
        gradient[:, 0] = gradient[:, 0] / self.input_std[0]
        gradient[:, 1] = gradient[:, 1] / self.input_std[1]
        gradient[:, 2] = gradient[:, 2] / self.input_std[2]
        
        # Add perturbation to input
        perturbed_data = torch.add(data.detach(), gradient, alpha=-self.noise)
        
        # Forward pass with perturbed input
        output = model(perturbed_data)
        output = output / self.temperature
        
        # Calculate confidence scores
        nnOutput = output.detach()
        nnOutput = nnOutput - nnOutput.max(dim=1, keepdims=True).values
        nnOutput = nnOutput.exp() / nnOutput.exp().sum(dim=1, keepdims=True)
        
        conf, pred = nnOutput.max(dim=1)
        
        return pred, conf

    def compute_scores_from_images(self, model, dataloader, device='cuda'):
        """
        Compute ODIN scores for a dataset
        
        Args:
            model: Neural network model
            dataloader: DataLoader with images
            device: Device to run on
            
        Returns:
            pred_list: Predictions
            conf_list: ODIN confidence scores
        """
        pred_list, conf_list = [], []
        
        for batch in tqdm(dataloader, desc="Computing ODIN scores"):
            images = batch[0] if isinstance(batch, (list, tuple)) else batch['data']
            images = images.to(device)
            
            pred, conf = self.postprocess_logits(model, images, device)
            
            pred_list.append(pred.cpu())
            conf_list.append(conf.cpu())
        
        pred_list = torch.cat(pred_list).numpy().astype(int)
        conf_list = torch.cat(conf_list).numpy()
        
        return pred_list, conf_list


def load_dataset(dataset_path, transform, flat_structure=False, batch_size=64, num_workers=0):
    """
    Load dataset with either ImageFolder or FlatImageDataset
    
    Args:
        dataset_path: Path to dataset directory
        transform: Torchvision transforms to apply
        flat_structure: If True, use FlatImageDataset. If False, use ImageFolder
        batch_size: Batch size for DataLoader
        num_workers: Number of workers for DataLoader
        
    Returns:
        DataLoader for the dataset
    """
    print(f"Loading dataset from {dataset_path}...")
    print(f"Using {'flat' if flat_structure else 'structured (ImageFolder)'} dataset loading")
    
    if flat_structure:
        dataset = FlatImageDataset(dataset_path, transform=transform)
    else:
        dataset = ImageFolder(dataset_path, transform=transform)
    
    print(f"Dataset size: {len(dataset)}")
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return dataloader


def compute_ood_metrics(id_conf, ood_conf):
    """
    Compute OOD detection metrics
    
    Args:
        id_conf: Confidence scores for in-distribution data
        ood_conf: Confidence scores for out-of-distribution data
        
    Returns:
        Dictionary with FPR95, AUROC, AUPR metrics
    """
    # Create labels (1 for ID, 0 for OOD)
    labels = np.concatenate([np.ones_like(id_conf), np.zeros_like(ood_conf)])
    scores = np.concatenate([id_conf, ood_conf])
    
    # Compute AUROC
    auroc = roc_auc_score(labels, scores) * 100
    
    # Compute AUPR
    aupr = average_precision_score(labels, scores) * 100
    
    # Compute FPR95 (FPR at 95% TPR)
    fpr, tpr, _ = roc_curve(labels, scores)
    fpr95 = fpr[np.argmax(tpr >= 0.95)] * 100
    
    return {
        'AUROC': auroc,
        'AUPR': aupr,
        'FPR95': fpr95
    }


# Usage example:
if __name__ == "__main__":
    import torchvision
    from torchvision import transforms
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model
    model = torchvision.models.resnet50(pretrained=True)
    model = model.to(device)
    model.eval()
    
    # Define transforms (same as used for logit extraction)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Load ID dataset (structured - ImageFolder)
    id_loader = load_dataset(
        dataset_path=r'E:\datasets\ImageNet\ILSVRC2012_img_val',
        transform=transform,
        flat_structure=False,  # Structured dataset
        batch_size=64,
        num_workers=0
    )
    
    # Load OOD dataset (flat structure)
    ood_loader = load_dataset(
        dataset_path=r'C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\datasets\iNaturalist\images',
        transform=transform,
        flat_structure=True,  # Flat dataset
        batch_size=64,
        num_workers=0
    )
    
    # Initialize ODIN postprocessor
    odin = ODINPostprocessor(temperature=1000, noise=0.0014)
    
    # Compute ODIN scores
    print("\nComputing ID scores...")
    id_preds, id_conf = odin.compute_scores_from_images(model, id_loader, device)
    
    print("\nComputing OOD scores...")
    ood_preds, ood_conf = odin.compute_scores_from_images(model, ood_loader, device)
    
    # Compute metrics
    metrics = compute_ood_metrics(id_conf, ood_conf)
    
    print(f"\nODIN OOD Detection Results:")
    print(f"ID samples: {len(id_conf)}")
    print(f"OOD samples: {len(ood_conf)}")
    print(f"AUROC: {metrics['AUROC']:.2f}%")
    print(f"AUPR: {metrics['AUPR']:.2f}%")
    print(f"FPR95: {metrics['FPR95']:.2f}%")
    print(f"\nID confidence - Mean: {id_conf.mean():.4f}, Std: {id_conf.std():.4f}")
    print(f"OOD confidence - Mean: {ood_conf.mean():.4f}, Std: {ood_conf.std():.4f}")



Loading dataset from E:\datasets\ImageNet\ILSVRC2012_img_val...
Using structured (ImageFolder) dataset loading
Dataset size: 50000
Loading dataset from C:\Users\gabri\Local Desktop\Research\wnnnk\experiments\exp6_deep_inversion_for_ood\data\datasets\iNaturalist\images...
Using flat dataset loading
Found 20000 images
Dataset size: 20000

Computing ID scores...


Computing ODIN scores: 100%|██████████| 782/782 [11:28<00:00,  1.14it/s]



Computing OOD scores...


Computing ODIN scores: 100%|██████████| 313/313 [03:59<00:00,  1.30it/s]


ODIN OOD Detection Results:
ID samples: 50000
OOD samples: 20000
AUROC: 90.96%
AUPR: 95.97%
FPR95: 42.49%

ID confidence - Mean: 0.0010, Std: 0.0000
OOD confidence - Mean: 0.0010, Std: 0.0000





In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from sklearn.metrics import roc_auc_score, roc_curve
from pathlib import Path
from PIL import Image


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


def get_odin_score(inputs, model, forward_func, method_args):
    """
    Compute ODIN scores following the paper implementation.
    
    Args:
        inputs: torch.Tensor (batch_size, C, H, W)
        model: torch.nn.Module
        forward_func: function that takes (inputs, model) and returns logits
        method_args: dict with 'temperature' and 'magnitude' keys
    
    Returns:
        scores: numpy array (batch_size,)
    """
    temper = method_args['temperature']
    noiseMagnitude1 = method_args['magnitude']
    criterion = nn.CrossEntropyLoss()
    
    inputs = torch.autograd.Variable(inputs, requires_grad=True)
    outputs = forward_func(inputs, model)
    
    maxIndexTemp = np.argmax(outputs.data.cpu().numpy(), axis=1)
    
    # Using temperature scaling
    outputs = outputs / temper
    
    labels = torch.autograd.Variable(torch.LongTensor(maxIndexTemp).cuda())
    loss = criterion(outputs, labels)
    loss.backward()
    
    # Normalizing the gradient to binary in {0, 1}
    gradient = torch.ge(inputs.grad.data, 0)
    gradient = (gradient.float() - 0.5) * 2
    
    # Adding small perturbations to images
    tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)
    
    # Forward pass on perturbed inputs
    with torch.no_grad():
        outputs = forward_func(tempInputs, model)
    
    outputs = outputs / temper
    
    # Calculating the confidence after adding perturbations
    nnOutputs = outputs.data.cpu()
    nnOutputs = nnOutputs.numpy()
    nnOutputs = nnOutputs - np.max(nnOutputs, axis=1, keepdims=True)
    nnOutputs = np.exp(nnOutputs) / np.sum(np.exp(nnOutputs), axis=1, keepdims=True)
    
    scores = np.max(nnOutputs, axis=1)
    
    return scores


def forward_func(inputs, model):
    """Simple forward function."""
    return model(inputs)


def compute_auroc(id_scores, ood_scores):
    """Compute AUROC."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    return roc_auc_score(labels, scores) * 100


def compute_fpr95(id_scores, ood_scores):
    """Compute FPR@95."""
    scores = np.concatenate([id_scores, ood_scores])
    labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))])
    fpr, tpr, _ = roc_curve(labels, scores, pos_label=1)
    idx = np.argmax(tpr >= 0.95)
    return fpr[idx] * 100


class FlatImageDataset(Dataset):
    """Dataset for loading images from flat directory structure."""
    def __init__(self, root_dir, transform=None, extensions=('.jpg', '.jpeg', '.png', '.JPEG', '.JPG', '.PNG')):
        self.root_dir = Path(root_dir)
        self.transform = transform
        
        # Get all image files
        self.image_paths = []
        for ext in extensions:
            self.image_paths.extend(list(self.root_dir.rglob(f'*{ext}')))
        
        self.image_paths.sort()
        print(f"  Found {len(self.image_paths)} images")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new('RGB', (224, 224), color='black')
        
        if self.transform:
            image = self.transform(image)
        
        return image, 0  # Dummy label


def evaluate_odin_ood(model, id_loader, ood_loaders_dict, temperature=1000, magnitude=0.0014):
    """
    Evaluate ODIN OOD detection.
    
    Args:
        model: torch.nn.Module
        id_loader: DataLoader for ID data
        ood_loaders_dict: dict {ood_name: ood_loader}
        temperature: float - temperature scaling
        magnitude: float - perturbation magnitude (epsilon)
    
    Returns:
        Dictionary with results
    """
    print("="*70)
    print("ODIN OOD Detection")
    print(f"Temperature: {temperature}")
    print(f"Magnitude (epsilon): {magnitude}")
    print("="*70)
    
    model.eval()
    
    method_args = {
        'temperature': temperature,
        'magnitude': magnitude
    }
    
    # Compute ODIN scores for ID data
    print(f"\n[Step 1/2] Computing ODIN scores for ID data...")
    id_scores = []
    
    for batch_idx, (inputs, _) in enumerate(id_loader):
        inputs = inputs.to(device)
        batch_scores = get_odin_score(inputs, model, forward_func, method_args)
        id_scores.append(batch_scores)
        
        if (batch_idx + 1) % 50 == 0:
            processed = (batch_idx + 1) * id_loader.batch_size
            total = len(id_loader.dataset)
            print(f"  Progress: {processed}/{total} ({100*processed/total:.1f}%)")
    
    id_scores = np.concatenate(id_scores)
    print(f"\nID ODIN scores: min={id_scores.min():.4f}, max={id_scores.max():.4f}, mean={id_scores.mean():.4f}")
    
    # Evaluate on each OOD dataset
    print(f"\n[Step 2/2] Evaluating on OOD datasets...")
    print("="*70)
    
    results = {}
    
    for ood_idx, (ood_name, ood_loader) in enumerate(ood_loaders_dict.items(), 1):
        print(f"\nOOD Dataset {ood_idx}/{len(ood_loaders_dict)}: {ood_name}")
        
        ood_scores = []
        for batch_idx, (inputs, _) in enumerate(ood_loader):
            inputs = inputs.to(device)
            batch_scores = get_odin_score(inputs, model, forward_func, method_args)
            ood_scores.append(batch_scores)
            
            if (batch_idx + 1) % 50 == 0:
                processed = (batch_idx + 1) * ood_loader.batch_size
                total = len(ood_loader.dataset)
                print(f"  Progress: {processed}/{total} ({100*processed/total:.1f}%)")
        
        ood_scores = np.concatenate(ood_scores)
        print(f"  OOD ODIN scores: min={ood_scores.min():.4f}, max={ood_scores.max():.4f}, mean={ood_scores.mean():.4f}")
        
        # Compute metrics
        auroc = compute_auroc(id_scores, ood_scores)
        fpr95 = compute_fpr95(id_scores, ood_scores)
        
        results[ood_name] = {
            'fpr95': fpr95,
            'auroc': auroc,
            'num_samples': len(ood_scores)
        }
        
        print(f"  ✓ Results: FPR95={fpr95:.2f}%, AUROC={auroc:.2f}%")
    
    # Compute averages
    avg_fpr95 = np.mean([r['fpr95'] for r in results.values()])
    avg_auroc = np.mean([r['auroc'] for r in results.values()])
    results['average'] = {'fpr95': avg_fpr95, 'auroc': avg_auroc}
    
    # Print summary
    print(f"\n{'='*70}")
    print("ODIN OOD Detection Results")
    print(f"{'='*70}")
    print(f"{'OOD Dataset':<20} {'FPR95 ↓':>12} {'AUROC ↑':>12} {'Samples':>12}")
    print(f"{'-'*70}")
    
    for ood_name, metrics in results.items():
        if ood_name == 'average':
            continue
        print(f"{ood_name:<20} {metrics['fpr95']:>11.2f}% {metrics['auroc']:>11.2f}% {metrics['num_samples']:>12}")
    
    print(f"{'-'*70}")
    print(f"{'Average':<20} {avg_fpr95:>11.2f}% {avg_auroc:>11.2f}% {'-':>12}")
    print(f"{'='*70}\n")
    
    return results


# ============================================================================
# SETUP
# ============================================================================

# Load model
model = torchvision.models.resnet50(pretrained=True)
model = model.to(device)
model.eval()

# Define transforms (same as used for logit extraction)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

# ============================================================================
# LOAD DATASETS
# ============================================================================

# ID dataset (ImageNet validation)
print("\nLoading ID dataset...")
id_dataset_path = r"E:\datasets\ImageNet\ILSVRC2012_img_val"
id_dataset = ImageFolder(id_dataset_path, transform=transform)
id_loader = DataLoader(id_dataset, batch_size=128, shuffle=False, num_workers=0, pin_memory=True)
print(f"ID dataset: {len(id_dataset)} samples")

# OOD datasets
print("\nLoading OOD datasets...")
ood_datasets = {
    "iNaturalist": r"E:\datasets\KNN-OOD\iNaturalist\images",
    "SUN": r"E:\datasets\KNN-OOD\SUN\images",
    "Places": r"E:\datasets\KNN-OOD\Places\images",
    "Textures": r"E:\datasets\KNN-OOD\Textures\images",
}

ood_loaders_dict = {}
for ood_name, ood_path in ood_datasets.items():
    print(f"\n{ood_name}:")
    # Use FlatImageDataset if structure is flat, ImageFolder if structured
    try:
        # Try ImageFolder first (if has class subdirectories)
        dataset = ImageFolder(ood_path, transform=transform)
    except:
        # Fall back to FlatImageDataset (flat structure)
        dataset = FlatImageDataset(ood_path, transform=transform)
    
    ood_loaders_dict[ood_name] = DataLoader(
        dataset, 
        batch_size=128, 
        shuffle=False, 
        num_workers=0, 
        pin_memory=True
    )

# ============================================================================
# RUN EVALUATION
# ============================================================================

# Paper uses: temperature=1000, magnitude=0.0014 for ImageNet
results = evaluate_odin_ood(
    model=model,
    id_loader=id_loader,
    ood_loaders_dict=ood_loaders_dict,
    temperature=1000,
    magnitude=0.0014
)

# Access results
print("\nFinal Summary:")
for ood_name in ood_datasets.keys():
    print(f"  {ood_name}: FPR95={results[ood_name]['fpr95']:.2f}%, AUROC={results[ood_name]['auroc']:.2f}%")
print(f"\nAverage: FPR95={results['average']['fpr95']:.2f}%, AUROC={results['average']['auroc']:.2f}%")

Using device: cuda





Loading ID dataset...
ID dataset: 50000 samples

Loading OOD datasets...

iNaturalist:
  Found 20000 images

SUN:
  Found 20000 images

Places:
  Found 20000 images

Textures:
ODIN OOD Detection
Temperature: 1000
Magnitude (epsilon): 0.0014

[Step 1/2] Computing ODIN scores for ID data...


	add(Tensor input, Number alpha, Tensor other, *, Tensor out = None)
Consider using one of the following signatures instead:
	add(Tensor input, Tensor other, *, Number alpha = 1, Tensor out = None) (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\utils\python_arg_parser.cpp:1707.)
  tempInputs = torch.add(inputs.data, -noiseMagnitude1, gradient)


  Progress: 6400/50000 (12.8%)
  Progress: 12800/50000 (25.6%)
  Progress: 19200/50000 (38.4%)
  Progress: 25600/50000 (51.2%)
  Progress: 32000/50000 (64.0%)
  Progress: 38400/50000 (76.8%)
  Progress: 44800/50000 (89.6%)

ID ODIN scores: min=0.0010, max=0.0011, mean=0.0010

[Step 2/2] Evaluating on OOD datasets...

OOD Dataset 1/4: iNaturalist
  Progress: 6400/20000 (32.0%)
  Progress: 12800/20000 (64.0%)
  Progress: 19200/20000 (96.0%)
  OOD ODIN scores: min=0.0010, max=0.0010, mean=0.0010
  ✓ Results: FPR95=41.93%, AUROC=92.24%

OOD Dataset 2/4: SUN
  Progress: 6400/20000 (32.0%)
  Progress: 12800/20000 (64.0%)
  Progress: 19200/20000 (96.0%)
  OOD ODIN scores: min=0.0010, max=0.0010, mean=0.0010
  ✓ Results: FPR95=57.17%, AUROC=86.76%

OOD Dataset 3/4: Places
  Progress: 6400/20000 (32.0%)
  Progress: 12800/20000 (64.0%)
  Progress: 19200/20000 (96.0%)
  OOD ODIN scores: min=0.0010, max=0.0010, mean=0.0010
  ✓ Results: FPR95=64.72%, AUROC=84.11%

OOD Dataset 4/4: Textures
  OOD OD