In [31]:
import os
import io
import json
import time
import copy
import glob
import random
import shutil
import traceback
import builtins
import itertools
from enum import Enum, auto
from pathlib import Path
from contextlib import contextmanager, redirect_stdout, redirect_stderr
from dataclasses import dataclass
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import torch
from torch import Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from google.colab import files
from scipy.ndimage import gaussian_filter
from scipy.stats import chi2, mannwhitneyu
from sklearn.model_selection import StratifiedShuffleSplit
from datetime import datetime
import zipfile
from typing import Iterator

In [None]:
ABSTENTION_WEIGHT = 12 # Scaling factor for loss for non-abstention on invalid inputs
INCORRECT_ABSTENTION_PENALTY = 25 # Loss penalty for abstention on valid inpouts
def set_seed(seed): # Set global seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return seed
seed = set_seed(1)

In [None]:
class ArithmeticDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]],
                 noise_config: Dict[str, Any] = {'enabled': False, 'std': 0.0}
                 ) -> None:
        """
        Args:
            data: Dataset of arithmetic and abstention problems
            noise_config: Dict with keys:
                - enabled: Boolean to enable/disable noise
                - std: Standard deviation of Gaussian noise
        """
        self.data = data
        self.noise_config = noise_config

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
        sample: Dict[str, Any] = self.data[idx]
        numbers: Tensor = torch.tensor([sample['Argument 1'], sample['Argument 2']], # Get a sample
                                     dtype=torch.float32)

        if self.noise_config['enabled']:
# Add random noise scaled by the value of 'std' to the dataset, aims to force the model to learn a more complex/difficult decision boundary.
            with torch.random.fork_rng():
                torch.manual_seed(seed)
                noise: Tensor = torch.randn_like(numbers) * self.noise_config['std']
                numbers = numbers + noise

        # Map operators to indices
        op_map: Dict[str, int] = {'+': 0, '-': 1, '@': 2} # @ operator means always abstain, used to test generalization to novel operators.
        operator: Tensor = torch.tensor(op_map[sample['Operator']], dtype=torch.long)
        result: Tensor = torch.tensor([sample['Result']], dtype=torch.float32)

        # Return the tuple of tensors
        return numbers, operator, result

    @staticmethod
    def get_train_val_test_loaders(
        dataset_dict: Dict[str, List[Dict[str, Any]]],
        batch_size: int = 32,
        val_ratio: float = 0.1,
        noise_config: Optional[Dict[str, Any]] = None
    ) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """
        Create loaders for train, validation, and test sets.
        Args:
            dataset_dict: Dictionary containing 'train' and 'test' splits of dataset
            batch_size: Batch size for dataloaders
            val_ratio: Ratio of training data to use for validation
            noise_config: Optional noise configuration dictionary
        """
        # Set default noise config if none provided
        noise_config = noise_config or {'enabled': False, 'std': 0.0}

        # Keep test set separate
        test_data: List[Dict[str, Any]] = dataset_dict['test']
        train_full: List[Dict[str, Any]] = dataset_dict['train']

        # Create stratification keys for training data
        stratification_keys: List[str] = []
        for item in train_full:
            op_type: str = item['Operator']

            # Determine validity
            if op_type == '@':
                category: str = 'invalid'
            elif op_type == '+':
                sum_result: float = item['Argument 1'] + item['Argument 2']
                category = 'overflow' if sum_result > 400 else 'valid' # The category of all of these should maybe be invalid, overflow and underflow don't really add much
            elif op_type == '-':
                diff_result: float = item['Argument 1'] - item['Argument 2']
                category = 'underflow' if diff_result < 0 else 'valid'

            strat_key: str = f"{op_type}_{category}"
            stratification_keys.append(strat_key)

        # Get indices for training and validation split
        split: StratifiedShuffleSplit = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
        indices: np.ndarray = np.arange(len(train_full))
        train_idx, val_idx = next(split.split(indices, stratification_keys))

        # Build datasets using indices
        train_data: List[Dict[str, Any]] = [train_full[i] for i in train_idx]
        val_data: List[Dict[str, Any]] = [train_full[i] for i in val_idx]

        # Create datasets with noise config
        train_dataset: ArithmeticDataset = ArithmeticDataset(train_data, noise_config=noise_config)
        val_dataset: ArithmeticDataset = ArithmeticDataset(val_data, noise_config=noise_config)
        test_dataset: ArithmeticDataset = ArithmeticDataset(test_data, noise_config=noise_config)

        # Create loaders
        train_loader: DataLoader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader: DataLoader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        test_loader: DataLoader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        return train_loader, val_loader, test_loader


def is_invalid_computation(numbers: Tensor, operator: Tensor) -> Tensor:
    """Check if computation would be invalid"""
    invalid: Tensor = torch.zeros_like(operator, dtype=torch.bool)

    # For addition cases, results that would exceed 400 are invalid
    add_mask: Tensor = (operator == 0)
    invalid[add_mask] = (numbers[add_mask, 0] + numbers[add_mask, 1] > 400)

    # For subtraction cases, results less than 0 are invalid
    sub_mask: Tensor = (operator == 1)
    invalid[sub_mask] = (numbers[sub_mask, 0] - numbers[sub_mask, 1] < 0)

    # The @ operator makes the result automatically invalid
    special_mask: Tensor = (operator == 2)
    invalid[special_mask] = True

    return invalid

