## Import libraries                                                                      

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import Tuple, Optional, Dict

## SoftmaxOOD Class
Enhanced softmax-based Out-of-Distribution detection with multiple uncertainty estimators.
* Compatible with the existing SoftmaxDetector interface while providing additional capabilities.
    
* Based on the theoretical analysis of softmax uncertainty estimation from the research paper.

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.mixture import GaussianMixture

class SoftmaxOOD:
    def __init__(self, method="umax"):
        self.method = method
        self.tau = None
        self.gmm = None
        self.training_stats = {}
        self.num_classes = None
        self.final_layer_weights = None
        
    def _get_actual_model(self, model):
        """Extract the actual model from wrapped models like OnlineEWC."""
        if hasattr(model, 'model'):
            return model.model
        return model
    
    def _get_device_from_model(self, model):
        """Get device from model parameters."""
        actual_model = self._get_actual_model(model)
        return next(actual_model.parameters()).device
        
    def fit(self, model, buffer, device=None, num_classes=None, reg_cov=1e-5):
        """
        Fit the detector parameters using training data.
        
        Args:
            model: Neural network model (can be wrapped like OnlineEWC)
            buffer: Data buffer containing training samples
            device: Computing device (auto-detected if None)
            num_classes: Number of classes in the dataset
            reg_cov: Regularization for covariance (used for density estimation)
        """
        # Auto-detect device if not provided
        if device is None:
            device = self._get_device_from_model(model)
            
        self.num_classes = num_classes
        actual_model = self._get_actual_model(model)
        
        if self.method == 'udensity':
            X_train, y_train = buffer.get_all_data()
            if X_train is None:
                return
            
            actual_model.eval()
            features_list = []
            with torch.no_grad():
                for i in range(0, len(X_train), 256):
                    batch = X_train[i:i+256].to(device).float()
                    _, features = actual_model(batch)
                    features_list.append(features.cpu().numpy())
            all_features = np.concatenate(features_list, axis=0)
            
            # Use provided num_classes or determine from labels
            if y_train is not None:
                actual_classes = len(torch.unique(y_train))
                if num_classes is not None:
                    self.num_classes = min(num_classes, actual_classes)
                else:
                    self.num_classes = actual_classes
            
            self.gmm = GaussianMixture(
                n_components=self.num_classes,
                covariance_type='full',
                reg_covar=reg_cov,
                random_state=42
            )
            self.gmm.fit(all_features)
        
        # Extract final layer weights for mental model and analysis
        self._extract_final_layer_weights(actual_model)
                
    def _extract_final_layer_weights(self, model):
        """Extract weights from the final linear layer."""
        final_layer = None
        for module in reversed(list(model.modules())):
            if isinstance(module, nn.Linear):
                final_layer = module
                break
        
        if final_layer is not None:
            self.final_layer_weights = final_layer.weight.data.T.detach()
            
    def score(self, model, x, device=None):
        """
        Compute uncertainty scores for input samples.
        Higher scores indicate higher uncertainty (more likely OOD).
        """
        # Auto-detect device if not provided
        if device is None:
            device = self._get_device_from_model(model)
            
        actual_model = self._get_actual_model(model)
        actual_model.eval()
        
        with torch.no_grad():
            logits, features = actual_model(x)  # Model returns (logits, features)
            
        if self.method == 'umax':
            return self._compute_umax(logits)
        elif self.method == 'uentropy':
            return self._compute_uentropy(logits)
        elif self.method == 'udensity':
            return self._compute_udensity(features)
        elif self.method == 'mental_model':
            return self._compute_mental_model(logits, features)
        else:
            raise ValueError(f"Unknown method: {self.method}")
    
    def _compute_umax(self, logits):
        """Compute Umax: negative maximum predicted probability."""
        probs = F.softmax(logits, dim=1)
        max_probs = probs.max(dim=1)[0]
        return -max_probs  # Higher (less negative) = more uncertain
    
    def _compute_uentropy(self, logits):
        """Compute Uentropy: prediction entropy."""
        probs = F.softmax(logits, dim=1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
        return entropy  # Higher = more uncertain
    
    def _compute_udensity(self, features):
        """Compute Udensity: negative log-likelihood under fitted GMM."""
        if self.gmm is None:
            raise ValueError("Must call fit() first for density-based detection")
        
        features_np = features.cpu().numpy()
        log_likelihood = self.gmm.score_samples(features_np)
        return torch.tensor(-log_likelihood, device=features.device)
    
    def _compute_mental_model(self, logits, features):
        """
        Compute uncertainty using the complete mental model from Proposition 6.
        U_max_mental(z) = -1 / (1 + (K-1) * exp(-||z|| * (1/(K-1) + max cos θ_i,z)))
        """
        if self.num_classes is None:
            self.num_classes = logits.size(1)
        
        # Compute feature norms ||z||
        feature_norms = torch.norm(features, dim=1)
        
        # Compute max cosine similarity with weight vectors
        if self.final_layer_weights is not None:
            max_cos_theta = self._compute_max_cosine_alignment(features, self.final_layer_weights)
        else:
            # Fallback: approximate using softmax probabilities
            probs = F.softmax(logits, dim=1)
            max_cos_theta = torch.max(probs, dim=1)[0]
        
        # Mental model formula from equation (86)
        K = self.num_classes
        exponent = -feature_norms * (1.0/(K - 1) + max_cos_theta)
        uncertainty = -1.0 / (1.0 + (K - 1) * torch.exp(exponent))
        
        return uncertainty
    
    def _compute_max_cosine_alignment(self, features, weights):
        """
        Compute max_i cos(θ_{z,i}) - maximum cosine similarity between 
        feature vectors and weight vectors.
        """
        # Normalize features and weights
        features_norm = F.normalize(features, dim=1)
        weights_norm = F.normalize(weights, dim=0)  # weights is (H, K)
        
        # Compute cosine similarities: (N, K)
        cosine_similarities = torch.mm(features_norm, weights_norm)
        
        # Return maximum similarity for each feature vector
        max_cosines = torch.max(cosine_similarities, dim=1)[0]
        
        return max_cosines
    
    def compute_feature_statistics(self, features):
        """
        Compute feature statistics as described in Section 5.
        Returns ||z|| and max cos θ statistics.
        """
        feature_norms = torch.norm(features, dim=1)
        
        stats = {
            'feature_norm_mean': torch.mean(feature_norms).item(),
            'feature_norm_std': torch.std(feature_norms).item(),
            'feature_norm_min': torch.min(feature_norms).item(),
            'feature_norm_max': torch.max(feature_norms).item()
        }
        
        if self.final_layer_weights is not None:
            max_cosines = self._compute_max_cosine_alignment(features, self.final_layer_weights)
            stats.update({
                'max_cosine_mean': torch.mean(max_cosines).item(),
                'max_cosine_std': torch.std(max_cosines).item(),
                'max_cosine_min': torch.min(max_cosines).item(),
                'max_cosine_max': torch.max(max_cosines).item()
            })
        
        return stats
    
    def detect(self, model, x, device=None):
        """Binary OOD detection using the fitted threshold."""
        if self.tau is None:
            raise ValueError("Must call set_threshold() first")
        return self.score(model, x, device) > self.tau
    
    def set_threshold(self, model, buffer, device=None, false_positive_rate=0.2):
        """Set detection threshold based on in-distribution data."""
        # Auto-detect device if not provided
        if device is None:
            device = self._get_device_from_model(model)
            
        X_id, _ = buffer.get_all_data()
        if X_id is None:
            return
            
        s_id = []
        for i in range(0, len(X_id), 256):
            batch = X_id[i:i+256].to(device).float()
            s_id.append(self.score(model, batch, device))
        s_id = torch.cat(s_id)
        
        self.tau = torch.quantile(s_id, 1 - false_positive_rate).item()
    
    def analyze_model_structure(self, model, device=None):
        """
        Analyze the decision boundary structure of the model's final layer.
        """
        # Auto-detect device if not provided
        if device is None:
            device = self._get_device_from_model(model)
            
        actual_model = self._get_actual_model(model)
        actual_model.eval()
        
        # Get the final layer weights
        final_layer = None
        for module in reversed(list(actual_model.modules())):
            if isinstance(module, nn.Linear):
                final_layer = module
                break
        
        if final_layer is None:
            return {"error": "Could not find final linear layer"}
        
        weights = final_layer.weight.data  # Shape: (num_classes, feature_dim)
        biases = final_layer.bias.data if final_layer.bias is not None else None
        
        # Transpose to match paper notation (H x K)
        weights = weights.T
        
        return self._analyze_decision_boundary_structure(weights, biases)
    
    def _analyze_decision_boundary_structure(self, weights, biases=None):
        """
        Analyze decision boundary structure following Section 4.2.
        Checks for optimal structure properties:
        1. ||w_i|| = constant (equal weight magnitudes)
        2. cos θ_{i,j} = -1/(K-1) (evenly distributed weights)
        3. Bias values ≈ 0
        """
        weights_np = weights.detach().cpu().numpy()
        num_classes = weights_np.shape[1]
        
        # 1. Compute weight magnitudes
        weight_norms = np.linalg.norm(weights_np, axis=0)
        
        # 2. Compute pairwise cosine similarities
        weights_normalized = weights_np / (weight_norms + 1e-8)
        cosine_similarities = np.dot(weights_normalized.T, weights_normalized)
        
        # Remove diagonal elements for pairwise analysis
        mask = ~np.eye(num_classes, dtype=bool)
        pairwise_cosines = cosine_similarities[mask]
        
        # 3. Theoretical optimal cosine for evenly distributed weights
        optimal_cosine = -1.0 / (num_classes - 1) if num_classes > 1 else 0.0
        
        results = {
            'num_classes': num_classes,
            'weight_norm_mean': np.mean(weight_norms),
            'weight_norm_std': np.std(weight_norms),
            'weight_norm_uniformity': np.std(weight_norms) / (np.mean(weight_norms) + 1e-8),
            'pairwise_cosine_mean': np.mean(pairwise_cosines),
            'pairwise_cosine_std': np.std(pairwise_cosines),
            'optimal_cosine_target': optimal_cosine,
            'cosine_deviation_from_optimal': abs(np.mean(pairwise_cosines) - optimal_cosine),
            'structure_optimality_score': 1.0 / (1.0 + abs(np.mean(pairwise_cosines) - optimal_cosine)),
            'is_structure_optimal': abs(np.mean(pairwise_cosines) - optimal_cosine) < 0.1
        }
        
        if biases is not None:
            biases_np = biases.detach().cpu().numpy()
            results.update({
                'bias_mean': np.mean(biases_np),
                'bias_std': np.std(biases_np),
                'bias_magnitude': np.mean(np.abs(biases_np)),
                'biases_near_zero': np.mean(np.abs(biases_np)) < 0.1
            })
        
        return results
    
    def compute_valid_ood_region_size(self, model, device=None, sample_size=10000):
        """
        Estimate the size of the valid OOD region by sampling.
        """
        # Auto-detect device if not provided
        if device is None:
            device = self._get_device_from_model(model)
            
        if self.final_layer_weights is None:
            return {"error": "Need final layer weights for region analysis"}
        
        # Sample random points in feature space
        feature_dim = self.final_layer_weights.size(0)
        sample_points = torch.randn(sample_size, feature_dim, device=device)
        
        # Compute uncertainty scores for sampled points
        actual_model = self._get_actual_model(model)
        actual_model.eval()
        with torch.no_grad():
            # Create dummy logits by passing through final layer
            logits = torch.mm(sample_points, self.final_layer_weights.to(device))
            
        uncertainty_scores = self._compute_umax(logits)
        
        # Estimate region properties
        mean_uncertainty = torch.mean(uncertainty_scores).item()
        std_uncertainty = torch.std(uncertainty_scores).item()
        
        return {
            'sample_size': sample_size,
            'mean_uncertainty': mean_uncertainty,
            'std_uncertainty': std_uncertainty,
            'high_uncertainty_fraction': (uncertainty_scores > mean_uncertainty + std_uncertainty).float().mean().item()
        }
    
    def get_method_info(self):
        """Return information about the current uncertainty method."""
        info = {
            'umax': 'Maximum predicted probability (negative) - Equation (2)',
            'uentropy': 'Prediction entropy - Equation (2)', 
            'udensity': 'Gaussian Mixture Model density - Equation (3)',
            'mental_model': 'Mental model approximation - Proposition 6'
        }
        return {
            'method': self.method,
            'description': info.get(self.method, 'Unknown method'),
            'requires_fitting': self.method in ['udensity'],
            'threshold_set': self.tau is not None,
            'num_classes': self.num_classes,
            'has_weights': self.final_layer_weights is not None
        }

In [None]:
class SoftmaxDetector(SoftmaxOOD):
    """
    Drop-in replacement for the original SoftmaxDetector with same interface.
    Uses Umax method by default for backward compatibility.
    """
    
    def __init__(self):
        super().__init__(method='umax')
#* Usage:
'''
def enhanced_usage():
    """Demonstrate the enhanced capabilities."""
    
    # Create detectors with different methods
    detectors = {
        'umax': SoftmaxOOD(method='umax'),
        'uentropy': SoftmaxOOD(method='uentropy'), 
        'udensity': SoftmaxOOD(method='udensity'),
        'mental_model': SoftmaxOOD(method='mental_model')
    }
    
    print("Available uncertainty methods:")
    for name, detector in detectors.items():
        info = detector.get_method_info()
        print(f"- {name}: {info['description']}")
    
    return detectors
'''