def create_boundary_test_loader(base_loader, num_samples=1000):
    """
    Creates a test loader with a specific distribution of cases:
    - 40% invalid cases (overflow/underflow)
    - 40% boundary cases (valid but near limits)
    - 20% easy cases (valid and far from boundaries)

    Args:
        base_loader: Original data loader
        num_samples: Number of samples to include
    """

    def generate_invalid_case():
        """Generate a case that crosses boundaries"""
        op = random.choice(['+', '-'])
        if op == '+':
            # Generate sum > 400
            result = random.randint(401, 500)
            arg2 = random.randint(1, 200)
            arg1 = result - arg2
        else:
            # Generate difference < 0
            result = random.randint(-100, -1)
            arg2 = random.randint(1, 200)
            arg1 = result + arg2

        return {
            'Argument 1': arg1,
            'Argument 2': arg2,
            'Operator': op,
            'Result': result
        }

    def is_boundary_case(sample):
        """Check if a valid sample is near the boundary"""
        if sample['Operator'] == '+':
            result = sample['Argument 1'] + sample['Argument 2']
            return 350 <= result <= 400  # Within 50 of overflow boundary
        else:
            result = sample['Argument 1'] - sample['Argument 2']
            return 0 <= result <= 50  # Within 50 of underflow boundary

    def is_valid_case(sample):
        """Check if a sample is valid (within bounds)"""
        if sample['Operator'] == '+':
            return sample['Argument 1'] + sample['Argument 2'] <= 400
        else:
            return sample['Argument 1'] - sample['Argument 2'] >= 0

    # Calculate desired numbers of each type
    num_invalid = int(num_samples * 0.4)
    num_boundary = int(num_samples * 0.4)
    num_easy = num_samples - num_invalid - num_boundary

    # Generate invalid cases
    invalid_samples = [generate_invalid_case() for _ in range(num_invalid)]

    # Collect boundary and easy cases from base loader
    boundary_samples = []
    easy_samples = []

    for idx in range(len(base_loader.dataset.data)):
        sample = copy.deepcopy(base_loader.dataset.data[idx])

        # Calculate and store result
        if sample['Operator'] == '+':
            sample['Result'] = sample['Argument 1'] + sample['Argument 2']
        else:  # '-'
            sample['Result'] = sample['Argument 1'] - sample['Argument 2']

        if not is_valid_case(sample):
            continue

        if is_boundary_case(sample):
            if len(boundary_samples) < num_boundary:
                boundary_samples.append(sample)
        else:
            if len(easy_samples) < num_easy:
                easy_samples.append(sample)

        if (len(boundary_samples) >= num_boundary and
            len(easy_samples) >= num_easy):
            break

    # Handle cases where we don't have enough samples
    if len(boundary_samples) < num_boundary:
        print(f"Warning: Only found {len(boundary_samples)} boundary samples")
        missing = num_boundary - len(boundary_samples)
        num_easy = num_easy + missing // 2
        num_invalid = num_invalid + (missing - missing // 2)

    if len(easy_samples) < num_easy:
        print(f"Warning: Only found {len(easy_samples)} easy samples")
        missing = num_easy - len(easy_samples)
        num_invalid += missing

    # Combine all samples
    selected_samples = (
        invalid_samples[:num_invalid] +
        boundary_samples[:len(boundary_samples)] +
        easy_samples[:len(easy_samples)]
    )
    random.shuffle(selected_samples)

    # Create dataset with same noise config as base loader
    boundary_dataset = ArithmeticDataset(
        data=selected_samples,
        noise_config=base_loader.dataset.noise_config
    )

    # Create and return loader
    return DataLoader(
        boundary_dataset,
        batch_size=base_loader.batch_size,
        shuffle=True
    )

@contextmanager # Disable printing metrics for faster training
def suppress_output():
    original_stdout = os.environ.get('PYTHONUNBUFFERED')
    original_tqdm = os.environ.get('TQDM_DISABLE')
    original_print = builtins.print  # Save the original print function

    try:
        os.environ['PYTHONUNBUFFERED'] = '0'
        os.environ['TQDM_DISABLE'] = '1'

        # Create a dummy file-like object that just discards writes
        dummy = io.StringIO()
        builtins.print = lambda *args, **kwargs: None  # Replace print with no-op
        with redirect_stdout(dummy), redirect_stderr(dummy):
            yield

    finally:
        # Restore the environment variables
        if original_stdout:
            os.environ['PYTHONUNBUFFERED'] = original_stdout
        else:
            os.environ.pop('PYTHONUNBUFFERED', None)

        if original_tqdm:
            os.environ['TQDM_DISABLE'] = original_tqdm
        else:
            os.environ.pop('TQDM_DISABLE', None)

        builtins.print = original_print  # Restore the original print function
        dummy.close()

In [None]:
class ArithmeticNet(nn.Module): # Core architecture of the experiment, intentionally very simple to save on compute and minimize complexity.
    def __init__(self, hidden_size: int = 128) -> None: # 128 Neurons
        super().__init__()

        # Input processing
        self.num_embedder: nn.Linear = nn.Linear(2, hidden_size)
        self.op_embedding: nn.Embedding = nn.Embedding(3, hidden_size)

        # Core processing
        self.layer1: nn.Linear = nn.Linear(hidden_size * 2, hidden_size)
        self.layer2: nn.Linear = nn.Linear(hidden_size, hidden_size)
        self.layer3: nn.Linear = nn.Linear(hidden_size, hidden_size//2)

        # Single output - no separate abstention head
        self.output: nn.Linear = nn.Linear(hidden_size//2, 1)

    def forward(self, numbers: Tensor, operator: Tensor) -> Tensor:
        # Embed inputs
        num_features: Tensor = self.num_embedder(numbers)
        op_features: Tensor = self.op_embedding(operator)

        # Combine features
        x: Tensor = torch.cat([num_features, op_features], dim=1)

        # Process with residual connections for better gradient flow
        x1: Tensor = F.relu(self.layer1(x))
        x2: Tensor = F.relu(self.layer2(x1)) + x1
        x3: Tensor = self.layer3(x2)

        return self.output(x3)

class TrainingMode(Enum):
    """
    Enum defining different training modes.

    NORMAL: Standard training with full validation and analysis
    FAST: Optimized for speed with reduced validation frequency
    HYBRID: Transitions from normal to fast training after initial epochs
    """
    NORMAL = auto()
    FAST = auto()
    HYBRID = auto()

@dataclass
class TrainingConfig:
    """
    Unified configuration for managing training parameters and optimization settings.

    This class handles both static and dynamic training configurations, supporting
    normal, fast, and hybrid training modes. In hybrid mode, it automatically
    transitions from normal to fast training after a specified number of epochs.

    Attributes:
        mode (TrainingMode): Training mode (NORMAL, FAST, or HYBRID)
        batch_size (int): Base batch size for training
        grad_accum_steps (int): Number of steps for gradient accumulation
        val_freq (int): Frequency of validation during training
        landscape_freq (int): Frequency of loss landscape analysis
        initial_epochs (Optional[int]): Number of initial epochs before switching to fast training in hybrid mode
        current_epoch (int): Current training epoch (used for hybrid mode transitions)

    Properties:
        is_fast_training (bool): Whether fast training optimizations are currently active
        effective_batch_size (int): Actual batch size after applying fast training multiplier
        effective_grad_accum_steps (int): Actual gradient accumulation steps after mode adjustments
        effective_val_freq (int): Actual validation frequency after mode adjustments
    """

    def __init__(
        self,
        mode: TrainingMode = TrainingMode.NORMAL,
        *,
        batch_size: int = 32,
        grad_accum_steps: int = 4,
        val_freq: int = 1,
        landscape_freq: int = 40,
        initial_epochs: Optional[int] = None
    ) -> None:
        """
        Initialize training configuration.

        Args:
            mode: Training mode to use
            batch_size: Base batch size
            grad_accum_steps: Base gradient accumulation steps
            val_freq: Base validation frequency
            landscape_freq: Frequency of loss landscape analysis
            initial_epochs: Required for HYBRID mode, specifies transition point
        """
        self.mode = mode
        self.batch_size = batch_size
        self.grad_accum_steps = grad_accum_steps
        self.val_freq = val_freq
        self.landscape_freq = landscape_freq
        self.current_epoch = 0

        # Hybrid mode specific settings
        if mode == TrainingMode.HYBRID:
            if initial_epochs is None:
                raise ValueError("initial_epochs must be specified for HYBRID mode")
            self.initial_epochs = initial_epochs
        else:
            self.initial_epochs = None

        # Validate configuration
        self._validate_config()

    def _validate_config(self) -> None:
        """Validate configuration parameters."""
        if self.batch_size <= 0:
            raise ValueError("batch_size must be positive")
        if self.grad_accum_steps <= 0:
            raise ValueError("grad_accum_steps must be positive")
        if self.val_freq <= 0:
            raise ValueError("val_freq must be positive")
        if self.landscape_freq <= 0:
            raise ValueError("landscape_freq must be positive")

    @property
    def is_fast_training(self) -> bool:
        """
        Determine if fast training optimizations should be active.

        Returns:
            True if in FAST mode or if in HYBRID mode past initial epochs
        """
        if self.mode == TrainingMode.FAST:
            return True
        if self.mode == TrainingMode.HYBRID:
            return self.current_epoch >= self.initial_epochs
        return False

    @property
    def effective_batch_size(self) -> int:
        """
        Get the current effective batch size.

        Returns:
            Batch size adjusted for fast training mode
        """
        return self.batch_size * 2 if self.is_fast_training else self.batch_size

    @property
    def effective_grad_accum_steps(self) -> int:
        """
        Get the current effective gradient accumulation steps.

        Returns:
            Gradient accumulation steps adjusted for current mode
        """
        return self.grad_accum_steps if self.is_fast_training else 1

    @property
    def effective_val_freq(self) -> int:
        """
        Get the current effective validation frequency.

        Returns:
            Validation frequency adjusted for fast training mode
        """
        return 2 if self.is_fast_training else self.val_freq

    def update_epoch(self, epoch: int) -> None:
        """
        Update the current epoch counter and adjust settings accordingly.

        Args:
            epoch: New epoch number
        """
        if epoch < 0:
            raise ValueError("epoch must be non-negative")
        self.current_epoch = epoch

    def get_config_dict(self) -> dict:
        """
        Get current configuration as a dictionary.

        Returns:
            Dictionary containing all current effective settings
        """
        return {
            'mode': self.mode.name,
            'is_fast_training': self.is_fast_training,
            'batch_size': self.effective_batch_size,
            'grad_accum_steps': self.effective_grad_accum_steps,
            'val_freq': self.effective_val_freq,
            'landscape_freq': self.landscape_freq,
            'current_epoch': self.current_epoch,
            'initial_epochs': self.initial_epochs
        }

    def __repr__(self) -> str:
        """Provide a detailed string representation of the configuration."""
        return (
            f"TrainingConfig(mode={self.mode.name}, "
            f"is_fast_training={self.is_fast_training}, "
            f"effective_batch_size={self.effective_batch_size}, "
            f"effective_grad_accum_steps={self.effective_grad_accum_steps}, "
            f"effective_val_freq={self.effective_val_freq}, "
            f"landscape_freq={self.landscape_freq}, "
            f"current_epoch={self.current_epoch}"
            f"{f', initial_epochs={self.initial_epochs}' if self.mode == TrainingMode.HYBRID else ''})"
        )
class UnifiedLossComputer:
    """
    Unified class for computing losses across training and evaluation.

    This class consolidates loss computation logic used by trainers and analyzers,
    providing consistent behavior and configuration across the codebase.

    Attributes:
        -1.0: Value used to indicate model abstention
        abstention_weight: Weight applied to abstention loss term
        incorrect_abstention_penalty: Penalty for abstaining on valid inputs
        rtol: Relative tolerance for abstention detection
        atol: Absolute tolerance for abstention detection
    """

    def __init__(
        self,
        abstention_weight: float = 12.0, # 12, 16, or 25
        incorrect_abstention_penalty: float = 25.0, # 25 or 12
        rtol: float = 0.1,
        atol: float = 0.1
    ) -> None:
        self.abstention_weight = abstention_weight
        self.incorrect_abstention_penalty = incorrect_abstention_penalty
        self.rtol = rtol
        self.atol = atol

    def is_abstention(self, predictions: Tensor) -> Tensor:
        """
        Check if predictions indicate abstention.

        Args:
            predictions: Model output tensor

        Returns:
            Boolean tensor indicating abstention for each prediction
        """
        return torch.isclose(
            predictions.squeeze(),
            torch.tensor(-1.0).to(predictions.device),
            rtol=self.rtol,
            atol=self.atol
        )

    def compute_loss(
        self,
        predictions: Tensor,
        targets: Tensor,
        numbers: Tensor,
        operator: Tensor,
        reduce: bool = True
    ) -> Tuple[Tensor, Dict[str, float]]:
        """
        Compute the total loss with abstention incentives and penalties.

        Args:
            predictions: Model predictions
            targets: Ground truth targets
            numbers: Input numbers for computation
            operator: Operator tokens
            reduce: Whether to reduce the loss to a scalar

        Returns:
            Tuple of (total_loss, component_dict) where component_dict contains:
            - valid_loss: Loss on valid computations
            - invalid_loss: Loss on invalid computations
            - abstention_loss: Penalty for incorrect abstentions
        """
        # Get invalid computation mask
        invalid_mask = is_invalid_computation(numbers, operator)
        valid_mask = ~invalid_mask

        predictions = predictions.squeeze()
        targets = targets.squeeze()

        components = {}

        # Compute MSE loss for valid computations
        valid_loss = torch.tensor(0.0, device=predictions.device)
        if valid_mask.any():
            valid_loss = F.mse_loss(
                predictions[valid_mask],
                targets[valid_mask],
                reduction='mean' if reduce else 'none'
            )
        components['valid_loss'] = valid_loss.item()

        # Compute abstention loss for invalid computations
        invalid_loss = torch.tensor(0.0, device=predictions.device)
        if invalid_mask.any():
            abstention_targets = torch.full_like(
                predictions[invalid_mask],
                -1.0 # Abstention Token - ! Double Check this
            )
            invalid_loss = self.abstention_weight * F.mse_loss(
                predictions[invalid_mask],
                abstention_targets,
                reduction='mean' if reduce else 'none'
            )
        components['invalid_loss'] = invalid_loss.item()

        # Combine base losses
        total_loss = valid_loss + invalid_loss

        # Add penalty for incorrect abstentions on valid computations
        abstention_loss = torch.tensor(0.0, device=predictions.device)
        if valid_mask.any():
            abstained_on_valid = self.is_abstention(predictions[valid_mask])
            if abstained_on_valid.any():
                num_incorrect = abstained_on_valid.sum().float()
                abstention_loss = self.incorrect_abstention_penalty * num_incorrect
                if not reduce:
                    abstention_loss = abstention_loss.expand_as(predictions)
                total_loss = total_loss + abstention_loss
        components['abstention_loss'] = abstention_loss.item()

        return total_loss, components

    def compute_metrics(
        self,
        predictions: Tensor,
        targets: Tensor,
        numbers: Tensor,
        operator: Tensor
    ) -> Dict[str, float]:
        """
        Compute comprehensive metrics for model predictions.

        Args:
            predictions: Model predictions
            targets: Ground truth targets
            numbers: Input numbers
            operator: Operator tokens

        Returns:
            Dictionary containing metrics:
            - loss: Total loss value
            - valid_accuracy: Accuracy on valid computations
            - abstention_rate: Overall abstention rate
            - correct_abstentions: Rate of correct abstentions
            - incorrect_abstentions: Rate of incorrect abstentions
        """
        invalid_mask = is_invalid_computation(numbers, operator)
        valid_mask = ~invalid_mask
        abstained = self.is_abstention(predictions)

        # Calculate loss
        loss, _ = self.compute_loss(predictions, targets, numbers, operator)

        metrics = {
            'loss': loss.item(),
            'abstention_rate': abstained.float().mean().item()
        }

        # Calculate accuracy on valid predictions
        valid_and_not_abstained = valid_mask & (~abstained)
        if valid_and_not_abstained.any():
            pred_vals = predictions[valid_and_not_abstained]
            tgt_vals = targets[valid_and_not_abstained]
            diff = torch.abs(pred_vals - tgt_vals)
            threshold = torch.abs(tgt_vals) * 0.01 + 1e-8
            correct_valid = (diff < threshold).sum().item()
            total_valid = valid_and_not_abstained.sum().item()
            metrics['valid_accuracy'] = correct_valid / total_valid

        # Calculate abstention metrics
        if abstained.any():
            metrics['correct_abstentions'] = (abstained & invalid_mask).float().mean().item()
            metrics['incorrect_abstentions'] = (abstained & valid_mask).float().mean().item()

        return metrics

In [None]:
class LandscapeAnalyzer:
    """
    Analyzes the loss landscape of neural networks during training, focusing on measures of curvature.

    Attributes:
        -1.0 (float): Token value used to indicate model abstention
        abstention_weight (float): Weight applied to abstention loss term
        model (Module): PyTorch model to analyze
        alpha (float): Step size for parameter perturbations
        num_samples (int): Number of samples to use in analysis
        grid_size (int): Resolution of grid for landscape visualization
        alpha_range (float): Range of alpha values to explore
        save_dir (Optional[Path]): Directory to save analysis results
        fixed_scale (float): Fixed scale factor for visualization
        metrics_history (defaultdict): History of computed metrics

    Args:
        model (Module): The neural network model to analyze
        alpha (float, optional): Step size for parameter perturbations. Defaults to 0.1
        num_samples (int, optional): Number of samples to use in analysis. Defaults to 100
        grid_size (int, optional): Resolution of grid for visualization. Defaults to 20
        alpha_range (float, optional): Range of alpha values to explore. Defaults to 0.5
        save_dir (Optional[str], optional): Directory to save results. Defaults to None
        fixed_scale (float, optional): Fixed scale factor for visualization. Defaults to 7
        -1.0 (float, optional): Token value for abstention. Defaults to -1.0
        abstention_weight (float, optional): Weight for abstention loss. Defaults to 16.0
    """

    def __init__(
        self,
        model: nn.Module,
        alpha: float = 0.1,
        num_samples: int = 100,
        grid_size: int = 20,
        alpha_range: float = 0.5,
        save_dir: Optional[str] = None,
        fixed_scale: float = 7,
        abstention_weight: float = ABSTENTION_WEIGHT
    ) -> None:
        self.abstention_weight = abstention_weight
        self.model = model
        self.alpha = alpha
        self.num_samples = num_samples
        self.grid_size = grid_size
        self.alpha_range = alpha_range
        self.save_dir = Path(save_dir) if save_dir else None
        self.fixed_scale = fixed_scale
        self.loss_computer = UnifiedLossComputer(
            abstention_weight=abstention_weight
        )
        self.metrics_history: defaultdict = defaultdict(list)

    def compute_valley_asymmetry(
        self,
        batch: Dict[str, Tensor],
        num_directions: int = 10
    ) -> float:
        """
        Measures the asymmetry of loss landscape valleys by comparing loss values
        in positive and negative directions from the current parameter position.

        The method works by:
        1. Storing the original model parameters
        2. Sampling random directions in parameter space
        3. Computing loss at positive and negative steps in each direction
        4. Measuring the asymmetry as the difference between positive and negative deviations

        Args:
            batch (Dict[str, Tensor]): Input batch containing data and labels
            num_directions (int, optional): Number of random directions to sample. Defaults to 10

        Returns:
            float: Maximum asymmetry found across all sampled directions

        Note:
            - Higher asymmetry values indicate more irregular loss landscape topology
            - The method temporarily modifies model parameters but restores them after computation
        """
        # Store original parameters to restore later
        original_params = {name: param.clone() for name, param in self.model.named_parameters()}
        original_loss = self.compute_loss(batch)

        max_asymmetry = 0
        for _ in range(num_directions):
            # Sample a random direction in parameter space
            direction = self.get_random_direction()

            # Measure loss in positive and negative directions
            with torch.no_grad():
                # Compute loss in positive direction
                for name, param in self.model.named_parameters():
                    param.data.copy_(original_params[name] + self.alpha * direction[name])
                pos_loss = self.compute_loss(batch)

                # Compute loss in negative direction
                for name, param in self.model.named_parameters():
                    param.data.copy_(original_params[name] - self.alpha * direction[name])
                neg_loss = self.compute_loss(batch)

                # Restore original parameters
                for name, param in self.model.named_parameters():
                    param.data.copy_(original_params[name])

            # Calculate asymmetry as the difference between positive and negative deviations
            asymmetry = abs(pos_loss - original_loss - (original_loss - neg_loss))
            max_asymmetry = max(max_asymmetry, asymmetry.item())

        return max_asymmetry
    def compute_alpha_sharpness(self, batch: Tuple[Tensor, Tensor]) -> float:
       """Compute the α-sharpness measure of the loss landscape by random perturbation sampling.

       The α-sharpness is defined as the maximum loss difference when parameters are perturbed
       within an α-radius L2 ball:

       α-sharpness = max_{||δ||₂ ≤ α} [L(θ + δ) - L(θ)]

       where:
       - L(θ) is the loss at parameters θ
       - δ is the perturbation vector
       - α is the perturbation radius

       Args:
           batch: A tuple of (inputs, targets) tensors representing a batch of data

       Returns:
           float: The maximum loss difference found across all sampled perturbations

       Note:
           This implementation approximates the true α-sharpness by random sampling
           rather than solving the optimization problem exactly.
       """
       # Store original parameters to restore later
       original_params = {name: param.clone() for name, param in self.model.named_parameters()}

       # Compute loss at original parameters
       original_loss = self.compute_loss(batch)

       max_loss_diff = 0
       # Sample multiple random perturbations to approximate maximum
       for _ in range(self.num_samples):
           with torch.no_grad():
               # Add random perturbation to each parameter, scaled by alpha
               for name, param in self.model.named_parameters():
                   delta = torch.randn_like(param) * self.alpha  # Gaussian noise
                   param.data.add_(delta)

               # Compute loss at perturbed parameters
               perturbed_loss = self.compute_loss(batch)
               # Update maximum loss difference found so far
               max_loss_diff = max(max_loss_diff, (perturbed_loss - original_loss).item())

               # Restore original parameters for next iteration
               for name, param in self.model.named_parameters():
                   param.data.copy_(original_params[name])

       return max_loss_diff

    def compute_multiscale_sharpness(
        self,
        batch: Tuple[Tensor, Tensor],
        scales: List[float] = [0.1, 0.01, 0.001]
    ) -> Dict[str, float]:
        """Compute α-sharpness at multiple scale values to analyze loss landscape roughness.

        This provides a more complete picture of the loss landscape geometry by measuring
        sharpness at different perturbation magnitudes:

        {α₁-sharpness, α₂-sharpness, ..., αₙ-sharpness}

        Args:
            batch: A tuple of (inputs, targets) tensors
            scales: List of α values to measure sharpness at, default [0.1, 0.01, 0.001]

        Returns:
            Dict[str, float]: Mapping from scale identifier to sharpness value
        """
        # Store original alpha to restore later
        original_alpha = self.alpha
        sharpness_values = {}

        # Compute sharpness at each scale
        for scale in scales:
            self.alpha = scale
            sharpness_values[f'alpha_{scale}'] = self.compute_alpha_sharpness(batch)

        # Restore original alpha
        self.alpha = original_alpha
        return sharpness_values

    def get_hessian_vector_product(
        self,
        batch: Tuple[Tensor, Tensor],
        vector: Tensor,
        num_power_iterations: int = 10
    ) -> Tensor:
        """Compute the Hessian-vector product (Hv) using automatic differentiation.

        Implements the calculation:
        Hv = ∇²L(θ)v = ∇(∇L(θ)ᵀv)

        where:
        - L(θ) is the loss function
        - θ are the model parameters
        - v is the input vector
        - H is the Hessian matrix

        Args:
            batch: A tuple of (inputs, targets) tensors
            vector: Vector to compute product with
            num_power_iterations: Number of power iterations (unused in this implementation)

        Returns:
            Tensor: The Hessian-vector product as a flattened tensor

        Note:
            This implementation uses automatic differentiation to compute the HVP
            without explicitly forming the Hessian matrix.
        """
        self.model.zero_grad()

        # Enable gradient computation for all parameters
        for param in self.model.parameters():
            param.requires_grad_(True)

        # Forward pass to get loss
        loss = self.compute_loss(batch)

        # Compute first-order gradients
        grads = torch.autograd.grad(
            loss,
            self.model.parameters(),
            allow_unused=True,  # Some parameters might not influence loss
            create_graph=True   # Enable second derivative computation
        )

        # Replace None gradients with zeros for unused parameters
        grads = [torch.zeros_like(p) if g is None else g
                 for g, p in zip(grads, self.model.parameters())]

        # Compute gradient-vector product (first term in chain rule)
        flat_grad = torch.cat([g.flatten() for g in grads])
        grad_vector_product = torch.dot(flat_grad, vector)

        # Compute Hessian-vector product via second backward pass
        hvp = torch.autograd.grad(
            grad_vector_product,
            self.model.parameters(),
            allow_unused=True
        )

        # Replace None values with zeros for unused parameters
        hvp = [torch.zeros_like(p) if h is None else h
               for h, p in zip(hvp, self.model.parameters())]

        # Return flattened HVP
        return torch.cat([h.flatten() for h in hvp])

    def estimate_top_k_eigenvalues(
        self,
        batch: Tuple[Tensor, Tensor, Tensor],
        k: int = 3,
        num_power_iterations: int = 10
    ) -> List[float]:
       """Estimate the top k eigenvalues of the Hessian matrix using power iteration.

       Uses the power iteration method with deflation to find the largest eigenvalues
       of the Hessian matrix. The method iteratively computes:

       v_{t+1} = Hv_t / ||Hv_t||
       λ = v^T Hv

       where:
       - H is the Hessian matrix
       - v_t is the estimate of eigenvector at iteration t
       - λ is the corresponding eigenvalue

       Args:
           batch: Tuple of (numbers, operator, targets) tensors
           k: Number of top eigenvalues to estimate
           num_power_iterations: Number of power iterations for each eigenvalue

       Returns:
           List[float]: Top k eigenvalues of the Hessian in descending order

       Note:
           Uses deflation to find subsequent eigenvalues by removing projections
           onto previously found eigenvectors.
       """
       device = next(self.model.parameters()).device
       eigenvalues = []
       eigenvectors = []

       for i in range(k):
           # Initialize random vector with fixed seed for reproducibility
           with torch.random.fork_rng():
               torch.manual_seed(seed)
               vector = torch.randn(sum(p.numel() for p in self.model.parameters())).to(device)
               vector = vector / torch.norm(vector)  # Normalize to unit vector

           # Power iteration
           for _ in range(num_power_iterations):
               vector_new = self.get_hessian_vector_product(batch, vector)

               # Deflate: Remove projections onto previous eigenvectors
               for prev_vec in eigenvectors:
                   vector_new = vector_new - torch.dot(vector_new, prev_vec) * prev_vec

               # Normalize the vector, checking for numerical stability
               norm = torch.norm(vector_new)
               if norm > 1e-10:
                   vector = vector_new / norm
               else:
                   print("\nWarning: Near-zero vector in power iteration")
                   break

           # Compute Rayleigh quotient to get eigenvalue
           hvp = self.get_hessian_vector_product(batch, vector)
           eigenvalue = torch.dot(vector, hvp)

           # Handle numerical instability
           if torch.isnan(eigenvalue):
               print(f"\nWarning: NaN eigenvalue detected for eigenvector {i+1}")
               eigenvalue = torch.tensor(0.0)

           eigenvalues.append(eigenvalue.item())
           eigenvectors.append(vector)

       return eigenvalues

    def analyze_landscape(
        self,
        batch: Tuple[Tensor, Tensor, Tensor],
        epoch: Optional[int] = None
    ) -> Dict[str, Union[float, List[float], Dict[str, float]]]:
        """Analyze various geometric properties of the loss landscape.

        Computes multiple metrics to characterize the loss landscape geometry:
        1. Valley asymmetry: Measures asymmetric properties of loss valleys
        2. Top eigenvalues: Largest eigenvalues of the Hessian
        3. Multiscale sharpness: Loss variation at different perturbation scales
        4. Alpha sharpness: Maximum loss variation within alpha-radius ball

        Args:
            batch: Tuple of (numbers, operator, targets) tensors
            epoch: Optional epoch number for saving metrics

        Returns:
            Dict containing computed landscape metrics:
                - 'valley_asymmetry': float
                - 'top_eigenvalues': List[float]
                - 'multiscale_sharpness': Dict[str, float]
                - 'alpha_sharpness': float

        Note:
            If epoch is provided, metrics are saved to a JSON file.
        """
        # Compute individual landscape metrics
        valley_asymmetry = self.compute_valley_asymmetry(batch)
        top_eigenvalues = self.estimate_top_k_eigenvalues(batch, k=3)
        multiscale_sharpness = self.compute_multiscale_sharpness(batch)
        alpha_sharpness = self.compute_alpha_sharpness(batch)

        # Aggregate metrics
        metrics = {
            'valley_asymmetry': valley_asymmetry,
            'top_eigenvalues': top_eigenvalues,
            'multiscale_sharpness': multiscale_sharpness,
            'alpha_sharpness': alpha_sharpness
        }

        # Save metrics if epoch is provided
        if epoch is not None:
            save_path = os.path.join(self.save_dir, f"landscape_metrics_epoch_{epoch}.json")
            with open(save_path, 'w') as f:
                json.dump(metrics, f)

        return metrics

    def compute_loss(
        self,
        batch: Tuple[Tensor, Tensor, Tensor]
    ) -> Tensor:
        numbers, operator, targets = batch
        predictions = self.model(numbers, operator)
        loss, _ = self.loss_computer.compute_loss(predictions, targets, numbers, operator)
        return loss
    def evaluate_loss_surface(
            self,
            data_loader: DataLoader,
            dir1: Dict[str, Tensor],
            dir2: Dict[str, Tensor]
        ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Evaluate the loss on a 2D grid in parameter space along two directions.

        Computes the loss surface by perturbing model parameters along two directions:
        θ(α,β) = θ₀ + α·d₁ + β·d₂

        where:
        - θ₀ is the original parameter vector
        - d₁, d₂ are the perturbation directions
        - α, β are scaling factors

        Args:
            data_loader: DataLoader providing batches of training data
            dir1: First perturbation direction as {param_name: direction_tensor}
            dir2: Second perturbation direction as {param_name: direction_tensor}

        Returns:
            Tuple containing:
            - alphas: 1D array of α values
            - betas: 1D array of β values
            - loss_surface: 2D array of loss values at each (α,β) point
        """
        # Store original parameters to restore later
        original_params = {
            name: param.data.clone()
            for name, param in self.model.named_parameters()
        }

        # Create grid of perturbation scales
        alphas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        betas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        loss_surface = np.zeros((self.grid_size, self.grid_size))

        # Evaluate loss at each grid point
        for i, alpha in enumerate(alphas):
            for j, beta in enumerate(betas):
                with torch.no_grad():
                    # Apply perturbation: θ = θ₀ + α·d₁ + β·d₂
                    for name, param in self.model.named_parameters():
                        param.data.copy_(
                            original_params[name] + alpha * dir1[name] + beta * dir2[name]
                        )

                    # Average loss over multiple batches
                    total_loss = sum(
                        self.compute_loss(batch).item()
                        for batch in itertools.islice(data_loader, 5)
                    )
                    loss_surface[i, j] = total_loss / 5

            if i % 5 == 0:
                print(f"Completed {i + 1}/{self.grid_size} rows")

        # Restore original parameters
        for name, param in self.model.named_parameters():
            param.data.copy_(original_params[name])

        return alphas, betas, loss_surface


    def evaluate_landscape_with_abstention(
        self,
        data_loader: DataLoader,
        dir1: Dict[str, Tensor],
        dir2: Dict[str, Tensor]
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Evaluate loss landscape with abstention metrics along two parameter directions.

        Computes multiple surfaces characterizing model behavior:
        1. Loss surface
        2. Abstention rates
        3. Valid accuracy

        Args:
            data_loader: DataLoader providing batches of training data
            dir1: First perturbation direction as {param_name: direction_tensor}
            dir2: Second perturbation direction as {param_name: direction_tensor}

        Returns:
            Tuple containing:
            - alphas: 1D array of α values
            - betas: 1D array of β values
            - loss_surface: 2D array of loss values
            - abstention_rates: 2D array of abstention rates
            - valid_accuracy: 2D array of accuracy on valid inputs
        """
        original_params = {name: param.data.clone()
                          for name, param in self.model.named_parameters()}

        # Create evaluation grid
        alphas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        betas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)

        # Initialize metric surfaces
        loss_surface = np.zeros((self.grid_size, self.grid_size))
        abstention_rates = np.zeros((self.grid_size, self.grid_size))
        valid_accuracy = np.zeros((self.grid_size, self.grid_size))

        for i, alpha in enumerate(alphas):
            for j, beta in enumerate(betas):
                with torch.no_grad():
                    # Apply parameter perturbation
                    for name, param in self.model.named_parameters():
                        param.data.copy_(original_params[name] +
                                       alpha * dir1[name] +
                                       beta * dir2[name])

                    # Initialize metrics for current point
                    total_loss = 0.0
                    total_abstentions = 0
                    total_valid = 0
                    total_correct = 0

                    for batch in itertools.islice(data_loader, 5):
                        numbers, operator, targets = batch
                        predictions = self.model(numbers, operator)

                        # Compute metrics
                        loss = self.compute_loss(batch)
                        total_loss += loss.item()

                        predictions = predictions.squeeze()
                        # Identify abstentions
                        abstained = torch.isclose(
                            predictions,
                            torch.tensor(-1.0, device=predictions.device),
                            rtol=0.1, atol=0.1
                        )

                        # Compute accuracy on valid, non-abstained predictions
                        invalid_mask = is_invalid_computation(numbers, operator)
                        valid_mask = ~invalid_mask
                        valid_and_not_abstained = valid_mask & (~abstained)

                        if valid_and_not_abstained.any():
                            pred_vals = predictions[valid_and_not_abstained]
                            tgt_vals = targets[valid_and_not_abstained].squeeze()
                            diff = torch.abs(pred_vals - tgt_vals)
                            threshold = torch.abs(tgt_vals) * 0.01 + 1e-8
                            correct_valid = (diff < threshold).sum().item()
                            total_correct += correct_valid

                        total_abstentions += abstained.sum().item()
                        total_valid += len(predictions)

                    # Update metric surfaces
                    loss_surface[i, j] = total_loss / 5
                    abstention_rates[i, j] = total_abstentions / total_valid
                    valid_accuracy[i, j] = total_correct / total_valid

        # Restore original parameters
        for name, param in self.model.named_parameters():
            param.data.copy_(original_params[name])

        return alphas, betas, loss_surface, abstention_rates, valid_accuracy

    def visualize_landscape(
        self,
        data_loader: DataLoader,
        epoch: Optional[int] = None,
        model_name: Optional[str] = None,
        rand_dir: bool = True
        ) -> None:
        """Visualize multiple aspects of the loss landscape in 3D.

        Creates a three-panel visualization showing:
        1. Local Lipschitz constants
        2. Distance to decision boundary
        3. Loss surface (log scale)

        Each surface is plotted along two directions in parameter space,
        either random or principal directions based on the Hessian.

        Args:
            data_loader: DataLoader providing batches of training data
            epoch: Optional epoch number for saving visualization
            model_name: Optional model name for saving visualization
            rand_dir: Whether to use random directions (True) or
                     principal directions (False)
        """
        # Get perturbation directions
        if rand_dir:
            dir1 = self.get_random_direction()
            dir2 = self.get_random_direction()
        else:
            dir1, dir2 = self.get_principal_directions(next(iter(data_loader)))

        # Create coordinate grid
        alphas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        betas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        alpha_grid, beta_grid = np.meshgrid(alphas, betas)

        # Compute surfaces
        print("Computing Lipschitz surface...")
        _, _, llc_surface = self.evaluate_lipschitz_surface(data_loader, dir1, dir2)

        print("Computing Distance to Decision Boundary surface...")
        _, _, ddb_surface = self.evaluate_distance_to_boundary(data_loader, dir1, dir2)

        print("Computing loss surface...")
        _, _, loss_surface = self.evaluate_loss_surface(data_loader, dir1, dir2)

        # Create visualization
        fig = plt.figure(figsize=(20, 7))
        gs = plt.GridSpec(1, 3, width_ratios=[1, 1, 1], wspace=0.2)

        # Plot 1: Local Lipschitz Constants
        ax1 = fig.add_subplot(gs[0], projection='3d')
        llc_surface_smooth = gaussian_filter(llc_surface, sigma=1.0)
        surf1 = ax1.plot_surface(alpha_grid, beta_grid, llc_surface_smooth,
                               cmap='coolwarm', antialiased=True)
        ax1.grid(True, linestyle='--', alpha=0.3)
        ax1.set_title("Local Lipschitz Constants\n(Smoothed)")
        fig.colorbar(surf1, ax=ax1, pad=0.12, label="LLC Magnitude")

        # Plot 2: Distance to Decision Boundary
        ax2 = fig.add_subplot(gs[1], projection='3d')
        ddb_surface_smooth = gaussian_filter(ddb_surface, sigma=1.0)
        surf2 = ax2.plot_surface(alpha_grid, beta_grid, ddb_surface_smooth,
                               cmap='plasma', antialiased=True)
        ax2.grid(True, linestyle='--', alpha=0.3)
        ax2.set_title("Distance to Decision Boundary\n(Smoothed)")
        fig.colorbar(surf2, ax=ax2, pad=0.12, label="Distance")

        # Plot 3: Loss Surface (log scale)
        ax3 = fig.add_subplot(gs[2], projection='3d')
        loss_surface_log = np.log1p(loss_surface)
        loss_surface_smooth = gaussian_filter(loss_surface_log, sigma=1.0)
        surf3 = ax3.plot_surface(alpha_grid, beta_grid, loss_surface_smooth,
                               cmap='magma', antialiased=True)
        ax3.grid(True, linestyle='--', alpha=0.3)
        ax3.set_title("Loss Landscape\n(Log Scale, Smoothed)")
        fig.colorbar(surf3, ax=ax3, pad=0.12, label="Log Loss")

        # Style all axes consistently
        for ax in [ax1, ax2, ax3]:
            ax.set_xlabel('α')
            ax.set_ylabel('β')
            ax.view_init(elev=20, azim=45)
            ax.xaxis.set_major_locator(plt.MaxNLocator(6))
            ax.yaxis.set_major_locator(plt.MaxNLocator(6))
            ax.zaxis.set_major_locator(plt.MaxNLocator(6))
            ax.xaxis._axinfo["grid"]['color'] = (1,1,1,0.2)
            ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0.2)
            ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0.2)

        # Save visualization if requested
        if self.save_dir and epoch is not None and model_name:
            save_path = os.path.join(self.save_dir,
                                    f"landscape_epoch_{epoch}_{model_name}.png")
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            print(f"Saved landscape visualization to {save_path}")
        plt.close()
    def get_principal_directions(
       self,
       batch: Tuple[Tensor, ...],
       k: int = 2
       ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
       """Compute the top k principal directions of the Hessian using power iteration.
       Uses modified power iteration to find the eigenvectors corresponding to the
       largest eigenvalues of the Hessian matrix. Implements deflation to find
       multiple orthogonal directions.
       The algorithm iteratively computes:
       v_{t+1} = Hv_t / ||Hv_t||
       where H is the Hessian matrix and v_t is the current estimate of an eigenvector.
       Args:
           batch: Tuple of input tensors for computing Hessian
           k: Number of principal directions to compute (default: 2)
       Returns:
           Tuple of two dictionaries mapping parameter names to direction tensors,
           representing the top two principal directions
       Note:
           Handles potential numerical instabilities with zero vectors
           that can occur with unused parameters.
       """
       device = next(self.model.parameters()).device
       directions = {}
       vectors = []  # Store flat vectors for deflation
       for i in range(k):
           # Initialize random vector
           with torch.random.fork_rng():
               torch.manual_seed(seed)
               vector = torch.randn(sum(p.numel() for p in self.model.parameters())).to(device)
               vector = vector / torch.norm(vector)
            # Power iteration
           for _ in range(10):
               vector_new = self.get_hessian_vector_product(batch, vector)
               # Deflate against previous eigenvectors
               for prev_vec in vectors:
                   vector_new = vector_new - torch.dot(vector_new, prev_vec) * prev_vec
               # Handle numerical stability
               if torch.norm(vector_new) > 1e-10:
                   vector = vector_new / torch.norm(vector_new)
               else:
                   # Reset to random if we hit a zero vector
                   vector = torch.randn_like(vector)
                   vector = vector / torch.norm(vector)
           vectors.append(vector)
           # Convert flat vector back to parameter dictionary
           direction = {}
           offset = 0
           for name, param in self.model.named_parameters():
               numel = param.numel()
               direction[name] = vector[offset:offset+numel].reshape(param.shape)
               offset += numel
           directions[f'pc{i+1}'] = direction
       return directions['pc1'], directions['pc2']
    def get_random_direction(self) -> Dict[str, Tensor]:
        """Generate a normalized random direction in parameter space.
          Creates a dictionary mapping parameter names to random direction tensors,
        where each direction tensor has the same shape as its corresponding parameter
        and is normalized to unit length.
          Returns:
            Dict mapping parameter names to normalized random direction tensors
        """
        direction = {}
        with torch.random.fork_rng():
            torch.manual_seed(seed)
            for name, param in self.model.named_parameters():
                direction[name] = torch.randn_like(param)
                direction[name] /= torch.norm(direction[name])
        return direction
    def compute_local_lipschitz(
      self,
      batch: Tuple[Tensor, ...],
      point: Tuple[float, float],
      radius: float = 0.1,
      num_samples: int = 100
      ) -> float:
      """Compute the local Lipschitz constant in a neighborhood of a point.
        The local Lipschitz constant L is computed as:
      L = max_{x,y in B(point,radius)} ||f(x) - f(y)|| / ||x - y||
        where:
      - B(point,radius) is the ball of given radius around the point
      - f is the model function
      - ||·|| denotes appropriate norms
        Args:
          batch: Input data batch
          point: Point in parameter space to compute Lipschitz constant around
          radius: Radius of neighborhood to sample in
            num_samples: Number of random directions to sample

        Returns:
            float: Estimated local Lipschitz constant
        """
      original_params = {name: param.clone() for name, param in self.model.named_parameters()}
      device = next(self.model.parameters()).device
      f_0 = self.model(*batch[:2])
      max_lipschitz = 0
      # Sample random directions to estimate Lipschitz constant
      for _ in range(num_samples):
          direction = self.get_random_direction()
          scale = torch.rand(1).item() * radius
          # Compute output at perturbed point
          with torch.no_grad():
              for name, param in self.model.named_parameters():
                  param.data.copy_(original_params[name] + scale * direction[name])
              # Compute Lipschitz quotient
              f_x = self.model(*batch[:2])
              output_diff = torch.norm(f_x - f_0)
              output_diff = torch.norm(f_x - f_0)
              param_diff_squared = sum(torch.norm(d).item() ** 2 for d in direction.values())
              param_diff = scale * np.sqrt(param_diff_squared)
              lipschitz = output_diff.item() / param_diff
              max_lipschitz = max(max_lipschitz, lipschitz)
        # Restore original parameters
      for name, param in self.model.named_parameters():
          param.data.copy_(original_params[name])
      return max_lipschitz

    def evaluate_lipschitz_surface(
        self,
        data_loader: DataLoader,
        dir1: Dict[str, Tensor],
        dir2: Dict[str, Tensor]
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Evaluate local Lipschitz constants across a 2D slice of parameter space.
        Computes a grid of local Lipschitz constants by perturbing parameters along
        two specified directions:
        θ(α,β) = θ₀ + α·d₁ + β·d₂
        Args:
            data_loader: DataLoader providing training batches
            dir1: First perturbation direction
            dir2: Second perturbation direction
        Returns:
            Tuple containing:
            - alphas: 1D array of α values
            - betas: 1D array of β values
            - llc_surface: 2D array of local Lipschitz constants
        """
        original_params = {name: param.data.clone() for name, param in self.model.named_parameters()}
        alphas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        betas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        llc_surface = np.zeros((self.grid_size, self.grid_size))
        for i, alpha in enumerate(alphas):
            for j, beta in enumerate(betas):
                # Move to grid point
                with torch.no_grad():
                    for name, param in self.model.named_parameters():
                        param.data.copy_(original_params[name] +
                                       alpha * dir1[name] +
                                       beta * dir2[name])
                # Compute LLC at current point
                batch = next(iter(data_loader))
                llc_surface[i, j] = self.compute_local_lipschitz(batch, (alpha, beta))
            if i % 5 == 0:
                print(f"Completed {i + 1}/{self.grid_size} LLC rows")
        # Restore original parameters
        for name, param in self.model.named_parameters():
            param.data.copy_(original_params[name])
        return alphas, betas, llc_surface

    def evaluate_distance_to_boundary(
        self,
        data_loader: DataLoader,
        dir1: Dict[str, Tensor],
        dir2: Dict[str, Tensor]
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Evaluate distance to decision boundary across a 2D parameter slice.
        Uses binary search to find the minimum distance to a decision boundary
        (where model behavior changes) at each point in a grid.
        Args:
            data_loader: DataLoader providing training batches
            dir1: First perturbation direction
            dir2: Second perturbation direction
        Returns:
            Tuple containing:
            - alphas: 1D array of α values
            - betas: 1D array of β values
            - distance_surface: 2D array of distances to nearest boundary
        """
        device = next(self.model.parameters()).device
        original_params = {name: param.data.clone() for name, param in self.model.named_parameters()}
        # Setup evaluation grid
        alphas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        betas = np.linspace(-self.alpha_range, self.alpha_range, self.grid_size)
        distance_surface = np.zeros((self.grid_size, self.grid_size))
        # Get test batch near decision boundaries
        boundary_loader = create_boundary_test_loader(data_loader, num_samples=100)
        test_batch = next(iter(boundary_loader))
        # Evaluate grid points
        for i, alpha in enumerate(alphas):
            for j, beta in enumerate(betas):
                with torch.no_grad():
                    # Set parameters for current grid point
                    for name, param in self.model.named_parameters():
                        new_param = (original_params[name] +
                                   alpha * dir1[name] +
                                   beta * dir2[name])
                        param.data.copy_(new_param)

                    # Get base behavior
                    numbers, operator, targets = test_batch
                    numbers = numbers.to(device)
                    operator = operator.to(device)
                    predictions = self.model(numbers, operator)
                    base_abstained = torch.isclose(
                        predictions.squeeze(),
                        torch.tensor(-1.0).to(device),
                        rtol=0.1, atol=0.1
                    )
                    # Binary search for nearest boundary
                    max_dist = 0.5
                    min_dist = 0.001
                    left = 0
                    right = max_dist
                    while right - left > min_dist:
                        mid = (left + right) / 2
                        search_dir = self.get_random_direction()
                        # Test perturbed behavior
                        for name, param in self.model.named_parameters():
                            perturbed = new_param + mid * search_dir[name]
                            param.data.copy_(perturbed)
                        perturbed_pred = self.model(numbers, operator)
                        perturbed_abstained = torch.isclose(
                            perturbed_pred.squeeze(),
                            torch.tensor(-1.0).to(device),
                            rtol=0.1, atol=0.1
                        )
                        # Update search interval
                        if (perturbed_abstained != base_abstained).any():
                            right = mid
                        else:
                            left = mid
                    distance_surface[i,j] = left
            if i % 10 == 0 and j == 0:
                 print(f"Computing DDB row {i}/{self.grid_size}")
        # Restore original parameters
        for name, param in self.model.named_parameters():
            param.data.copy_(original_params[name])
        return alphas, betas, distance_surface
def get_rand_dirs(model: torch.nn.Module) -> Dict[str, Tensor]:
    """Generate a random direction with filter-wise normalizatio
    Creates random directions that preserve the scaling of weights in
    each filter of convolutional layer
    Args:
        model: Neural network mod
    Returns:
        Dict mapping parameter names to normalized random direction tensors
    """
    direction = {}
    # Sort parameters for consistency
    params = sorted(model.named_parameters(), key=lambda x: x[0])
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        for name, param in params:
            rnd = torch.randn_like(param)
            # Normalize each filter independently
            if len(param.size()) > 1:
                for dim in range(rnd.size(0)):
                    filter_norm = torch.norm(param[dim].data)
                    dir_norm = torch.norm(rnd[dim].data) + 1e-8
                    rnd[dim].data.mul_(filter_norm / dir_norm)
            direction[name] = rnd
    return direction
def plot_high_quality_loss_landscape(model, data_loader, save_path,
                             grid_resolution=200, alpha_range=1.0,
                             azimuth=140, elevation=20,
                             show_axes=False,  # Default False
                             transparent_background=True,
                             loss_computer=None,
                             dark_mode=False):  # Default light theme
    """
    Create loss landscape visualization.

    Args:
        azimuth (float): Horizontal rotation (0-360 degrees)
        elevation (float): Vertical rotation (-90 to 90 degrees)
        show_axes (bool): Whether to show axes and grid
        transparent_background (bool): Use transparent background for paper
        dark_mode (bool): Use dark theme (default False for paper)
    """
    orig_params = {name: param.data.clone() for name, param in model.named_parameters()}
    if loss_computer is None:
      loss_computer = UnifiedLossComputer()


    # Temporarily increment seed for second direction
    global seed
    dir1 = get_rand_dirs(model)  # Uses original seed
    seed += 1  # Explicit increment
    dir2 = get_rand_dirs(model)  # Uses incremented seed
    seed -= 1  # Reset to original

    # Single-pass with moderate resolution
    alphas = np.linspace(-alpha_range, alpha_range, grid_resolution)
    betas = np.linspace(-alpha_range, alpha_range, grid_resolution)
    alpha_grid, beta_grid = np.meshgrid(alphas, betas)
    loss_surface = np.zeros((grid_resolution, grid_resolution))

    device = next(model.parameters()).device
    model.eval()

    print(f"\nComputing loss surface ({grid_resolution} x {grid_resolution})...")

    with torch.no_grad():
        for i in tqdm(range(grid_resolution), desc="Computing landscape"):
            for j in range(grid_resolution):
                alpha, beta = alphas[i], betas[j]

                # Update model parameters
                for name, param in model.named_parameters():
                    new_param = orig_params[name] + alpha * dir1[name] + beta * dir2[name]
                    param.data.copy_(new_param)

                total_loss = 0.0
                batch_count = 0
                max_batches = 7  # Reduced from 10

                for batch in data_loader:
                    numbers, operator, targets = [t.to(device) for t in batch]
                    predictions = model(numbers, operator)
                    loss, _ = loss_computer.compute_loss(predictions, targets, numbers, operator)
                    total_loss += loss.item()
                    batch_count += 1
                    if batch_count >= max_batches:
                        break

                loss_surface[i, j] = total_loss / batch_count

    # Restore original parameters
    for name, param in model.named_parameters():
        param.data.copy_(orig_params[name])

    plt.style.use('default' if not dark_mode else 'dark_background')
    fig = plt.figure(figsize=(16, 16), dpi=300)  # Restored original high quality size
    ax = fig.add_subplot(111, projection='3d')


    if dark_mode:
        colors_map = [
            (0.1, 0.0, 0.0),      # Dark base
            (0.3, 0.0, 0.0),      # Dark red
            (0.5, 0.1, 0.1),      # Medium red
            (0.7, 0.2, 0.2),      # Light red
            (1.0, 0.4, 0.4)       # Highlight
        ]
    else:
        colors_map = [
            (1.0, 0.8, 0.8),      # Light pink
            (0.9, 0.6, 0.6),      # Salmon
            (0.8, 0.4, 0.4),      # Medium red
            (0.7, 0.2, 0.2),      # Deep red
            (0.6, 0.0, 0.0)       # Dark red
        ]

    color_map = colors.LinearSegmentedColormap.from_list('paper_map', colors_map, N=256)

    # Surface processing with slightly increased smoothing
    loss_surface_processed = loss_surface.copy()
    sigma = grid_resolution/35  # Slightly increased smoothing
    loss_surface_processed = gaussian_filter(loss_surface_processed, sigma=sigma)

    loss_surface_processed = np.log1p(loss_surface_processed)
    p_min, p_max = np.percentile(loss_surface_processed, [1, 99])
    loss_surface_processed = np.clip(loss_surface_processed, p_min, p_max)
    loss_surface_processed = (loss_surface_processed - p_min) / (p_max - p_min)

    # Enhanced lighting for paper clarity
    ls = colors.LightSource(azdeg=315, altdeg=45)
    illuminated_surface = ls.shade(loss_surface_processed,
                                 cmap=color_map,
                                 vert_exag=2.0,
                                 blend_mode='soft')

    # Main surface plot 
    surf = ax.plot_surface(alpha_grid, beta_grid, loss_surface_processed,
                          facecolors=illuminated_surface,
                          linewidth=0.1,
                          antialiased=True,
                          shade=True)

    if show_axes:
        # Contours only if axes are shown
        levels = np.linspace(loss_surface_processed.min(), loss_surface_processed.max(), 20)
        ax.contour(alpha_grid, beta_grid, loss_surface_processed,
                  zdir='z',
                  offset=loss_surface_processed.min(),
                  levels=levels,
                  cmap=color_map,
                  alpha=0.3,
                  linewidths=0.5)

        # Configure axes
        ax.grid(True, alpha=0.2, linestyle='-')
        ax.set_xlabel('α', labelpad=10)
        ax.set_ylabel('β', labelpad=10)
        ax.set_zlabel('Loss', labelpad=10)
    else:
        # Hide all axes, ticks, and grid
        ax.set_axis_off()

    # Set view angle - rotated to face camera
    ax.view_init(elev=elevation, azim=azimuth)
    ax.dist = 8

    # Configure background
    if transparent_background:
        ax.set_facecolor('none')
        fig.patch.set_alpha(0.0)
    else:
        ax.set_facecolor('white' if not dark_mode else 'black')
        fig.patch.set_facecolor('white' if not dark_mode else 'black')

    # Hide panes when axes are off
    if not show_axes:
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.xaxis.pane.set_edgecolor('none')
        ax.yaxis.pane.set_edgecolor('none')
        ax.zaxis.pane.set_edgecolor('none')

    # Set axis limits with minimal margin
    margin = 0.05
    ax.set_xlim(-alpha_range * (1 + margin), alpha_range * (1 + margin))
    ax.set_ylim(-alpha_range * (1 + margin), alpha_range * (1 + margin))

    
    plt.savefig(save_path,
                dpi=300,
                bbox_inches='tight',
                pad_inches=0.1,  # Reduced padding for paper
                facecolor='none' if transparent_background else ('white' if not dark_mode else 'black'),
                edgecolor='none',
                transparent=transparent_background)
    plt.close()

    return loss_surface

In [None]:
class UnifiedTracker:
    """
    Unified system for tracking both computational resources and training metrics.

    This tracker combines computational overhead tracking (FLOPs, passes, time)
    with detailed training metrics (loss, accuracy, abstentions). It provides
    a single interface for comprehensive monitoring of model training.

    Attributes:
        metrics (DefaultDict): Stores basic metrics by phase and epoch
        current_epoch (int): Current training epoch
        batch_abstention_counts (DefaultDict): Tracks abstention statistics
        batch_totals (DefaultDict): Tracks total abstentions
        forward_passes (int): Number of forward passes
        backward_passes (int): Number of backward passes
        flops (int): Total floating point operations
        start_time (float): Training start timestamp
    """

    def __init__(self) -> None:
        # Metrics tracking initialization
        self.metrics: DefaultDict[str, DefaultDict[int, List[float]]] = defaultdict(
            lambda: defaultdict(list)
        )
        self.current_epoch: int = 0

        # Abstention tracking
        self.batch_abstention_counts: DefaultDict[str, DefaultDict[int, Dict[str, int]]] = defaultdict(
            lambda: defaultdict(
                lambda: {
                    'correct': 0,    # Abstained on invalid samples
                    'incorrect': 0   # Abstained on valid samples
                }
            )
        )
        self.batch_totals: DefaultDict[str, DefaultDict[int, int]] = defaultdict(
            lambda: defaultdict(int)
        )

        # Computational tracking initialization
        self.forward_passes: int = 0
        self.backward_passes: int = 0
        self.flops: int = 0
        self.start_time: float = time.time()

    def update_computational_metrics(
        self,
        forward_passes: int = 0,
        backward_passes: int = 0,
        flops: int = 0
    ) -> None:
        """
        Update computational resource counters.

        Args:
            forward_passes: Number of forward passes to add
            backward_passes: Number of backward passes to add
            flops: Number of floating point operations to add
        """
        self.forward_passes += forward_passes
        self.backward_passes += backward_passes
        self.flops += flops

    def update_training_metrics(
        self,
        batch_metrics: Dict[str, Union[float, int]],
        phase: str = 'train'
    ) -> None:
        """
        Update training metrics from a batch.

        Args:
            batch_metrics: Dictionary containing:
                - 'loss': batch loss value
                - 'total_accuracy': batch accuracy
                - 'total_abstentions': total abstentions
                - 'abstained_on_invalid': correct abstentions
                - 'abstained_on_valid': incorrect abstentions
            phase: Training phase ('train' or 'val')
        """
        # Update basic metrics
        self.metrics[f"{phase}_loss"][self.current_epoch].append(batch_metrics['loss'])
        self.metrics[f"{phase}_accuracy"][self.current_epoch].append(
            batch_metrics['total_accuracy']
        )

        # Update abstention statistics
        if batch_metrics['total_abstentions'] > 0:
            self.batch_abstention_counts[phase][self.current_epoch]['correct'] += \
                batch_metrics['abstained_on_invalid']
            self.batch_abstention_counts[phase][self.current_epoch]['incorrect'] += \
                batch_metrics['abstained_on_valid']
            self.batch_totals[phase][self.current_epoch] += batch_metrics['total_abstentions']

    def get_computational_metrics(self) -> Dict[str, float]:
        """
        Get current computational resource usage metrics.

        Returns:
            Dictionary containing:
                - forward_passes: Total forward passes
                - backward_passes: Total backward passes
                - total_flops: Cumulative FLOPs
                - wall_time: Total elapsed time
        """
        return {
            'forward_passes': self.forward_passes,
            'backward_passes': self.backward_passes,
            'total_flops': self.flops,
            'wall_time': time.time() - self.start_time
        }

    def get_training_metrics(self, phase: str) -> Dict[str, float]:
        """
        Get aggregated training metrics for current epoch.

        Args:
            phase: Training phase to get metrics for

        Returns:
            Dictionary containing:
                - loss: Mean epoch loss
                - accuracy: Mean epoch accuracy
                - correct_abstentions_percent: Correct abstention percentage
                - incorrect_abstentions_percent: Incorrect abstention percentage
        """
        epoch_data: Dict[str, float] = {}

        # Calculate mean metrics
        for key, epoch_dict in self.metrics.items():
            if key.startswith(phase):
                metric_name = key.split('_', 1)[1]
                values = epoch_dict[self.current_epoch]
                epoch_data[metric_name] = float(np.mean(values)) if values else 0.0

        # Calculate abstention percentages
        total_abstentions = self.batch_totals[phase][self.current_epoch]
        if total_abstentions > 0:
            correct_abs = self.batch_abstention_counts[phase][self.current_epoch]['correct']
            incorrect_abs = self.batch_abstention_counts[phase][self.current_epoch]['incorrect']
            epoch_data['correct_abstentions_percent'] = (correct_abs / total_abstentions) * 100.0
            epoch_data['incorrect_abstentions_percent'] = (incorrect_abs / total_abstentions) * 100.0
        else:
            epoch_data['correct_abstentions_percent'] = 0.0
            epoch_data['incorrect_abstentions_percent'] = 0.0

        return epoch_data

    def get_all_metrics(self, phase: str) -> Dict[str, float]:
        """
        Get comprehensive metrics combining both computational and training statistics.

        Args:
            phase: Training phase to get metrics for

        Returns:
            Combined dictionary of all metrics
        """
        return {
            **self.get_computational_metrics(),
            **self.get_training_metrics(phase)
        }

    def count_flops(
        self,
        model: nn.Module,
        input_batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ) -> int:
        """
        Calculate total FLOPs for a forward pass through the ArithmeticNet architecture.

        This function provides a layer-by-layer accounting of floating point operations
        in an ArithmeticNet model. It counts operations for:
        1. Input processing (number embedding and operator embedding)
        2. Core network layers (three fully connected layers)
        3. Output layer

        Mathematical Details:
        -------------------
        For each linear layer:
            FLOPs = batch_size * in_features * out_features

        For embeddings:
            FLOPs = batch_size * embedding_dim

        Args:
            model: ArithmeticNet model instance
            input_batch: Tuple of (numbers, operator, targets) tensors

        Returns:
            Total number of floating point operations for a forward pass

        Note:
            This is a simplified FLOP count that focuses on major operations.
            It does not count activation functions, bias additions, or other
            minor operations.
        """
        numbers, operator, _ = input_batch
        batch_size = len(numbers)
        total_flops = 0

        # Input processing FLOPs
        # Number embedding: two numbers per sample
        total_flops += batch_size * 2 * model.num_embedder.out_features
        # Operator embedding
        total_flops += batch_size * model.op_embedding.embedding_dim

        # Core network FLOPs
        # Each linear layer: matrix multiplication
        total_flops += batch_size * model.layer1.in_features * model.layer1.out_features
        total_flops += batch_size * model.layer2.in_features * model.layer2.out_features
        total_flops += batch_size * model.layer3.in_features * model.layer3.out_features

        # Output layer FLOPs
        total_flops += batch_size * model.output.in_features * model.output.out_features

        return total_flops


In [None]:
class BaseTrainer: #REWRITE Comment
    """
    Base trainer class implementing abstention-aware training for neural computation models.

    This trainer handles both valid and invalid computations, implementing an abstention
    mechanism where the model can learn to abstain from predictions on invalid inputs.
    The trainer includes comprehensive metrics tracking and loss computation that rewards
    correct abstentions while penalizing incorrect ones.

    Attributes:
        model (nn.Module): The neural network model being trained
        optimizer (optim.Optimizer): The optimizer used for training
        criterion (nn.Module): Loss function for comparing predictions with targets
        scheduler (optim.lr_scheduler.ReduceLROnPlateau): Learning rate scheduler
        tracker (ComputationalTracker): Tracks computational overhead metrics
    """

    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer
    ) -> None:
        """
        Initialize the trainer with a model and optimizer.

        Args:
            model: Neural network model to train
            optimizer: Optimizer for updating model parameters
        """
        self.model = model
        self.optimizer = optimizer
        self.criterion = nn.MSELoss(reduction='none')
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5
        )
        self.loss_computer = UnifiedLossComputer(
            abstention_weight=ABSTENTION_WEIGHT,
        )
        self.tracker = UnifiedTracker()

    def training_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ) -> Dict[str, float]:
        """
        Perform a single training step.

        Args:
            batch: Tuple of (numbers, operator, targets)

        Returns:
            Dictionary containing batch metrics
        """
        self.model.train()
        # Track computational overhead
        self.tracker.forward_passes += 1
        self.tracker.backward_passes += 1
        self.tracker.flops += self.tracker.count_flops(self.model, batch)

        self.optimizer.zero_grad()
        numbers, operator, targets = batch
        predictions = self.model(numbers, operator)
        loss = self.loss_computer.compute_loss(predictions, targets, numbers, operator)[0]
        loss.backward()
        self.optimizer.step()

        return self.compute_batch_metrics(numbers, operator, targets, predictions)

    def evaluate(
        self,
        loader: DataLoader,
        phase: str = 'val'
    ) -> float:
        """
        Evaluate the model on a data loader.

        Args:
            loader: DataLoader containing evaluation data
            phase: Evaluation phase ('train' or 'val')

        Returns:
            Average loss over the evaluation set
        """
        self.model.eval()
        total_loss = 0.0

        with torch.no_grad():
            for batch in loader:
                numbers, operator, targets = batch
                predictions = self.model(numbers, operator)
                batch_metrics = self.compute_batch_metrics(numbers, operator, targets, predictions)
                self.metrics_tracker.update(batch_metrics, phase)
                total_loss += batch_metrics['loss']

        return total_loss / len(loader)

    def compute_batch_metrics(
        self,
        numbers: Tensor,
        operator: Tensor,
        targets: Tensor,
        predictions: Optional[Tensor] = None
    ) -> Dict[str, float]:
        if predictions is None:
            predictions = self.model(numbers, operator)
        return self.loss_computer.compute_metrics(predictions, targets, numbers, operator)
class PGDTrainer(BaseTrainer):
    """
    Projected Gradient Descent (PGD) trainer that implements adversarial training
    to enhance abstention robustness.

    This trainer extends BaseTrainer by adding adversarial training specifically
    for invalid computations. It uses PGD to find adversarial parameter perturbations
    that would cause the model to output non-abstention values for invalid inputs,
    then trains the model to resist these perturbations.

    Mathematical Description:
    -----------------------
    For invalid inputs x, the PGD attack tries to find model parameters θ' that minimize:
        L_adv(θ') = -||f_θ'(x) - y_fake||²

    where f_θ' is the model with perturbed parameters and y_fake are random valid outputs.

    Attributes:
        alpha (float): PGD step size for gradient ascent
        k (int): Number of PGD iterations
        weight (float): Weight for adversarial loss
    """

    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        *,
        alpha: float = 0.016,
        k: int = 3,
        weight: float = 6.5
    ) -> None:
        """
        Initialize the PGD trainer.

        Args:
            model: Neural network model
            optimizer: Optimizer for parameter updates
            alpha: Step size for PGD
            k: Number of PGD steps
            weight: Weight for adversarial loss
        """
        super().__init__(model, optimizer)
        self.alpha = alpha
        self.k = k
        self.weight = weight
        self.tracker = UnifiedTracker()

    def get_invalid_mask_for_pgd(
        self,
        invalid_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Get the mask for PGD training. Base implementation uses all invalid samples.

        Args:
            invalid_mask: Original invalid computation mask

        Returns:
            Mask indicating which samples to use for PGD
        """
        return invalid_mask

    def training_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ) -> Dict[str, float]:
        """
        Perform a training step including both normal and adversarial updates.
        """
        self.model.train()
        numbers, operator, targets = batch
        invalid_mask = is_invalid_computation(numbers, operator)

        # Track main forward pass
        self.tracker.forward_passes += 1
        self.tracker.flops += self.tracker.count_flops(self.model, batch)

        # Standard training step
        self.optimizer.zero_grad()
        predictions = self.model(numbers, operator)
        loss = self.loss_computer.compute_loss(predictions, targets, numbers, operator)[0]
        loss.backward()

        # Track backward pass
        self.tracker.backward_passes += 1

        self.optimizer.step()

        # PGD step for invalid computations
        if invalid_mask.any():
            pgd_mask = self.get_invalid_mask_for_pgd(invalid_mask)

            # Track PGD computational cost
            self.tracker.forward_passes += self.k
            self.tracker.backward_passes += self.k
            selected_batch = (
                numbers[pgd_mask],
                operator[pgd_mask],
                targets[pgd_mask]
            )
            self.tracker.flops += self.k * self.tracker.count_flops(self.model, selected_batch)

            self.perform_pgd(batch, pgd_mask)

        return self.compute_batch_metrics(numbers, operator, targets, predictions)

    def perform_pgd(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
        invalid_mask: torch.Tensor
    ) -> None:
        """
        Perform PGD-based adversarial training on invalid inputs.

        1. Store original parameters θ₀
        2. For k steps:
           a. Forward pass: ŷ = f_θ(x_invalid)
           b. Compute anti-abstention loss: L = -||ŷ - y_fake||²
           c. Update: θ ← θ + α * ∇_θ L
        3. Restore θ₀
        4. Final update with combined loss

        Args:
            batch: Tuple of (numbers, operator, targets)
            invalid_mask: Boolean mask identifying invalid computations
        """
        # Store original parameters
        original_params = {
            name: param.clone().detach()
            for name, param in self.model.named_parameters()
        }

        # Extract invalid inputs
        numbers, operator, targets = batch
        invalid_numbers = numbers[invalid_mask]
        invalid_operator = operator[invalid_mask]

        # Generate anti-abstention targets
        # boundary targets
        num_invalid = invalid_mask.sum()
        fake_targets = self.generate_pgd_targets(invalid_numbers, invalid_operator)

        # PGD iteration loop
        for _ in range(self.k):
            self.optimizer.zero_grad()
            predictions = self.model(invalid_numbers, invalid_operator)

            # Compute anti-abstention loss
            # Negative MSE pushes predictions away from abstention token
            anti_abstention_loss = -F.mse_loss(predictions, fake_targets)
            (self.weight * anti_abstention_loss).backward()

            # Gradient ascent step
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    if param.grad is not None:
                        param.data += self.alpha * param.grad

        # Restore original parameters
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.data.copy_(original_params[name])

        # Final update with standard loss
        self.optimizer.zero_grad()
        predictions = self.model(numbers, operator)
        loss = self.loss_computer.compute_loss(predictions, targets, numbers, operator)[0]
        loss.backward()
        self.optimizer.step()
    def generate_pgd_targets(self, invalid_numbers, invalid_operator):
        targets = []
        for num, op in zip(invalid_numbers, invalid_operator):
            if op == '+':
                targets.append(395 + torch.rand(1) * 5)
            elif op == '-':
                targets.append(torch.rand(1) * 5)
            else:
                boundary_points = torch.tensor([0, 400])
                chosen_boundary = boundary_points[torch.randint(0, 2, (1,))]
                targets.append(chosen_boundary + (torch.rand(1) - 0.5) * 20)
        return torch.tensor(targets, device=invalid_numbers.device).unsqueeze(1)  # Add dimension to match expected shape
class EfficientPGDTrainer(PGDTrainer):
    """
    Memory-efficient version of PGD trainer that performs adversarial training
    on a randomly selected subset of invalid computations.

    This implementation reduces memory usage and computation time by applying
    PGD to only a fraction of invalid samples in each batch, while maintaining
    the effectiveness of adversarial training.

    Attributes:
        sample_ratio (float): Fraction of invalid samples to use for PGD
        seed (int): Random seed for sample selection
    """

    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        *,  # Force keyword arguments for clarity
        alpha: float = 0.016,
        k: int = 2,
        weight: float = 6.5,
        sample_ratio: float = 0.3,
        seed: int = 42
    ) -> None:
        """
        Initialize the efficient PGD trainer.

        Args:
            model: Neural network model
            optimizer: Optimizer for parameter updates
            alpha: Step size for PGD
            k: Number of PGD steps
            weight: Weight for adversarial loss
            sample_ratio: Fraction of invalid samples to use for PGD
            seed: Random seed for sample selection
        """
        super().__init__(
            model,
            optimizer,
            alpha=alpha,
            k=k,
            weight=weight
        )
        self.sample_ratio = sample_ratio
        self.seed = seed

    def get_invalid_mask_for_pgd(
        self,
        invalid_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Select a random subset of invalid samples for PGD training.

        Args:
            invalid_mask: Original invalid computation mask

        Returns:
            Reduced mask for subset of samples to use in PGD
        """
        mask_indices = torch.where(invalid_mask)[0]
        num_samples = max(1, int(len(mask_indices) * self.sample_ratio))

        with torch.random.fork_rng():
            torch.manual_seed(self.seed)
            selected_indices = mask_indices[
                torch.randperm(len(mask_indices))[:num_samples]
            ]

        reduced_mask = torch.zeros_like(invalid_mask)
        reduced_mask[selected_indices] = True

        return reduced_mask
class InputSpaceAdversarialTrainer(BaseTrainer):
    """
    A control trainer that does input-space adversarial training
    without increasing the effective dataset size and without
    leaking test data. This ensures it is a fair control
    relative to your parameter-based PGD model.
    """
    def __init__(
        self,
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        *,
        epsilon: float = 0.1,
        steps: int = 3,
        adv_ratio: float = 0.05,
    ) -> None:
        """
        Args:
            model: Neural network model
            optimizer: Optimizer for parameter updates
            epsilon: Max perturbation magnitude for valid inputs
            steps: Number of small PGD steps
            adv_ratio: Fraction of training examples in each batch to perturb
        """
        super().__init__(model, optimizer)
        self.epsilon = epsilon
        self.steps = steps
        self.adv_ratio = adv_ratio
        self.tracker = UnifiedTracker()

    def generate_input_adversarial_examples(
        self,
        numbers: torch.Tensor,
        operator: torch.Tensor,
        targets: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Generate adversarial examples *only* from the current batch.
        We do not expand or leak data from test sets. We also keep the
        final 'adv batch' size the same as the original subset we replace.
        """
        device = next(self.model.parameters()).device
        perturbed = numbers.clone().to(device).requires_grad_(True)
        operator = operator.to(device)
        targets = targets.to(device)

        # We'll only push the model's predictions away from the correct targets
        # for these valid inputs. (No new data samples.)
        for _ in range(self.steps):
            predictions = self.model(perturbed, operator)

            # Simple "maximize MSE" approach: push predictions away from targets
            # so that the model is forced to learn a stable boundary.
            loss = F.mse_loss(predictions, targets, reduction='mean')
            loss = -loss  # negative for gradient ascent on the same model params

            # Clear old grads & backprop
            perturbed.grad = None
            loss.backward()

            with torch.no_grad():
                if perturbed.grad is not None:
                    # FGSM-like step: take sign of grad
                    grad_sign = perturbed.grad.sign()
                    update = self.epsilon * grad_sign
                    perturbed = perturbed + update

                    # If you clamp to [0,400], do so carefully so valid points can move
                    # but do not get stuck at the boundary:
                    perturbed = torch.clamp(perturbed, min=0, max=400)

                # Re-enable grad for next iteration
                perturbed = perturbed.detach().requires_grad_(True)

        return perturbed.detach(), operator, targets

    def training_step(
        self,
        batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ) -> Dict[str, float]:
        """Single training step:
          1) Split the batch so that a fraction 'self.adv_ratio' is used for adversarial input
          2) Generate adversarial perturbations for that fraction
          3) Merge them back and run a forward/backward pass with the same total batch size
        """
        self.model.train()
        numbers, operator, targets = batch
        batch_size = len(numbers)

        # Track initial forward pass
        self.tracker.forward_passes += 1
        self.tracker.flops += self.tracker.count_flops(self.model, batch)

        # Index at which to split
        split_idx = int(batch_size * (1 - self.adv_ratio))

        # 1) Clean portion
        clean_nums = numbers[:split_idx]
        clean_ops = operator[:split_idx]
        clean_targets = targets[:split_idx]

        # 2) Adversarial portion
        # Track adversarial generation passes
        adv_nums, adv_ops, adv_targets = self.generate_input_adversarial_examples(
            numbers[split_idx:], operator[split_idx:], targets[split_idx:]
        )
        # Each adversarial example generation requires self.steps forward passes
        self.tracker.forward_passes += self.steps * (batch_size - split_idx)
        self.tracker.flops += self.steps * (batch_size - split_idx) * \
                             self.tracker.count_flops(self.model, (numbers[split_idx:],
                                                                 operator[split_idx:],
                                                                 targets[split_idx:]))

        # Recombine so total size is the same as the original batch
        combined_nums = torch.cat([clean_nums, adv_nums], dim=0)
        combined_ops = torch.cat([clean_ops, adv_ops], dim=0)
        combined_tgts = torch.cat([clean_targets, adv_targets], dim=0)

        # Forward pass on the combined batch
        self.optimizer.zero_grad()
        predictions = self.model(combined_nums, combined_ops)
        loss, _ = self.loss_computer.compute_loss(
            predictions, combined_tgts, combined_nums, combined_ops
        )
        loss.backward()

        # Track final backward pass
        self.tracker.backward_passes += 1

        self.optimizer.step()

        # Compute metrics
        with torch.no_grad():
            metrics = self.compute_batch_metrics(
                combined_nums, combined_ops, combined_tgts, predictions
            )
        return metrics

In [None]:
def get_model_name(model_idx: int, model: nn.Module, is_pgd: bool, model_var_name: Optional[str]) -> str:
    """
    Convert variable name to formatted model name, with fallback to type-based naming.

    Args:
        model_idx: Index of the model in the training framework
        model: The neural network model
        is_pgd: Flag indicating if the model uses PGD training
        model_var_name: Optional variable name provided for the model

    Returns:
        str: Formatted model name for use in saving and tracking
    """
    if model_var_name:
        base_name = model_var_name.replace('_model', '').replace('_', ' ').strip()
        return f"model_{model_idx}_{base_name}"
    elif is_pgd:
        return f"model_{model_idx}_pgd"
    else:
        return f"model_{model_idx}_base_adam"

class UnifiedTrainingFramework:
    def __init__(
        self,
        models: List[nn.Module],
        trainers: List[BaseTrainer],
        save_dir: str = 'training_results',
        fixed_scale: int = 1000,
        training_config: Optional[TrainingConfig] = None,
        model_names: Optional[List[str]] = None
    ):
        """
        Initialize the unified training framework for multiple models.

        Args:
            models: List of PyTorch models to train
            save_dir: Directory to save training results and visualizations
            lr: Learning rate for optimization
            fixed_scale: Scale factor for landscape analysis
            -1.0: Token value used for model abstention
            training_config: Configuration object for training parameters
            model_var_names: Optional list of variable names for models
        """
        self.models = models
        self.save_dir = save_dir
        self.fixed_scale = fixed_scale
        self.training_config = training_config or TrainingConfig()
        self.trainers = []
        self.trainers = trainers
        self.model_names = model_names

        if len(models) != len(trainers):
          raise ValueError("Number of models must match number of trainers")

        # Create necessary directories
        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(os.path.join(save_dir, 'landscapes'), exist_ok=True)
        os.makedirs(os.path.join(save_dir, 'metrics'), exist_ok=True)
        # Initialize metrics history tracking
        self.metrics_history = defaultdict(lambda: defaultdict(list))

    def get_model_name(model_idx: int, model: nn.Module, model_var_name: Optional[str]) -> str:
        if model_var_name:
            base_name = model_var_name.replace('_model', '').replace('_', ' ').strip()
            return f"model_{model_idx}_{base_name}"
        return f"model_{model_idx}"

    def train_single_model(
        self,
        model_idx: int,
        train_loader: DataLoader,
        val_loader: DataLoader,
        epochs: int,
        analyze_every: int
    ) -> Dict[str, Any]:
        """
        Train a single model with the specified configuration.

        Args:
            model_idx: Index of the model to train
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            epochs: Number of training epochs
            analyze_every: Frequency of landscape analysis

        Returns:
            Dict containing training metrics history
        """

        model = self.models[model_idx]
        trainer = self.trainers[model_idx]
        model_name = self.model_names[model_idx]

        # Training loop
        for epoch in range(epochs):
            self._train_epoch(trainer, train_loader, epoch, total_epochs=epochs)

            # Validation phase
            if epoch % self.training_config.val_freq == 0:
                self._validate_model(trainer, val_loader, epoch)

            # Landscape analysis
            if epoch > 0 and (epoch + 1) % self.training_config.analysis_freq == 0:
                self.perform_analysis_single_model(trainer, epoch, val_loader,
                                                model_name, model_idx)

        # Save final results
        #self._save_model_results(model_idx, model_name)
       # self._log_computational_metrics(trainer, model_name)
        return dict(self.metrics_history)
    def _train_epoch(
        self,
        trainer: BaseTrainer,
        train_loader: DataLoader,
        epoch: int,
        total_epochs: int,
    ) -> None:
        """
        Train for one epoch using the provided trainer.

        Args:
            trainer: Trainer instance to use
            train_loader: DataLoader for training data
            epoch: Current epoch number
        """
        # Update training configuration epoch

        # Initialize epoch metrics
        epoch_metrics = defaultdict(list)

        # Training loop
        pbar = tqdm(train_loader,
                    desc=f'Epoch {epoch+1}/{total_epochs}',
                    leave=True)
        for batch_idx, batch in enumerate(pbar):
            # Perform training step
            batch_metrics = trainer.training_step(batch)

            # Accumulate metrics
            for key, value in batch_metrics.items():
                epoch_metrics[f'train_{key}'].append(value)

            # Update progress bar with running averages
            running_metrics = {
                key: np.mean(values[-100:])  # Last 100 batches
                for key, values in epoch_metrics.items()
            }
            metric_strings = [
              f"\n{key.replace('train_', '')}: {value:.4g}"
              for key, value in running_metrics.items()
        ]

            pbar.set_postfix(running_metrics)


        # Calculate and store epoch averages
        avg_metrics = {
            key: float(np.mean(values))
            for key, values in epoch_metrics.items()
        }



    def _validate_model(
        self,
        trainer: BaseTrainer,
        val_loader: DataLoader,
        epoch: int
    ) -> float:
        """
        Validate the model using the provided trainer.

        Args:
            trainer: Trainer instance to use
            val_loader: DataLoader for validation data
            epoch: Current epoch number

        Returns:
            float: Average validation loss
        """
        # Initialize validation metrics
        val_metrics = defaultdict(list)

        # Validation loop
        trainer.model.eval()
        with torch.no_grad():
            for batch in val_loader:
                # Get predictions and metrics
                numbers, operator, targets = batch
                predictions = trainer.model(numbers, operator)
                batch_metrics = trainer.compute_batch_metrics(
                    numbers=numbers,
                    operator=operator,
                    targets=targets,
                    predictions=predictions
                )

                # Accumulate metrics
                for key, value in batch_metrics.items():
                    val_metrics[f'val_{key}'].append(value)

        # Calculate validation averages
        avg_metrics = {
            key: float(np.mean(values))
            for key, values in val_metrics.items()
        }

        # Update metrics history
        for key, value in avg_metrics.items():
            self.metrics_history[key][epoch].append(value)

        # Log validation results
        metrics_str = ', '.join(
            f'{key}: {value:.4f}'
            for key, value in avg_metrics.items()
        )
        print(f'\nEpoch {epoch} validation: {metrics_str}')

        # Step learning rate scheduler if it exists
        val_loss = avg_metrics['val_loss']
        if hasattr(trainer, 'scheduler'):
            trainer.scheduler.step(val_loss)

        return val_loss

    def _log_computational_metrics(
        self,
        trainer: BaseTrainer,
        model_name: str
    ) -> None:
        """
        Log computational metrics from trainer.

        Args:
            trainer: Trainer instance
            model_name: Name of the model
        """
        if hasattr(trainer, 'tracker'):
            metrics = trainer.tracker.get_metrics()

            # Save detailed metrics to file
            metrics_path = os.path.join(
                self.save_dir,
                'metrics',
                f'{model_name}_computational_metrics.json'
            )
            with open(metrics_path, 'w') as f:
                json.dump(metrics, f, indent=2)

            # Print summary
            print(f"\nComputational metrics for {model_name}:")
            print(f"  Forward passes: {metrics['forward_passes']}")
            print(f"  Backward passes: {metrics['backward_passes']}")
            print(f"  Total FLOPs: {metrics['total_flops']:,}")
            print(f"  Wall time: {metrics['wall_time']:.2f}s")

    def perform_analysis_single_model(
        self,
        trainer: Union[BaseTrainer, EfficientPGDTrainer],
        epoch: int,
        val_loader: DataLoader,
        model_name: str,
        model_idx: int
    ) -> None:
        """
        Perform comprehensive landscape analysis on a single model using both
        principal component and random direction analysis.

        Args:
            trainer: The model's trainer instance
            epoch: Current training epoch
            val_loader: Validation data loader
            model_name: Name of the model
            model_idx: Index of the model
        """
        print(f"\nPerforming landscape analysis for epoch {epoch} on {model_name}...")

        # Setup directory structure
        landscape_dir = os.path.join(self.save_dir, 'landscapes', model_name)
        os.makedirs(landscape_dir, exist_ok=True)

        # Initialize analyzer
        analyzer = LandscapeAnalyzer(
            self.models[model_idx],
            save_dir=landscape_dir,
            fixed_scale=self.fixed_scale,
        )

        try:
            # First analyze using random directions
            save_path = os.path.join(landscape_dir, f"epoch_{epoch}_random_directions.png")
            analyzer.visualize_landscape(
                val_loader,
                epoch,
                model_name=f"{model_name}_random",
                rand_dir=True
            )

            # Then analyze using principal directions
            save_path = os.path.join(landscape_dir, f"epoch_{epoch}_principal_directions.png")
            analyzer.visualize_landscape(
                val_loader,
                epoch,
                model_name=f"{model_name}_pca",
                rand_dir=False
            )

            # Perform numerical landscape analysis
            val_batch = next(iter(val_loader))
            metrics = analyzer.analyze_landscape(val_batch, epoch)

            # Log metrics
            print(f"  [Landscape] Alpha Sharpness: {metrics['alpha_sharpness']:.4f}")
            print("  [Landscape] Top Eigenvalues:")
            for i, ev in enumerate(metrics['top_eigenvalues']):
                print(f"    λ{i+1}: {ev:.4f}")

            print("  [Landscape] Multiscale Sharpness:")
            for scale, value in metrics['multiscale_sharpness'].items():
                print(f"    {scale}: {value:.4f}")

            print(f"  [Landscape] Valley Asymmetry: {metrics['valley_asymmetry']:.4f}")

            # Store metrics history
            for key, value in metrics.items():
                self.metrics_history[model_name][f"landscape_{key}"].append((epoch, value))

        except Exception as e:
            print(f"Warning: Failed to perform landscape analysis:")
            print(traceback.format_exc())
            # Ensure directory exists even if analysis fails
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

    def evaluate_and_visualize(
        self,
        test_loader: DataLoader,
        epoch: int,
        model_names: List[str]
    ) -> None:
        """
        Evaluate models on test data and generate visualizations.

        Args:
            test_loader: DataLoader for test data
            epoch: Current epoch number
            model_names: List of model names to evaluate
        """
        abstention_rates = defaultdict(list)

        # Evaluate each model
        for model_idx, model in enumerate(self.models):
            model_name = model_names[model_idx]
            metrics = self._evaluate_single_model(model, test_loader, model_name)
            abstention_rates[model_name].append(metrics['total_abstentions'])

        # Generate visualization
        self._plot_abstention_rates(abstention_rates, epoch)

In [None]:
def test_model_with_noise(
    model: nn.Module,
    loader: Iterator[tuple[Tensor, Tensor, Tensor]],
    noise_std: float
) -> Dict[str, float]:
    """
    Test a model's performance with optional additive Gaussian noise.

    Args:
        model: Neural network model for arithmetic operations
        loader: DataLoader providing batches of (numbers, operator, targets)
        noise_std: Standard deviation of Gaussian noise to add to inputs

    Returns:
        Dictionary containing various performance metrics:
        - abstention_rate: Frequency of model abstaining from predictions
        - invalid_recall: Recall for invalid computation cases
        - valid_recall: Recall for valid computation cases
        - correct_abstention: Rate of correctly abstaining on invalid cases
        - valid_accuracy: Accuracy on valid cases when not abstaining
    """
    metrics = defaultdict(list)
    model.eval()

    with torch.no_grad():
        for batch in loader:
            numbers, operator, targets = batch
            # Add Gaussian noise to input if specified
            if noise_std > 0:
                noise = torch.randn_like(numbers) * noise_std
                numbers = numbers + noise

            # Get model predictions and ensure consistent shapes
            predictions = model(numbers, operator)
            predictions = predictions.squeeze()
            targets = targets.squeeze()

            # Identify invalid computations and abstentions
            invalid_mask = is_invalid_computation(numbers, operator)
            abstained = torch.isclose(
                predictions,
                torch.tensor(-1.0).to(predictions.device),
                rtol=0.1, atol=0.1
            )

            # Calculate overall abstention rate
            metrics['abstention_rate'].append(abstained.float().mean().item())

            # Calculate recall metrics for invalid cases
            if invalid_mask.any():
                invalid_recall = (abstained & invalid_mask).float().sum() / invalid_mask.float().sum()
                metrics['invalid_recall'].append(invalid_recall.item())

            # Calculate recall metrics for valid cases
            valid_mask = ~invalid_mask
            valid_and_not_abstained = valid_mask & (~abstained)

            if valid_mask.any():
                valid_recall = valid_and_not_abstained.float().sum() / valid_mask.float().sum()
                metrics['valid_recall'].append(valid_recall.item())

            # Track correct abstention rate
            metrics['correct_abstention'].append(
                (abstained & invalid_mask).float().mean().item()
            )

            # Calculate accuracy for valid cases where model didn't abstain
            if valid_and_not_abstained.any():
                pred_vals = predictions[valid_and_not_abstained]
                tgt_vals = targets[valid_and_not_abstained]
                diff = torch.abs(pred_vals - tgt_vals)
                # Use relative threshold for accuracy calculation
                threshold = torch.abs(tgt_vals) * 0.01 + 1e-8
                accuracy = (diff < threshold).float().mean().item()
                metrics['valid_accuracy'].append(accuracy)

    # Ensure valid_accuracy exists in output
    if not metrics['valid_accuracy']:
        metrics['valid_accuracy'] = [0.0]

    return {k: np.mean(v) for k, v in metrics.items()}

def test_model_on_boundary_cases(
    model: nn.Module,
    loader: Iterator[tuple[Tensor, Tensor, Tensor]]
) -> Dict[str, float]:
    """
    Test model performance specifically on boundary cases.

    Args:
        model: Neural network model for arithmetic operations
        loader: DataLoader providing batches of (numbers, operator, targets)

    Returns:
        Dictionary containing performance metrics:
        - invalid_recall: Recall rate for invalid computation cases
        - valid_recall: Recall rate for valid computation cases
        - correct_abstentions: Rate of correctly abstaining on invalid cases
        - incorrect_abstentions: Rate of incorrectly abstaining on valid cases
        - accuracy: Accuracy on valid cases when not abstaining
    """
    metrics = defaultdict(list)
    model.eval()

    with torch.no_grad():
        for batch in loader:
            numbers, operator, targets = batch
            predictions = model(numbers, operator)

            # Ensure consistent shapes
            predictions = predictions.squeeze()
            targets = targets.squeeze()

            # Identify invalid computations and abstentions
            invalid_mask = is_invalid_computation(numbers, operator)
            abstained = torch.isclose(
                predictions,
                torch.tensor(-1.0).to(predictions.device),
                rtol=0.1,
                atol=0.1
            )

            # Calculate recall metrics for invalid cases
            if invalid_mask.any():
                invalid_recall = (abstained & invalid_mask).float().sum() / invalid_mask.float().sum()
                metrics['invalid_recall'].append(invalid_recall.item())

            # Calculate recall metrics for valid cases
            valid_mask = ~invalid_mask
            if valid_mask.any():
                valid_and_not_abstained = valid_mask & (~abstained)
                valid_recall = valid_and_not_abstained.float().sum() / valid_mask.float().sum()
                metrics['valid_recall'].append(valid_recall.item())

            # Track abstention metrics
            metrics['correct_abstentions'].append(
                (abstained & invalid_mask).float().mean().item()
            )
            metrics['incorrect_abstentions'].append(
                (abstained & ~invalid_mask).float().mean().item()
            )

            # Calculate accuracy for valid cases where model didn't abstain
            valid_and_not_abstained = valid_mask & (~abstained)
            if valid_and_not_abstained.any():
                pred_vals = predictions[valid_and_not_abstained]
                tgt_vals = targets[valid_and_not_abstained]

                diff = torch.abs(pred_vals - tgt_vals)
                # Use relative threshold for accuracy calculation
                threshold = torch.abs(tgt_vals) * 0.01 + 1e-8
                accuracy = (diff < threshold).float().mean().item()
                metrics['accuracy'].append(accuracy)

    return {k: np.mean(v) for k, v in metrics.items()}
def run_robustness_experiments(
    models: Dict[str, nn.Module],
    save_dir: str = 'robustness_results',
    model_names: Optional[List[str]] = None,
    test_loader = Any
) -> Dict[str, Dict]:
    """
    Run robustness experiments on multiple models with various noise levels.

    Args:
        models: Dictionary mapping model names to model instances
        save_dir: Directory to save results and visualizations
        model_names: Optional list of model names to test (defaults to all models in dict)

    Returns:
        Dictionary containing test results for boundary and noise experiments
    """
    # Create output directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'plots'), exist_ok=True)

    # Validate model names
    if model_names is None:
        model_names = list(models.keys())
    elif not all(name in models for name in model_names):
        raise ValueError("All model_names must be keys in models dict")

    # Initialize results structure
    results = {
        # 'boundary_test': {},
        'noise_test': {}
    }

    # Define noise levels for testing
    noise_levels = [
        0.5, 1.0, 2.0, 5.0, 10.0, 25.0, 50.0, 100.0,
        200.0, 250.0, 300.0, 500.0, 1000.0, 2000.0
    ]

    # Run experiments for each noise level
    for noise_std in tqdm(noise_levels, desc="Testing noise levels"):
        # results['boundary_test'][noise_std] = {}
        results['noise_test'][noise_std] = {}

        for name in model_names:
            try:
                model = models[name]
                metrics = test_model_with_noise(model, test_loader, noise_std)

                # # Store boundary test results
                # results['boundary_test'][noise_std][name] = {
                #     'correct_abstentions': metrics['correct_abstention'],
                #     'invalid_recall': metrics.get('invalid_recall', 0.0)
                # }

                # Store noise test results
                results['noise_test'][noise_std][name] = {
                    'abstention_rate': metrics['abstention_rate'],
                    'valid_accuracy': metrics['valid_accuracy'],
                    'invalid_recall': metrics.get('invalid_recall', 0.0)
                }

            except Exception as e:
                print(f"\nWarning: Error testing {name} with noise {noise_std}: {str(e)}")
                # Initialize empty results for failed tests
                # results['boundary_test'][noise_std][name] = {
                #     'correct_abstentions': 0.0,
                #     'invalid_recall': 0.0
                # }
                results['noise_test'][noise_std][name] = {
                    'abstention_rate': 0.0,
                    'valid_accuracy': 0.0,
                    'invalid_recall': 0.0
                }

    # Save results to file
    try:
        results_path = os.path.join(save_dir, 'robustness_results.json')
        with open(results_path, 'w') as f:
            json.dump({
                'results': results,
                'metadata': {
                    'model_names': model_names,
                    'noise_levels': noise_levels
                }
            }, f, indent=2)
    except Exception as e:
        print(f"\nWarning: Could not save results: {str(e)}")

    # Generate visualizations if multiple models are present
    if model_names:
        try:
            create_robustness_visualizations(results, save_dir, model_names=model_names)
        except Exception as e:
            print(f"\nWarning: Could not create visualizations: {str(e)}")
            print(traceback.format_exc())

    return results

def create_robustness_visualizations(
    results: Dict[str, Dict],
    save_dir: str,
    model_names: Optional[List[str]] = None
) -> bool:
    """


    Args:
        results: Dictionary containing experiment results
        save_dir: Directory to save visualization files
        model_names: Optional list of model names to include in visualizations

    Returns:
        True if visualizations were created successfully
    """
    os.makedirs(save_dir, exist_ok=True)

    # Get model names from results if not provided
    if model_names is None:
        model_names = list(next(iter(results['boundary_test'].values())).keys())

    # Set up matplotlib style
    plt.style.use('seaborn-v0_8-whitegrid')
    base_params = {
        'font.family': 'sans-serif',
        'font.size': 12,
        'axes.labelsize': 14,
        'axes.titlesize': 14,
        'legend.fontsize': 12,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.grid': True,
        'grid.alpha': 0.3,
    }

    # Define color palette for consistent visualization
    colors = ['#4878D0', '#EE854A', '#6ACC64', '#9467BD', '#FF7F0E', '#2CA02C']

    if len(model_names) > 1:
        # Generate boundary test visualization
        # _create_boundary_test_plot(
        #     results, model_names, colors, base_params, save_dir
        # )

      #  Generate progressive noise test visualization
        _create_noise_test_plot(
            results, model_names, colors, base_params, save_dir
        )

        # Generate invalid recall visualization
        _create_invalid_recall_plot(
            results, model_names, colors, base_params, save_dir
        )
    else:
        # Save metrics as JSON for single model case
        metrics_path = os.path.join(save_dir, f'{model_names[0]}_robustness_metrics.json')
        with open(metrics_path, 'w') as f:
            json.dump({
                'boundary_test': results['boundary_test'],
                'noise_test': results['noise_test']
            }, f, indent=2)

    return True

# def _create_boundary_test_plot(
#     results: Dict[str, Dict],
#     model_names: List[str],
#     colors: List[str],
#     base_params: Dict[str, Any],
#     save_dir: str
# ) -> None:
#     """Helper function to create boundary test visualization."""
#     plt.rcParams.update({**base_params, 'figure.figsize': (8, 6)})
#     fig = plt.figure(constrained_layout=True)
#     ax = fig.add_subplot(111)

#     noise_levels = sorted(results['boundary_test'].keys())
#     x = np.arange(len(noise_levels))
#     width = 0.8 / len(model_names)

#     for i, model_name in enumerate(model_names):
#         try:
#             correct_vals = [
#                 results['boundary_test'][n][model_name]['correct_abstentions']
#                 for n in noise_levels
#             ]
#             position = x + width * (i - len(model_names)/2 + 0.5)
#             ax.bar(position, correct_vals, width,
#                    label=model_name,
#                    color=colors[i % len(colors)],
#                    edgecolor='white',
#                    linewidth=0.5)
#         except KeyError as e:
#             print(f"Warning: Missing data for {model_name}: {e}")
#             continue

#     _format_bar_plot(ax, 'Correct Abstentions at Decision Boundaries',
#                     'Noise Level', 'Rate', noise_levels, model_names)

#     plt.savefig(os.path.join(save_dir, 'boundary_test.png'),
#                 dpi=300, bbox_inches='tight')
#     plt.close(fig)

def _create_noise_test_plot(
    results: Dict[str, Dict],
    model_names: List[str],
    colors: List[str],
    base_params: Dict[str, Any],
    save_dir: str
) -> None:
    """Helper function to create progressive noise test visualization."""
    plt.rcParams.update({**base_params, 'figure.figsize': (10, 7)})
    fig = plt.figure(constrained_layout=True)
    ax = fig.add_subplot(111)

    noise_levels = sorted(results['noise_test'].keys())

    for i, model_name in enumerate(model_names):
        try:
            metrics_data = [results['noise_test'][n][model_name]
                          for n in noise_levels]

            # Plot accuracy
            ax.plot(noise_levels,
                    [m['valid_accuracy'] for m in metrics_data],
                    label=f'{model_name} (accuracy)',
                    color=colors[i % len(colors)],
                    marker='s',
                    linestyle='--',
                    linewidth=3,
                    markersize=8,
                    markeredgecolor='white',
                    markeredgewidth=2)

            # Plot abstention rate
            ax.plot(noise_levels,
                    [m['abstention_rate'] for m in metrics_data],
                    label=f'{model_name} (abstention)',
                    color=colors[i % len(colors)],
                    marker='o',
                    linestyle='-',
                    linewidth=2,
                    markersize=8,
                    markeredgecolor='white',
                    markeredgewidth=2,
                    alpha=0.7)
        except KeyError as e:
            print(f"Warning: Missing data for {model_name}: {e}")
            continue

    _format_line_plot(ax, 'Performance Under Progressive Noise',
                     'Noise Level', 'Rate')

    plt.savefig(os.path.join(save_dir, 'progressive_noise.png'),
                dpi=300, bbox_inches='tight')
    plt.close(fig)

def _create_invalid_recall_plot(
    results: Dict[str, Dict],
    model_names: List[str],
    colors: List[str],
    base_params: Dict[str, Any],
    save_dir: str
) -> None:
    """Helper function to create invalid recall visualization."""
    plt.rcParams.update({**base_params, 'figure.figsize': (8, 6)})
    fig = plt.figure(constrained_layout=True)
    ax = fig.add_subplot(111)

    noise_levels = sorted(results['noise_test'].keys())
    x = np.arange(len(noise_levels))
    width = 0.8 / len(model_names)

    for i, model_name in enumerate(model_names):
        try:
            invalid_recall_vals = [
                results['noise_test'][n][model_name]['invalid_recall']
                for n in noise_levels
            ]
            position = x + width * (i - len(model_names)/2 + 0.5)
            ax.bar(position, invalid_recall_vals, width,
                   label=model_name,
                   color=colors[i % len(colors)],
                   edgecolor='white',
                   linewidth=0.5)
        except KeyError as e:
            print(f"Warning: Missing data for {model_name}: {e}")
            continue

    _format_bar_plot(ax, 'Invalid Case Detection Under Progressive Noise',
                    'Noise Level', 'Invalid Recall Rate', noise_levels, model_names)

    plt.savefig(os.path.join(save_dir, 'invalid_recall.png'),
                dpi=300, bbox_inches='tight')
    plt.close(fig)

def _format_bar_plot(
    ax: plt.Axes,
    title: str,
    xlabel: str,
    ylabel: str,
    xtick_labels: List[Union[str, float]],
    model_names: List[str]
) -> None:
    """Helper function to format bar plots consistently."""
    ax.set_title(title, pad=20)
    ax.set_xlabel(xlabel, labelpad=10)
    ax.set_ylabel(ylabel, labelpad=10)
    ax.set_xticks(np.arange(len(xtick_labels)))
    ax.set_xticklabels(xtick_labels)
    ax.set_ylim(0, 1.0)
    ax.set_yticks(np.arange(0, 1.1, 0.2))

    legend_cols = min(3, len(model_names))
    ax.legend(bbox_to_anchor=(0.5, -0.15),
             loc='upper center',
             ncol=legend_cols,
             frameon=False,
             handlelength=1.5)

    width = 0.8 / len(model_names)
    x = np.arange(len(xtick_labels))
    ax.set_xlim(x[0] - width*2, x[-1] + width*2)

def _format_line_plot(
    ax: plt.Axes,
    title: str,
    xlabel: str,
    ylabel: str
) -> None:
    """Helper function to format line plots consistently."""
    ax.set_title(title, pad=20)
    ax.set_xlabel(xlabel, labelpad=10)
    ax.set_ylabel(ylabel, labelpad=10)
    ax.grid(True, linestyle='--', alpha=0.4)
    ax.set_ylim(0, 1.0)
    ax.set_yticks(np.arange(0, 1.1, 0.2))
    ax.legend(bbox_to_anchor=(1.02, 0.5),
             loc='center left',
             frameon=False,
             handlelength=2.5,
             borderaxespad=0)
def evaluate_models_on_test(
    models: List[nn.Module],
    test_loader: DataLoader,
    save_dir: Union[str, Path] = 'test_results',
    model_names: Optional[List[str]] = None
) -> Dict[str, Dict[str, float]]:
    """
    Evaluate multiple models on test data and calculate performance metrics.

    Args:
        models: List of PyTorch models to evaluate
        test_loader: DataLoader containing test data batches
        save_dir: Directory to save evaluation results
        model_names: Optional list of names for the models (must match length of models)

    Returns:
        Dictionary mapping model names to their performance metrics

    Raises:
        ValueError: If number of model names doesn't match number of models
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Initialize results storage with nested defaultdict
    results: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(list))

    # Generate default model names if not provided
    if model_names is None:
        model_names = [f"Model_{i}" for i in range(len(models))]
    elif len(model_names) != len(models):
        raise ValueError("Number of model names must match number of models")

    # Evaluate each model separately
    for model_idx, (model, model_name) in enumerate(zip(models, model_names)):
        print(f"\nEvaluating {model_name} Model...")
        model.eval()

        # Initialize metric tracking
        metrics = defaultdict(float)
        total_samples = 0

        # Evaluate model on test data
        with torch.no_grad():
            for batch in test_loader:
                # Unpack batch data
                numbers, operator, targets = batch

                # Get model predictions
                predictions = model(numbers, operator)

                # Identify valid and invalid computations
                invalid_mask = is_invalid_computation(numbers, operator)
                valid_mask = ~invalid_mask

                # Check for abstentions (predictions close to abstention token)
                abstained = torch.isclose(
                    predictions.squeeze(),
                    torch.tensor(-1.0).to(predictions.device),
                    rtol=0.1,
                    atol=0.1
                )

                # Update batch statistics
                batch_size = len(predictions)
                total_samples += batch_size

                # Calculate accuracy for valid computations where model didn't abstain
                valid_and_not_abstained = valid_mask & (~abstained)
                if valid_and_not_abstained.any():
                    pred_vals = predictions[valid_and_not_abstained]
                    tgt_vals = targets[valid_and_not_abstained]

                    # Compare predictions with targets using relative threshold
                    diff = torch.abs(pred_vals - tgt_vals)
                    threshold = torch.abs(tgt_vals) * 0.01 + 1e-8  # 1% relative threshold
                    metrics['correct_valid'] += (diff < threshold).sum().item()

                # Update running metrics
                metrics['total_valid'] += valid_mask.sum().item()
                metrics['total_invalid'] += invalid_mask.sum().item()
                metrics['total_abstained'] += abstained.sum().item()
                metrics['correct_abstentions'] += (abstained & invalid_mask).sum().item()
                metrics['loss'] += F.mse_loss(predictions, targets).item() * batch_size

        # Calculate final metrics as percentages
        final_metrics = {
            'accuracy': (metrics['correct_valid'] / max(metrics['total_valid'], 1)),
            'abstention_rate': (metrics['total_abstained'] / total_samples),
            'abstention_precision': (metrics['correct_abstentions'] / max(metrics['total_abstained'], 1)),
            'abstention_recall': (metrics['correct_abstentions'] / max(metrics['total_invalid'], 1)),
            'loss': metrics['loss'] / total_samples
        }

        # Store results for this model
        results[model_name].update(final_metrics)

        # Print evaluation results
        print(f"  Test Loss:            {final_metrics['loss']:.4f}")
        print(f"  Test Accuracy:        {final_metrics['accuracy'] * 100:.4f}%")
        print(f"  Abstention Rate:      {final_metrics['abstention_rate'] * 100:.4f}%")
        print(f"  Abstention Precision: {final_metrics['abstention_precision'] * 100:.4f}%")
        print(f"  Abstention Recall:    {final_metrics['abstention_recall'] * 100:.4f}%")

        # Save individual model metrics
        metrics_path = os.path.join(save_dir, f'{model_name}_metrics.json')
        with open(metrics_path, 'w') as f:
            json.dump(final_metrics, f, indent=2)

    return results

In [None]:
class OODEvaluator:
    """
    Evaluator for testing model behavior on out-of-distribution inputs.

    Handles evaluation of models on various test cases including:
    - Number format variations
    - Novel operators
    - Cross-boundary cases
    """

    def __init__(self, test_file_path: str) -> None:
        """
        Initialize the OOD evaluator.

        Args:
            test_file_path: Path to JSON file containing test cases
        """
        with open(test_file_path, 'r') as f:
            self.test_data = json.load(f)

        # Model configuration constants
        self.RTOL: float = 0.1
        self.ATOL: float = 0.1

        # Mapping of operators to their indices
        # Novel operators are mapped to '@' (2) by default
        self.op_to_idx: Dict[str, int] = {
            '+': 0,
            '-': 1,
            '@': 2,
            '#': 2,  # Novel operators
            '$': 2,
            '&': 2,
            '^': 2,
            '%': 2
        }

    def preprocess_input(self, arg1: str, arg2: str) -> Tuple[Optional[float], Optional[float]]:
        """
        Convert string number representations to floats.

        Args:
            arg1: First number as string
            arg2: Second number as string

        Returns:
            Tuple of (float, float) if conversion successful, (None, None) otherwise
        """
        try:
            num1 = float(arg1)
            num2 = float(arg2)
            return num1, num2
        except ValueError:
            print(f"Error converting numbers: {arg1}, {arg2}")
            return None, None
    def check_abstain_label_distribution(self, test_cases: List[Dict[str, Any]]) -> float:
      """
      Calculate the fraction of test cases where 'Should Abstain?' is False.

      Args:
          test_cases: List of test cases, each containing the 'Should Abstain?' label.

      Returns:
          Fraction of cases where 'Should Abstain?' is False.
      """
      total_cases = len(test_cases)
      if total_cases == 0:
          return 0.0  # Avoid division by zero if the test set is empty

      non_abstain_cases = sum(1 for case in test_cases if not case['Should Abstain?'])
      return non_abstain_cases / total_cases


    def check_abstention(self, output: torch.Tensor) -> bool:
        """
        Check if model output indicates abstention.

        Args:
            output: Model output tensor

        Returns:
            True if output is close to abstention token, False otherwise
        """
        return torch.isclose(
            output.squeeze(),
            torch.tensor(-1.0,
                        dtype=torch.float32,
                        device=output.device),
            rtol=self.RTOL,
            atol=self.ATOL
        ).item()

    def evaluate_model(self,
                      model: nn.Module,
                      test_cases: List[Dict[str, Any]]) -> Dict[str, float]:
        """
        Evaluate a model on a set of test cases.

        Args:
            model: PyTorch model to evaluate
            test_cases: List of test cases, each containing arguments and expected behavior

        Returns:
            Dictionary of evaluation metrics including accuracy and abstention metrics
        """
        correct = 0
        total = len(test_cases)
        results = defaultdict(int)

        for case in test_cases:
            # Preprocess inputs
            arg1, arg2 = self.preprocess_input(case['Argument 1'], case['Argument 2'])
            if arg1 is None or arg2 is None:
                continue

            # Prepare model inputs
            numbers = torch.tensor([[arg1, arg2]], dtype=torch.float32)
            op_idx = self.op_to_idx.get(case['Operator'], 2)
            operator = torch.tensor([op_idx], dtype=torch.long)

            # Get model prediction
            with torch.no_grad():
                output = model(numbers, operator)
                predicted_abstain = self.check_abstention(output)

            # Update metrics
            correct += (predicted_abstain == case['Should Abstain?'])
            self._update_results_tracking(results, case, predicted_abstain)

        # Calculate and return final metrics
        return self._calculate_metrics(correct, total, results, test_cases)

    def evaluate_all_models(self,
                          models: Dict[str, nn.Module]) -> Dict[str, Dict[str, Dict[str, float]]]:
        """
        Evaluate multiple models on all test sets.

        Args:
            models: Dictionary mapping model names to PyTorch models

        Returns:
            Nested dictionary of results for each model and test type
        """
        results = {}

        for model_name, model in models.items():
            model.eval()
            results[model_name] = {
                'number_format': self.evaluate_model(
                    model, self.test_data['number_format_tests']
                ),
                'novel_operator': self.evaluate_model(
                    model, self.test_data['novel_operator_tests']
                ),
                'cross_boundary': self.evaluate_model(
                    model, self.test_data['cross_boundary_tests']
                )
            }

        return results

    def _update_results_tracking(self,
                               results: Dict[str, int],
                               case: Dict[str, Any],
                               predicted_abstain: bool) -> None:
        """
        Update tracking dictionary with results from a single test case.

        Args:
            results: Dictionary tracking result counts
            case: Current test case
            predicted_abstain: Whether model predicted abstention
        """
        results['total_cases'] += 1
        if case['Should Abstain?']:
            results['should_abstain'] += 1
            if predicted_abstain:
                results['correct_abstentions'] += 1
        else:
            results['should_not_abstain'] += 1
            if not predicted_abstain:
                results['correct_non_abstentions'] += 1

    def _calculate_metrics(self,
                         correct: int,
                         total: int,
                         results: Dict[str, int],
                         test_cases: List[Dict[str, Any]]) -> Dict[str, float]:
        """
        Calculate final metrics from results tracking.

        Args:
            correct: Total number of correct predictions
            total: Total number of test cases
            results: Dictionary tracking result counts
            test_cases: List of all test cases

        Returns:
            Dictionary of calculated metrics
        """
        return {
            'overall_accuracy': correct / total,
            'abstention_precision': results['correct_abstentions'] / max(
                1, results['should_abstain']
            ),
            'abstention_recall': results['correct_abstentions'] / max(
                1, sum(1 for c in test_cases if c['Should Abstain?'])
            ),
            'total_cases': total
        }


def evaluate_models(
    models: Dict[str, nn.Module],
    test_file: str = 'ood_test_set.json'
) -> Dict[str, Dict[str, Dict[str, float]]]:
    """
    Main evaluation function for OOD testing.

    Args:
        models: Dictionary mapping model names to PyTorch models
        test_file: Path to JSON file containing test cases

    Returns:
        Nested dictionary containing evaluation results for all models
    """
    evaluator = OODEvaluator(test_file)
    return evaluator.evaluate_all_models(models)


def print_evaluation_results(
    results: Dict[str, Dict[str, Dict[str, float]]]
) -> None:
    """
    Print formatted evaluation results.

    Args:
        results: Nested dictionary of evaluation results
    """
    for model_name, model_results in results.items():
        print(f"\nResults for {model_name}:")
        for test_type, metrics in model_results.items():
            print(f"\n{test_type.replace('_', ' ').title()}:")
            for metric, value in metrics.items():
                print(f"{metric.replace('_', ' ').title()}: {value:.3f}")

In [None]:
@dataclass
class ExperimentConfig:
    """Configuration for experiment parameters."""
    analysis_freq: int
    seed: int = 16
    noise_std: float = 0.1
    epochs: int = 72
    batch_size: int = 32
    ood_test_file: str = 'ood_test_set.json'
    val_freq: int = 1

class ExperimentDownloadManager:
    """Manages experiment result files and downloads."""

    def __init__(self, base_dir: str = "experiment_results") -> None:
        """
        Initialize the download manager.

        Args:
            base_dir: Base directory for storing experiment results
        """
        self.base_dir: str = base_dir
        self.files_to_download: Set[str] = set()

    def clear_old_results(self) -> None:
        """Clear previous results and create fresh directory structure."""
        if os.path.exists(self.base_dir):
            shutil.rmtree(self.base_dir)

        # Create directory structure
        subdirs = ['plots', 'metrics', 'landscapes', 'robustness']
        os.makedirs(self.base_dir)
        for subdir in subdirs:
            os.makedirs(os.path.join(self.base_dir, subdir))

    def add_file(self, filepath: str) -> None:
        """
        Track a file for downloading.

        Args:
            filepath: Path to file to be tracked
        """
        if os.path.exists(filepath):
            self.files_to_download.add(filepath)

    def zip_and_download(self, experiment_name: str) -> None:
        """
        Create and download a zip archive of experiment results.

        Args:
            experiment_name: Base name for the zip file
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        zip_filename = f'{experiment_name}_{timestamp}.zip'
        zip_path = os.path.join(self.base_dir, zip_filename)

        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for file_path in tqdm(self.files_to_download, desc="Adding files"):
                arcname = os.path.relpath(file_path, self.base_dir)
                zipf.write(file_path, arcname)

        files.download(zip_path)


class ExperimentRunner:
    """Manages and runs the entire experiment."""

    def __init__(self, config: ExperimentConfig) -> None:
        """
        Initialize the experiment runner.

        Args:
            config: Experiment configuration parameters
        """
        self.config = config
        self.download_mgr = ExperimentDownloadManager()
        self.download_mgr.clear_old_results()

        # Set random seed
        self.seed = set_seed(config.seed)

        # Initialize metrics storage
        self.computational_metrics: Dict[str, Any] = {}
        os.makedirs(os.path.join(self.download_mgr.base_dir, 'ood_tests'), exist_ok=True)

    def _run_ood_evaluation(
        self,
        models: List[nn.Module],
        model_names: List[str]
    ) -> Dict[str, Dict[str, Dict[str, float]]]:
        """
        Run OOD evaluations on all models.

        Args:
            models: List of trained models
            model_names: Names of the models

        Returns:
            Dictionary containing OOD evaluation results
        """
        print("\nRunning OOD evaluations...")

        model_dict = dict(zip(model_names, models))

        # Run OOD evaluation
        ood_results = evaluate_models(
            models=model_dict,
            test_file=self.config.ood_test_file
        )

        # Save OOD results
        ood_save_path = os.path.join(
            self.download_mgr.base_dir,
            'ood_tests',
            'ood_results.json'
        )
        with open(ood_save_path, 'w') as f:
            json.dump(ood_results, f, indent=2)
        self.download_mgr.add_file(ood_save_path)

        # Create visualizations for OOD results
       # self._create_ood_visualizations(ood_results, model_names)

        return ood_results

    def _create_ood_visualizations(
        self,
        ood_results: Dict[str, Dict[str, Dict[str, float]]],
        model_names: List[str]
    ) -> None:
        """
        Create visualizations for OOD test results.

        Args:
            ood_results: Results from OOD evaluation
            model_names: Names of the models
        """
        viz_dir = os.path.join(self.download_mgr.base_dir, 'ood_tests', 'visualizations')
        os.makedirs(viz_dir, exist_ok=True)

        # Create log comparison plots
        create_log_comparison_plots(ood_results, viz_dir)

        # Track visualization files
        for file in glob.glob(os.path.join(viz_dir, '*.png')):
            self.download_mgr.add_file(file)



    def _initialize_models(self) -> tuple[List[nn.Module], List[str]]:
        """
        Initialize all models for the experiment.

        Returns:
            Tuple of (models, model_names)
        """
        decay_control = ArithmeticNet()
        base_model = ArithmeticNet()
        pgd_model = ArithmeticNet()
        adv_model = ArithmeticNet()  # New adversarial training model
        full_pgd = ArithmeticNet()

        models = [base_model, adv_model, decay_control, full_pgd, pgd_model]
        model_names = ["base_adam", "input_space_adv", "decay_control", "full_pgd", "pgd"]

        return models, model_names

    def _get_data_loaders(self) -> tuple[Any, Any, Any]:
        """
        Create data loaders for training, validation, and testing.

        Returns:
            Tuple of (train_loader, val_loader, test_loader)
        """
        with open('abstention_dataset.json', 'r') as f:
          dataset_dict = json.load(f)

        noise_config = {
            'enabled': True,
            'std': self.config.noise_std
        }

        return ArithmeticDataset.get_train_val_test_loaders(
            dataset_dict,
            batch_size=self.config.batch_size,
            noise_config=noise_config
        )

    def _get_trainer(self, model: nn.Module, name: str) -> Any:
        """
        Get appropriate trainer based on model type.

        Args:
            model: Model to train
            name: Model name/type

        Returns:
            Trainer instance
        """
        if name == "pgd":
            return EfficientPGDTrainer(
                model=model,
                optimizer=optim.Adam(model.parameters(), lr=0.001),
                k=2,
                sample_ratio=0.3
            )
        elif name == "full_pgd":
            return PGDTrainer(
                model=model,
                optimizer=optim.Adam(model.parameters(), lr=0.001),
                k=2 # WAS 3 TESTING K =2
            )
        elif name == "input_space_adv":  # New case for adversarial trainer
            return InputSpaceAdversarialTrainer(
               model=model,
               optimizer=optim.Adam(model.parameters(), lr=0.001),
               epsilon=0.1,  # Small perturbation magnitude
               steps=2,      # Number of PGD steps
               adv_ratio=0.10  # Was 5% adversarial examples before, trying 10 first
          )
        elif name == "decay_control":
            return BaseTrainer(
                model,
                optimizer=torch.optim.Adam(
                    model.parameters(),
                    weight_decay=1e-4,
                    lr=0.001
                )
            )
        elif name == "base_adam":
            return BaseTrainer(
                model,
                optimizer=torch.optim.Adam(
                    model.parameters(),
                    lr=0.001
                )
            )
        else:
            raise ValueError(f"Unknown model type: {name}")

    def run(self) -> Dict[str, Any]:
        """
        Run the full experiment pipeline.

        Returns:
            Dictionary containing all experiment results
        """
        # Initialize models and framework
        models, model_names = self._initialize_models()
        train_loader, val_loader, test_loader = self._get_data_loaders()
        trainers = [self._get_trainer(model, name)
                    for model, name in zip(models, model_names)]


        # Initialize framework
        framework = UnifiedTrainingFramework(
            models=models,
            save_dir=self.download_mgr.base_dir,
            training_config=self.config,
            model_names=model_names.copy(),
            trainers=trainers
        )

        # Train models and collect metrics
        all_metrics = self._train_models(
            models, model_names, framework,
            train_loader, val_loader
        )

        # Run evaluations
        test_results, boundary_results = self._run_evaluations(
            models, model_names, test_loader
        )

        # Run robustness experiments
        robustness_results = self._run_robustness_tests(
            models, model_names, test_loader
        )

        # Run OOD evaluations
        ood_results = self._run_ood_evaluation(models, model_names)

        # Compile and save final results
        results = self._save_results(
            test_results=test_results,
            boundary_results=boundary_results,
            robustness_results=robustness_results,
            ood_results=ood_results,
            all_metrics=all_metrics
        )

        return results


    def _train_models(
        self,
        models: List[nn.Module],
        model_names: List[str],
        framework: Any,
        train_loader: Any,
        val_loader: Any
    ) -> Dict[str, Any]:
        """Train all models and collect metrics."""
        all_metrics = {}

        for idx, (model, name) in enumerate(zip(models, model_names)):
            trainer = self._get_trainer(model, name)
            print(f"Training {(name)} model:")
            print(f"Trainer settings:\n {trainer.__dict__}")
            model_metrics = framework.train_single_model(
                model_idx=idx,
                train_loader=train_loader,
                val_loader=val_loader,
                epochs=self.config.epochs,
                analyze_every=self.config.analysis_freq
            )

            all_metrics[name] = model_metrics
            if hasattr(framework.trainers[idx], "tracker"):
                self.computational_metrics[name] = framework.trainers[idx].tracker.get_computational_metrics()

            # Save model checkpoint
            model_path = os.path.join(self.download_mgr.base_dir, f'model_{name}.pt')
            torch.save(model.state_dict(), model_path)
            self.download_mgr.add_file(model_path)

        return all_metrics

    def _run_evaluations(
        self,
        models: List[nn.Module],
        model_names: List[str],
        test_loader: Any
    ) -> tuple[Dict[str, Any], Dict[str, Any]]:
        """Run model evaluations on test and boundary datasets."""
        # Create boundary test loader
        boundary_loader = create_boundary_test_loader(
            test_loader,
            num_samples=1000,
        )

        metrics_dir = os.path.join(self.download_mgr.base_dir, 'metrics')

        # Run evaluations
        print("Evaluating models on standard test set:")
        test_results = evaluate_models_on_test(
            models=models,
            test_loader=test_loader,
            save_dir=metrics_dir,
            model_names=model_names
        )

        print("Evaluating models on boundary challenge set:")

        boundary_results = evaluate_models_on_test(
            models=models,
            test_loader=boundary_loader,
            save_dir=metrics_dir,
            model_names=model_names
        )

        return test_results, boundary_results

    def _run_robustness_tests(
        self,
        models: List[nn.Module],
        model_names: List[str],
        test_loader: Any
    ) -> Dict[str, Any]:
        """Run robustness experiments on models."""
        model_dict = dict(zip(model_names, models))

        return run_robustness_experiments(
            models=model_dict,
            save_dir=os.path.join(self.download_mgr.base_dir, 'robustness'),
            model_names=model_names,
            test_loader = test_loader
        )

    def _save_results(
        self,
        test_results: Dict[str, Any],
        boundary_results: Dict[str, Any],
        robustness_results: Dict[str, Any],
        ood_results: Dict[str, Any],
        all_metrics: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Save all results including OOD evaluation."""
        from google.colab import drive

        # Mount Google Drive
        drive.mount('/content/drive')

        final_results = {
            'test_results': test_results,
            'boundary_results': boundary_results,
            'robustness_results': robustness_results,
            'ood_results': ood_results,
            'training_metrics': all_metrics,
            'computational_metrics': self.computational_metrics,
            'experiment_config': {
                'seed': self.seed,
                'noise_std': self.config.noise_std,
                'epochs': self.config.epochs,
                'ood_test_file': self.config.ood_test_file
            }
        }

        # Save results locally first
        results_path = os.path.join(
            self.download_mgr.base_dir,
            'final_results.json'
        )
        with open(results_path, 'w') as f:
            json.dump(final_results, f, indent=2)
        self.download_mgr.add_file(results_path)

        # Track generated files
        for root, _, files in os.walk(self.download_mgr.base_dir):
            for file in files:
                if file.endswith(('.png', '.jpg', '.json', '.pt')):
                    self.download_mgr.add_file(os.path.join(root, file))

        # Create archive name
        experiment_name = (
            f"abstention_experiment_s{self.seed}_"
            f"n{self.config.noise_std}"
        )
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        zip_filename = f'{experiment_name}_{timestamp}.zip'
        zip_path = os.path.join(self.download_mgr.base_dir, zip_filename)

        # Create zip file
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for file_path in tqdm(self.download_mgr.files_to_download, desc="Adding files"):
                arcname = os.path.relpath(file_path, self.download_mgr.base_dir)
                zipf.write(file_path, arcname)

        # Save to Google Drive
        drive_folder = '/content/drive/MyDrive/abstention_experiments'
        os.makedirs(drive_folder, exist_ok=True)
        drive_path = os.path.join(drive_folder, zip_filename)
        shutil.copy2(zip_path, drive_path)
        print(f"\nSaved experiment archive to Google Drive: {drive_path}")

        # Download to local machine
        self.download_mgr.zip_and_download(experiment_name)

        return final_results


def run_experiment(
    analysis_freq: int,
    seed: int = 16,
    noise_std: float = 0.1,
    epochs: int = 72
) -> Dict[str, Any]:
    """
    Main entry point for running experiments.

    Args:
        analysis_freq: Frequency of analysis during training
        seed: Random seed
        noise_std: Standard deviation of noise
        epochs: Number of training epochs

    Returns:
        Dictionary containing all experiment results
    """
    config = ExperimentConfig(
        analysis_freq=analysis_freq,
        seed=seed,
        noise_std=noise_std,
        epochs=epochs,
        ood_test_file='/content/ood_test_set.json'
    )

    runner = ExperimentRunner(config)
    return runner.run()

if __name__ == "__main__":
    results = run_experiment(
        analysis_freq=80,
        seed= 1,
        noise_std=0.3,
        epochs=80
    )

In [None]:
@dataclass
class VisualizationConfig:
    """Configuration for visualization styling."""
    color_map: Dict[str, str]
    figsize: Tuple[int, int] = (12, 6)
    grid_alpha: float = 0.2
    bar_width_ratio: float = 0.8

    def __init__(self) -> None:
        """Initialize with default color map for the actual models."""
        self.color_map = {
            'pgd': '#FF7043',          # Orange
            'full_pgd': '#FFAB91',     # Light orange
            'base_adam': '#A5D6A7',    # Green
            'input_space_adv': '#90CAF9', # Blue
            'decay_control': '#CE93D8'  # Purple
        }


def load_and_process_metrics(
    json_path: str,
    baseline_model: str = "base_adam"
) -> Tuple[Dict[str, List[float]], List[str], List[str], Dict[str, float]]:
    """
    Load and process computational metrics data from JSON file.

    Args:
        json_path: Path to JSON file containing metrics
        baseline_model: Name of model to use as baseline

    Returns:
        Tuple containing:
        - Dictionary of relative values for each model
        - List of model names
        - List of metric names
        - Dictionary of baseline values
    """
    with open(json_path, 'r') as f:
        data = json.load(f)

    # Extract computational metrics section
    metrics_data = data.get('computational_metrics', {})
    if not metrics_data:
        raise ValueError("No computational metrics found in JSON file")

    models = list(metrics_data.keys())
    # Get metrics from first model
    metrics = list(metrics_data[baseline_model].keys())

    # Get baseline values
    baseline_values = {
        metric: metrics_data[baseline_model][metric]
        for metric in metrics
    }

    # Calculate relative values
    relative_values = {}
    for model in models:
        relative_values[model] = [
            metrics_data[model][metric] / baseline_values[metric]
            for metric in metrics
        ]

    return relative_values, models, metrics, baseline_values


def create_bar_plot(
    relative_values: Dict[str, List[float]],
    models: List[str],
    metrics: List[str],
    config: VisualizationConfig,
    baseline_model: str
) -> Tuple[Figure, Axes]:
    """
    Create bar plot visualization of relative metrics.

    Args:
        relative_values: Dictionary of relative values for each model
        models: List of model names
        metrics: List of metric names
        config: Visualization configuration
        baseline_model: Name of baseline model

    Returns:
        Tuple of (Figure, Axes) for the created plot
    """
    fig, ax = plt.subplots(figsize=config.figsize)

    # Calculate bar positions
    x = np.arange(len(metrics))
    width = config.bar_width_ratio / len(models)

    # Plot bars for each model
    for i, model in enumerate(models):
        positions = x + width * (i - len(models)/2 + 0.5)
        bars = ax.bar(
            positions,
            relative_values[model],
            width,
            label=model,
            color=config.color_map[model]
        )

        # Add value labels
        _add_value_labels(ax, bars)

    _customize_plot(
        ax, baseline_model, metrics,
        config.grid_alpha
    )

    plt.tight_layout()
    return fig, ax


def _add_value_labels(ax: Axes, bars) -> None:
    """Add value labels on top of bars."""
    for bar in bars:
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width()/2.,
            height,
            f'{height:.2f}x',
            ha='center',
            va='bottom',
            rotation=0
        )


def _customize_plot(
    ax: Axes,
    baseline_model: str,
    metrics: List[str],
    grid_alpha: float
) -> None:
    """Apply custom styling to the plot."""
    # Set labels and title
    ax.set_ylabel(f'Relative Cost ({baseline_model} = 1.0)')
    ax.set_title('Computational Metrics Comparison')

    # Configure x-axis with better metric names
    ax.set_xticks(np.arange(len(metrics)))
    metric_display_names = {
        'forward_passes': 'Forward Passes',
        'backward_passes': 'Backward Passes',
        'total_flops': 'Total FLOPs',
        'wall_time': 'Wall Time'
    }
    ax.set_xticklabels([metric_display_names.get(m, m) for m in metrics], rotation=45)

    # Add baseline reference line
    ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5)

    # Configure grid
    ax.grid(True, axis='y', linestyle='-', alpha=grid_alpha)
    ax.set_axisbelow(True)

    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Configure legend with better model names
    model_display_names = {
        'pgd': 'PGD (Ours)',
        'full_pgd': 'Full PGD',
        'base_adam': 'Base Adam',
        'input_space_adv': 'Input Space Adv',
        'decay_control': 'Weight Decay'
    }

    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        handles,
        [model_display_names.get(label, label) for label in labels],
        title='Models',
        bbox_to_anchor=(1.05, 1),
        loc='upper left'
    )


def visualize_computational_metrics(
    json_path: str,
    baseline_model: str = "base_adam",
    figsize: Tuple[int, int] = (12, 6),
    save_path: Optional[str] = None
) -> Tuple[Figure, Axes]:
    """
    Create visualization of computational metrics comparison.

    Args:
        json_path: Path to JSON file containing metrics data
        baseline_model: Name of model to use as baseline for comparison
        figsize: Tuple of (width, height) for the figure
        save_path: Optional path to save the visualization

    Returns:
        Tuple of (Figure, Axes) for the created plot

    Example:
        >>> fig, ax = visualize_computational_metrics(
        ...     'results.json',
        ...     save_path='computational_metrics.png'
        ... )
    """
    # Initialize configuration
    config = VisualizationConfig()
    config.figsize = figsize

    # Load and process data
    relative_values, models, metrics, _ = load_and_process_metrics(
        json_path, baseline_model
    )

    # Create visualization
    fig, ax = create_bar_plot(
        relative_values, models, metrics,
        config, baseline_model
    )

    # Save if path provided
    if save_path:
        plt.savefig(
            save_path,
            bbox_inches='tight',
            dpi=300
        )
        print(f"Saved visualization to {save_path}")

    return fig, ax

In [None]:
def load_model(
    file_path: Union[str, Path],
) -> Optional[nn.Module]:
    """
    Load a PyTorch model from a checkpoint file.

    Supports loading from a state dictionary saved as OrderedDict in .pt format.
    Args:
        file_path: Path to the model checkpoint file

    Returns:
        Loaded PyTorch model
    """
    try:
        # Convert path to Path object for better handling
        file_path = Path(file_path)
        model = ArithmeticNet()  # Create new model instance

        # Load checkpoint (state dict)
        state_dict = torch.load(file_path)
        # If it's already a state dict (OrderedDict), use it directly
        model.load_state_dict(state_dict)

        # Set to evaluation mode
        model.eval()

        print(f"Successfully loaded model from {file_path}")
        return model

    except Exception as e:
        print(f"Error loading model from {file_path}: {str(e)}")
        print("Detailed error info:")
        print(f"  - File exists: {file_path.exists()}")
        print(f"  - File size: {file_path.stat().st_size if file_path.exists() else 'N/A'}")
        print(f"  - Error type: {type(e).__name__}")
        return None

In [None]:
def create_ood_comparison_plots(json_path, save_dir):
    """
    Create comparison plots for different test types, showing relative performance
    between different models.

    Args:
        json_path: Path to the JSON file containing test results
        save_dir: Directory to save the generated plots
    """
    # Read JSON file
    with open(json_path, 'r') as f:
        results = json.load(f)

    plt.style.use('seaborn-v0_8-whitegrid')

    # Define colors for each model
    colors = {
        'pgd': '#FF7043',          # Orange
        'full_pgd': '#FFAB91',     # Light orange
        'base_adam': '#A5D6A7',    # Green
        'input_space_adv': '#90CAF9', # Blue
        'decay_control': '#CE93D8'  # Purple
    }

    # Define display names for models and test types
    model_display_names = {
        'pgd': 'PGD',
        'full_pgd': 'Full PGD',
        'base_adam': 'Base Adam',
        'input_space_adv': 'Input Space Adv',
        'decay_control': 'Weight Decay'
    }

    test_types = {
        'number_format': 'Number Format Tests',
        'novel_operator': 'Novel Operator Tests',
        'cross_boundary': 'Cross Boundary Tests'
    }

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Create a separate plot for each test type
    for test_type, title in test_types.items():
        # Prepare data
        model_names = list(results.keys())
        recalls = [results[m][test_type]['abstention_recall'] for m in model_names]

        # Convert recalls to percentages
        recalls = [r * 100 for r in recalls]

        x = np.arange(len(model_names))

        fig, ax = plt.subplots(figsize=(12, 7))

        # Plot bars
        bar_width = 0.7
        bars = ax.bar(x, recalls, width=bar_width,
                     color=[colors[name] for name in model_names])

        # Configure axes
        y_max = max(recalls)
        ax.set_ylim(0, y_max * 1.2)  # 20% padding on top
        ax.grid(True, which="major", ls="-", alpha=0.2)

        # Format y-axis to show percentages
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.1f}%'))

        # Configure labels
        ax.set_ylabel('Abstention Recall Rate (%)', fontsize=12, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([model_display_names[name] for name in model_names],
                          fontsize=11, fontweight='bold', rotation=45, ha='right')

        # Add value labels on bars
        for bar, val in zip(bars, recalls):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, height + (y_max * 0.01),
                   f'{val:.1f}%',
                   ha='center', va='bottom',
                   fontsize=11, fontweight='bold',
                   color='black')

        # Add title
        ax.set_title(title, fontsize=14, fontweight='bold', pad=20)

        # Calculate performance comparison for PGD vs others
        pgd_val = results['pgd'][test_type]['abstention_recall'] * 100
        other_models = [m for m in model_names if m != 'pgd']
        best_other = max(results[m][test_type]['abstention_recall'] * 100
                        for m in other_models)

        # Calculate relative difference
        diff_percentage = ((pgd_val - best_other) / max(best_other, 1e-9)) * 100.0

        # Add performance comparison annotation
        is_better = pgd_val > best_other
        annotation_text = (
            f"+{abs(diff_percentage):.0f}% improvement"
            if is_better
            else f"{abs(diff_percentage):.0f}% lower"
        )

        # Position annotation
        pgd_idx = model_names.index('pgd')
        arrow_start = pgd_idx
        text_x = pgd_idx + 1.5

        if is_better:
            text_y = pgd_val * 1.05
            connection_style = "arc3,rad=0.2"
        else:
            text_y = pgd_val * 0.95
            connection_style = "arc3,rad=-0.2"

        ax.annotate(
            annotation_text,
            xy=(arrow_start, pgd_val),
            xytext=(text_x, text_y),
            arrowprops=dict(
                arrowstyle='->',
                color=colors['pgd'],
                linewidth=2,
                connectionstyle=connection_style,
                shrinkA=5,
                shrinkB=5
            ),
            ha='left',
            va='center',
            color=colors['pgd'],
            fontweight='bold',
            fontsize=11
        )

        plt.tight_layout()

        # Save figure
        outfile = os.path.join(save_dir, f'{test_type}_recall_comparison.png')
        plt.savefig(outfile, dpi=300, bbox_inches='tight', pad_inches=0.2)
        plt.close()

In [None]:
def fishers_method(pvalues):
    """Implement Fisher's method for combining p-values."""
    valid_pvals = [p for p in pvalues if p is not None and 0 < p < 1]

    if not valid_pvals:
        return None

    statistic = -2 * np.sum(np.log(valid_pvals))
    df = 2 * len(valid_pvals)
    combined_p = 1 - chi2.cdf(statistic, df)

    return float(combined_p)

def analyze_ood_test(results, test_type, models):
    """Analyze a specific OOD test type across all seeds."""
    # Extract recalls for this test type
    model_recalls = {model: [] for model in models}
    for result in results:
        for model in models:
            if model in result and test_type in result[model]:
                recall = result[model][test_type]['abstention_recall']
                model_recalls[model].append(recall)

    # Calculate mean performance
    model_means = {model: np.mean(recalls) for model, recalls in model_recalls.items()}

    # Perform pairwise comparisons
    comparison_pvalues = {}
    for i, model1 in enumerate(models):
        for j, model2 in enumerate(models):
            if i < j:  # Only compare each pair once
                recalls1 = model_recalls[model1]
                recalls2 = model_recalls[model2]

                if model_means[model1] > model_means[model2]:
                    better_model = model1
                    worse_model = model2
                    better_recalls = recalls1
                    worse_recalls = recalls2
                else:
                    better_model = model2
                    worse_model = model1
                    better_recalls = recalls2
                    worse_recalls = recalls1

                key = f"{better_model}_better_than_{worse_model}"

                try:
                    _, p_value = mannwhitneyu(better_recalls, worse_recalls, alternative='greater')
                    comparison_pvalues[key] = [float(p_value)]
                except Exception as e:
                    print(f"Error in Mann-Whitney U test for {key}: {str(e)}")
                    comparison_pvalues[key] = [1.0]

    # Apply Fisher's method
    combined_results = {}
    for comparison, pvalues in comparison_pvalues.items():
        combined_p = fishers_method(pvalues)
        combined_results[comparison] = {
            'individual_pvalues': pvalues,
            'combined_pvalue': combined_p,
            'significant': combined_p < 0.05 if combined_p is not None else None
        }

    # Calculate statistics
    mean_recalls_across_seeds = {}
    for model in models:
        values = model_recalls[model]
        mean_recalls_across_seeds[model] = {
            'values': [float(v) for v in values],
            'mean': float(np.mean(values)),
            'std': float(np.std(values))
        }

    return {
        'fishers_method_results': combined_results,
        'mean_recalls_across_seeds': mean_recalls_across_seeds,
        'number_of_seeds': len(results)
    }

def combine_seed_results(results_dir):
    """Combine results from multiple seed files."""
    result_files = list(Path(results_dir).glob('*.json'))
    all_results = []

    print(f"Found {len(result_files)} files")

    # Expected models and test types
    models = ['base_adam', 'input_space_adv', 'decay_control', 'full_pgd', 'pgd']
    test_types = ['number_format', 'novel_operator', 'cross_boundary']

    for file_path in result_files:
        try:
            with open(file_path, 'r') as f:
                data = json.load(f)
                all_results.append(data)
                print(f"Successfully loaded OOD results from {file_path}")
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")

    if not all_results:
        raise ValueError("No results found in the files")

    # Analyze each test type separately
    final_results = {}
    for test_type in test_types:
        final_results[test_type] = analyze_ood_test(all_results, test_type, models)

    return final_results

def analyze_directory(directory_path):
    """Run analysis pipeline on a directory of OOD test results."""
    directory_path = Path(directory_path)
    output_path = directory_path / 'fishers_analysis_ood.json'

    # Run analysis
    final_results = combine_seed_results(directory_path)

    # Save results
    with open(output_path, 'w') as f:
        json.dump(final_results, f, indent=2)

    # Print summary for each test type
    for test_type, results in final_results.items():
        print(f"\nFisher's Method Analysis Summary - {test_type}:")
        print("=" * (42 + len(test_type)))
        print(f"Number of seeds analyzed: {results['number_of_seeds']}")

        print(f"\nMean Abstention Recalls Across Seeds ({test_type}):")
        for model, stats in results['mean_recalls_across_seeds'].items():
            print(f"{model}: {stats['mean']:.4f} ± {stats['std']:.4f}")

        print(f"\nSignificant Results for {test_type} (p < 0.05):")
        for comparison, comp_results in results['fishers_method_results'].items():
            if comp_results['significant']:
                print(f"\n{comparison}:")
                print(f"  Combined p-value: {comp_results['combined_pvalue']:.6f} (SIGNIFICANT)")
                print(f"  Individual p-values: {[f'{p:.6f}' for p in comp_results['individual_pvalues']]}")

    print(f"\nFull results saved to: {output_path}")

    return final_results