# Trajectory Code

In [4]:
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import plotly.graph_objects as go

# Set seed
torch.manual_seed(42)
np.random.seed(42)

class SimpleNet(nn.Module):
    """Simple fully-connected neural network for MNIST classification.
    
    A basic 2-layer neural network with one hidden layer and ReLU activation.
    Suitable for small-scale loss landscape visualization experiments.
    
    Attributes:
        fc1 (nn.Linear): First fully-connected layer (784 -> hidden_dim).
        fc2 (nn.Linear): Second fully-connected layer (hidden_dim -> 10).
    """
    
    def __init__(self, hidden_dim: int = 20):
        """Initializes the SimpleNet with specified hidden dimension.
        
        Args:
            hidden_dim (int): Number of neurons in the hidden layer. Default is 20.
        """
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 1, 28, 28) or 
                             (batch_size, 784).
        
        Returns:
            torch.Tensor: Output logits of shape (batch_size, 10).
        """
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

def load_mnist(train_samples: int = 1000, batch_size: int = 64) -> DataLoader:
    """Loads a subset of MNIST dataset for training.
    
    Creates a DataLoader for a subset of the MNIST training set with standard
    normalization applied. Useful for quick experiments and visualizations.
    
    Args:
        train_samples (int): Number of training samples to use from MNIST.
                           Default is 1000. Must be <= 60000.
        batch_size (int): Batch size for the DataLoader. Default is 64.
    
    Returns:
        DataLoader: PyTorch DataLoader containing the MNIST subset.
    
    Raises:
        ValueError: If train_samples > 60000 or batch_size <= 0.
    
    Example:
        >>> train_loader = load_mnist(train_samples=5000, batch_size=128)
        >>> for batch_idx, (data, target) in enumerate(train_loader):
        ...     print(f"Batch {batch_idx}: data shape {data.shape}")
    """
    if train_samples > 60000:
        raise ValueError("train_samples cannot exceed 60000 (MNIST training set size)")
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])
    
    train_dataset = torchvision.datasets.MNIST(
        root='./data', 
        train=True, 
        download=True, 
        transform=transform
    )
    
    train_subset = Subset(train_dataset, range(train_samples))
    train_loader = DataLoader(
        train_subset, 
        batch_size=batch_size, 
        shuffle=True
    )
    
    return train_loader


def filter_normalize(direction: torch.Tensor, 
                    model: nn.Module, 
                    params_at_point: Optional[torch.Tensor] = None) -> torch.Tensor:
    """Applies filter normalization to a direction vector.
    
    Implements the filter normalization technique from Li et al. 2018, which
    normalizes each filter (row in weight matrices) independently to respect
    the scale of network layers. This enables meaningful comparisons between
    loss landscapes of different architectures.
    
    Args:
        direction (torch.Tensor): Direction vector in parameter space to normalize.
                                 Must have same dimension as model parameters.
        model (nn.Module): Neural network model providing parameter structure.
        params_at_point (torch.Tensor, optional): Reference parameters for normalization.
                                                  If None, uses current model parameters.
                                                  Default is None.
    
    Returns:
        torch.Tensor: Filter-normalized direction vector with same shape as input.
    
    Raises:
        ValueError: If direction dimension doesn't match model parameter count.
    
    Note:
        Filter normalization formula: d_ij <- (d_ij / ||d_ij||) * ||θ_ij||
        where d_ij is the j-th filter of i-th layer in direction d,
        and θ_ij is the corresponding filter in the parameter vector.
    
    Example:
        >>> model = SimpleNet(hidden_dim=20)
        >>> direction = torch.randn(sum(p.numel() for p in model.parameters()))
        >>> normalized_dir = filter_normalize(direction, model)
    """
    if params_at_point is None:
        # Use current model parameters
        params_at_point = torch.cat([p.data.view(-1) for p in model.parameters()])
    
    total_params = sum(p.numel() for p in model.parameters())
    if direction.numel() != total_params:
        raise ValueError(f"Direction size {direction.numel()} doesn't match "
                        f"model parameters {total_params}")
    
    normalized_direction = []
    idx = 0
    
    # Iterate through model parameters to get shapes
    for param in model.parameters():
        param_size = param.numel()
        param_shape = param.shape
        
        # Extract the portion of direction vector for this parameter
        dir_param = direction[idx:idx+param_size].view(param_shape)
        
        # Extract corresponding weights from params_at_point
        weight_param = params_at_point[idx:idx+param_size].view(param_shape)
        
        if len(param_shape) >= 2:  # Weight matrix (FC or Conv layer)
            # For FC layers: each row is a filter (one per output neuron)
            for i in range(param_shape[0]):
                filter_dir = dir_param[i]
                filter_weight = weight_param[i]
                
                # Compute norms
                dir_norm = filter_dir.norm()
                weight_norm = filter_weight.norm()
                
                # Apply filter normalization
                if dir_norm > 0:
                    dir_param[i] = filter_dir * (weight_norm / dir_norm)
            
            normalized_direction.append(dir_param.view(-1))
        else:
            # Bias vector - no normalization needed
            normalized_direction.append(dir_param.view(-1))
        
        idx += param_size
    
    return torch.cat(normalized_direction)


def train_and_record(model: nn.Module,
                    train_loader: DataLoader,
                    epochs: int = 30,
                    lr: float = 0.01,
                    use_filter_norm: bool = True) -> Tuple[List[List[float]], 
                                                           List[float], 
                                                           torch.Tensor, 
                                                           torch.Tensor, 
                                                           torch.Tensor]:
    """Trains model and records parameter trajectory for visualization.
    
    Trains a neural network while recording the parameter trajectory at each epoch.
    Uses PCA to find the two principal directions of parameter change, optionally
    applying filter normalization for scale-invariant visualization.
    
    Args:
        model (nn.Module): Neural network model to train.
        train_loader (DataLoader): DataLoader containing training data.
        epochs (int): Number of training epochs. Default is 30.
        lr (float): Learning rate for SGD optimizer. Default is 0.01.
        use_filter_norm (bool): Whether to apply filter normalization to PCA directions.
                               Default is True.
    
    Returns:
        Tuple containing:
            - trajectory_2d (List[List[float]]): 2D trajectory points projected onto
                                                 PCA directions, shape (epochs, 2).
            - losses (List[float]): Average loss value at each epoch.
            - direction1 (torch.Tensor): First principal component direction
                                        (possibly filter-normalized).
            - direction2 (torch.Tensor): Second principal component direction
                                        (possibly filter-normalized and orthogonalized).
            - initial_params (torch.Tensor): Initial parameter values before training.
    
    Note:
        The trajectory is projected onto a 2D subspace defined by the top 2
        principal components of the parameter changes during training.
    
    Example:
        >>> model = SimpleNet(hidden_dim=20)
        >>> train_loader = load_mnist(train_samples=1000)
        >>> trajectory_2d, losses, dir1, dir2, init_params = train_and_record(
        ...     model, train_loader, epochs=50, lr=0.1, use_filter_norm=True
        ... )
        >>> print(f"Final loss: {losses[-1]:.4f}")
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    # Helper to get all parameters as single vector
    def get_params() -> torch.Tensor:
        """Extracts all model parameters as a single flattened vector."""
        return torch.cat([p.data.view(-1) for p in model.parameters()])
    
    # Record trajectory
    initial_params = get_params().clone()
    param_trajectory = [initial_params.clone()]
    losses = []
    
    print(f"Training (filter_norm={use_filter_norm})...")
    for epoch in range(epochs):
        epoch_loss = 0.0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        param_trajectory.append(get_params().clone())
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
    
    # Get PCA directions for projection
    param_changes = torch.stack([p - initial_params for p in param_trajectory[1:]])
    U, S, V = torch.svd(param_changes.t())
    direction1 = U[:, 0]
    direction2 = U[:, 1] if U.shape[1] > 1 else torch.randn_like(direction1)
    
    # Apply filter normalization if requested
    if use_filter_norm:
        print("Applying filter normalization to PCA directions...")
        # Use initial params as the reference point for normalization
        direction1 = filter_normalize(direction1, model, initial_params)
        direction2 = filter_normalize(direction2, model, initial_params)
    
    # Ensure orthogonality after normalization
    direction2 = direction2 - direction1.dot(direction2) * direction1 / direction1.dot(direction1)
    
    # Re-normalize direction2 if filter norm is on
    if use_filter_norm:
        direction2 = filter_normalize(direction2, model, initial_params)
    else:
        direction2 = direction2 / direction2.norm()
    
    # Project trajectory to 2D
    trajectory_2d = []
    for params in param_trajectory[1:]:  # Skip initial (it's at origin)
        diff = params - initial_params
        alpha = diff.dot(direction1).item()
        beta = diff.dot(direction2).item()
        trajectory_2d.append([alpha, beta])
    
    return trajectory_2d, losses, direction1, direction2, initial_params


def compute_landscape(model: nn.Module,
                     train_loader: DataLoader,
                     initial_params: torch.Tensor,
                     direction1: torch.Tensor,
                     direction2: torch.Tensor,
                     trajectory_2d: List[List[float]],
                     resolution: int = 40) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Computes loss landscape on a 2D grid defined by two directions.
    
    Evaluates the loss function at grid points in a 2D subspace of the parameter
    space, defined by two direction vectors. The grid scale is automatically
    determined from the trajectory extent.
    
    Args:
        model (nn.Module): Neural network model for loss evaluation.
        train_loader (DataLoader): DataLoader containing data for loss computation.
        initial_params (torch.Tensor): Center point of the grid in parameter space.
        direction1 (torch.Tensor): First direction vector defining the grid.
        direction2 (torch.Tensor): Second direction vector defining the grid.
        trajectory_2d (List[List[float]]): 2D trajectory points used to determine
                                          grid scale.
        resolution (int): Number of grid points along each axis. Default is 40.
                         Total evaluations will be resolution^2.
    
    Returns:
        Tuple containing:
            - Alpha (np.ndarray): Meshgrid x-coordinates, shape (resolution, resolution).
            - Beta (np.ndarray): Meshgrid y-coordinates, shape (resolution, resolution).
            - Loss_surface (np.ndarray): Loss values at each grid point,
                                        shape (resolution, resolution).
    
    Note:
        For computational efficiency, only the first 5 batches are used for
        loss evaluation at each grid point.
    
    Example:
        >>> Alpha, Beta, Loss_surface = compute_landscape(
        ...     model, train_loader, init_params, dir1, dir2, trajectory_2d,
        ...     resolution=50
        ... )
        >>> print(f"Loss range: [{Loss_surface.min():.2f}, {Loss_surface.max():.2f}]")
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    def set_params(params: torch.Tensor) -> None:
        """Sets model parameters from a flattened parameter vector.
        
        Args:
            params (torch.Tensor): Flattened parameter vector to set.
        """
        idx = 0
        for p in model.parameters():
            size = p.numel()
            p.data = params[idx:idx+size].view_as(p)
            idx += size
    
    # Get scale from trajectory
    alphas = [t[0] for t in trajectory_2d]
    betas = [t[1] for t in trajectory_2d]
    
    # Determine grid scale based on trajectory extent
    if alphas and betas:
        scale = max(abs(min(alphas)), abs(max(alphas)), 
                   abs(min(betas)), abs(max(betas)), 0.1) * 1.5
    else:
        scale = 1.0  # Default scale if no trajectory
    
    # Create grid
    alpha_range = np.linspace(-scale, scale, resolution)
    beta_range = np.linspace(-scale, scale, resolution)
    Alpha, Beta = np.meshgrid(alpha_range, beta_range)
    Loss_surface = np.zeros_like(Alpha)
    
    print(f"Computing landscape ({resolution}x{resolution})...")
    print(f"Grid scale: ±{scale:.3f}")
    
    # Compute loss at each point
    for i in range(resolution):
        if i % 10 == 0:
            print(f"  Progress: {i}/{resolution}")
        for j in range(resolution):
            # Set parameters to this point
            test_params = initial_params + alpha_range[i] * direction1 + beta_range[j] * direction2
            set_params(test_params)
            
            # Compute loss (on subset for speed)
            total_loss = 0.0
            count = 0
            with torch.no_grad():
                for batch_idx, (data, target) in enumerate(train_loader):
                    if batch_idx >= 5:  # Use only first 5 batches for efficiency
                        break
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss = criterion(output, target)
                    total_loss += loss.item()
                    count += 1
            
            Loss_surface[j, i] = total_loss / count if count > 0 else float('inf')
    
    # Restore original parameters
    set_params(initial_params)
    
    return Alpha, Beta, Loss_surface


def create_3d_plot(Alpha: np.ndarray,
                  Beta: np.ndarray,
                  Loss_surface: np.ndarray,
                  trajectory_2d: List[List[float]]) -> go.Figure:
    """Creates interactive 3D visualization of loss landscape with trajectory.
    
    Generates a Plotly 3D surface plot of the loss landscape with the training
    trajectory overlaid. The trajectory z-values are interpolated from the
    computed loss surface for consistency.
    
    Args:
        Alpha (np.ndarray): Meshgrid x-coordinates from compute_landscape,
                          shape (resolution, resolution).
        Beta (np.ndarray): Meshgrid y-coordinates from compute_landscape,
                         shape (resolution, resolution).
        Loss_surface (np.ndarray): Loss values at each grid point,
                                  shape (resolution, resolution).
        trajectory_2d (List[List[float]]): 2D trajectory points in the PCA subspace,
                                          each point is [alpha, beta].
    
    Returns:
        go.Figure: Plotly figure object containing the 3D visualization.
                  Can be displayed with fig.show() or saved with fig.write_html().
    
    Note:
        The trajectory z-values are interpolated from the loss surface rather than
        using the training losses directly to ensure visual consistency.
    
    Example:
        >>> fig = create_3d_plot(Alpha, Beta, Loss_surface, trajectory_2d, losses)
        >>> fig.show()  # Display interactive plot
        >>> fig.write_html("loss_landscape.html")  # Save as HTML
    """
    fig = go.Figure()
    
    # Add loss surface
    fig.add_trace(go.Surface(
        x=Alpha,
        y=Beta,
        z=Loss_surface,
        colorscale='Viridis',
        opacity=0.9,
        contours={
            "z": {
                "show": True,
                "usecolormap": True,
                "project": {"z": True},
                "width": 2,
                "highlight": True
            }
        },
        name='Loss Surface',
        showscale=True,
        colorbar=dict(
            title="Loss",
            tickmode="linear",
            tick0=0,
            dtick=0.5
        )
    ))
    
    # Extract trajectory coordinates
    x_traj = [t[0] for t in trajectory_2d]
    y_traj = [t[1] for t in trajectory_2d]
    z_traj = []
    
    # Interpolate z-values from the loss surface
    for x, y in zip(x_traj, y_traj):
        # Find nearest grid point
        i = np.argmin(np.abs(Alpha[0, :] - x))
        j = np.argmin(np.abs(Beta[:, 0] - y))
        z_traj.append(Loss_surface[j, i])
    
    # Add training trajectory
    if x_traj and y_traj and z_traj:
        fig.add_trace(go.Scatter3d(
            x=x_traj,
            y=y_traj,
            z=z_traj,
            mode='lines+markers',
            line=dict(
                color='red',
                width=6
            ),
            marker=dict(
                size=4,
                color='red',
                symbol='circle'
            ),
            name='Training Path'
        ))
        
        # Mark start and end points
        fig.add_trace(go.Scatter3d(
            x=[x_traj[0]],
            y=[y_traj[0]],
            z=[z_traj[0]],
            mode='markers',
            marker=dict(
                size=10,
                color='green',
                symbol='diamond'
            ),
            name='Start',
            showlegend=True
        ))
        
        fig.add_trace(go.Scatter3d(
            x=[x_traj[-1]],
            y=[y_traj[-1]],
            z=[z_traj[-1]],
            mode='markers',
            marker=dict(
                size=10,
                color='blue',
                symbol='square'
            ),
            name='End',
            showlegend=True
        ))
    
    # Update layout
    fig.update_layout(
        title={
            'text': "Loss Landscape with Training Trajectory",
            'x': 0.5,
            'xanchor': 'center'
        },
        scene=dict(
            xaxis=dict(
                title="PC1 Direction",
                showbackground=True,
                backgroundcolor="rgb(230, 230,230)"
            ),
            yaxis=dict(
                title="PC2 Direction",
                showbackground=True,
                backgroundcolor="rgb(230, 230,230)"
            ),
            zaxis=dict(
                title="Loss",
                showbackground=True,
                backgroundcolor="rgb(230, 230,230)"
            ),
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.3)
            )
        ),
        showlegend=True,
        legend=dict(
            x=0.02,
            y=0.98,
            bgcolor="rgba(255, 255, 255, 0.8)",
            bordercolor="Black",
            borderwidth=1
        ),
        height=700,
        width=900
    )
    
    return fig

In [6]:
# Load data
print("Loading MNIST...")
train_loader = load_mnist(train_samples=1000, batch_size=64)

# Create and train model
model = SimpleNet(hidden_dim=20)
trajectory_2d, losses, dir1, dir2, init_params = train_and_record(
    model, train_loader, epochs=30, lr=0.01, use_filter_norm=False
)

# Compute landscape
Alpha, Beta, Loss_surface = compute_landscape(
    model, train_loader, init_params, dir1, dir2, trajectory_2d, resolution=50
)

# Create and show plot
fig = create_3d_plot(Alpha, Beta, Loss_surface, trajectory_2d)
fig.show()

Loading MNIST...
Training (filter_norm=False)...
Epoch 0: Loss = 2.1873
Epoch 10: Loss = 0.6790
Epoch 20: Loss = 0.4249
Computing landscape (50x50)...
Grid scale: ±3.477
  Progress: 0/50
  Progress: 10/50
  Progress: 20/50
  Progress: 30/50
  Progress: 40/50


# Visualize Loss Plane

## Loss Landscape Visualization: Following Li et al. 2018

This notebook implements the loss landscape visualization technique from the paper:
"Visualizing the Loss Landscape of Neural Nets" by Li et al. (NeurIPS 2018)

## Key Concepts:
- **Filter Normalization**: The paper's main contribution for fair landscape comparison
- **Random Directions**: Using random (not gradient) directions to explore the landscape
- **Effect of Architecture**: How depth and skip connections affect landscape geometry

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import plotly.graph_objects as go
from scipy.ndimage import gaussian_filter

# Set seed
torch.manual_seed(42)
np.random.seed(42)

In [8]:
class SimpleNet(nn.Module):
    """Simple fully-connected neural network for MNIST classification.
    
    A basic 2-layer neural network with one hidden layer and ReLU activation.
    Suitable for small-scale loss landscape visualization experiments.
    
    Attributes:
        fc1 (nn.Linear): First fully-connected layer (784 -> hidden_dim).
        fc2 (nn.Linear): Second fully-connected layer (hidden_dim -> 10).
    """
    
    def __init__(self, hidden_dim: int = 20):
        """Initializes the SimpleNet with specified hidden dimension.
        
        Args:
            hidden_dim (int): Number of neurons in the hidden layer. Default is 20.
        """
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 1, 28, 28) or 
                             (batch_size, 784).
        
        Returns:
            torch.Tensor: Output logits of shape (batch_size, 10).
        """
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


In [9]:
def get_params(model: nn.Module) -> torch.Tensor:
   """Extracts all parameters from a model as a single flattened vector.
   
   Concatenates all model parameters into a single 1D tensor, preserving
   the order of parameters as they appear in model.parameters(). This is
   useful for optimization visualization and parameter space analysis.
   
   Args:
       model (nn.Module): PyTorch model from which to extract parameters.
   
   Returns:
       torch.Tensor: Flattened vector containing all model parameters,
                    shape (total_params,) where total_params is the sum
                    of all parameter sizes.
   
   Example:
       >>> model = nn.Linear(10, 5)
       >>> params = get_params(model)
       >>> print(f"Total parameters: {params.numel()}")  # 10*5 + 5 = 55
       
   Note:
       The returned tensor shares memory with the original parameters.
       Use .clone() if you need a separate copy.
   """
   params = []
   for param in model.parameters():
       params.append(param.data.view(-1))
   return torch.cat(params)


def set_params(model: nn.Module, params_vector: torch.Tensor) -> None:
   """Sets model parameters from a single flattened vector.
   
   Updates all model parameters in-place using values from a flattened
   parameter vector. The vector must have the exact same size as the
   total number of model parameters, and values are assigned in the
   same order as model.parameters().
   
   Args:
       model (nn.Module): PyTorch model whose parameters will be updated.
       params_vector (torch.Tensor): Flattened vector containing new parameter
                                    values, shape (total_params,). Must match
                                    the total parameter count of the model.
   
   Raises:
       RuntimeError: If params_vector size doesn't match model parameter count.
       
   Example:
       >>> model = nn.Linear(10, 5)
       >>> original_params = get_params(model).clone()
       >>> perturbed_params = original_params + 0.1 * torch.randn_like(original_params)
       >>> set_params(model, perturbed_params)
       
   Note:
       This function modifies the model parameters in-place. The model's
       computational graph will be broken after calling this function.
   """
   idx = 0
   total_size = sum(p.numel() for p in model.parameters())
   
   if params_vector.numel() != total_size:
       raise RuntimeError(f"Parameter vector size {params_vector.numel()} "
                         f"doesn't match model parameters {total_size}")
   
   for param in model.parameters():
       param_size = param.numel()
       param.data = params_vector[idx:idx+param_size].view_as(param).data
       idx += param_size


def create_random_direction(model: nn.Module, 
                         distribution: str = 'gaussian',
                         seed: Optional[int] = None) -> torch.Tensor:
   """Creates a random direction vector with same dimension as model parameters.
   
   Generates a random direction in parameter space, useful for loss landscape
   visualization and random perturbation analysis. Each parameter's direction
   is sampled independently from the specified distribution.
   
   Args:
       model (nn.Module): PyTorch model defining the parameter space dimension.
       distribution (str): Type of random distribution to use.
                         Options: 'gaussian' (default), 'uniform', 'rademacher'.
                         - 'gaussian': Standard normal distribution N(0,1)
                         - 'uniform': Uniform distribution U(-1,1)
                         - 'rademacher': Random ±1 with equal probability
       seed (int, optional): Random seed for reproducibility. If None, uses
                            current random state. Default is None.
   
   Returns:
       torch.Tensor: Random direction vector with same total dimension as
                    model parameters, shape (total_params,).
   
   Raises:
       ValueError: If distribution is not one of the supported types.
       
   Example:
       >>> model = nn.Sequential(nn.Linear(10, 5), nn.Linear(5, 2))
       >>> direction = create_random_direction(model, distribution='gaussian')
       >>> print(f"Direction norm: {direction.norm():.4f}")
       
       >>> # For reproducible random direction
       >>> direction1 = create_random_direction(model, seed=42)
       >>> direction2 = create_random_direction(model, seed=42)
       >>> assert torch.allclose(direction1, direction2)
       
   Note:
       The returned direction is not normalized. Use direction/direction.norm()
       if you need a unit vector.
   """
   if distribution not in ['gaussian', 'uniform', 'rademacher']:
       raise ValueError(f"Unknown distribution: {distribution}. "
                       "Choose from 'gaussian', 'uniform', 'rademacher'")
   
   if seed is not None:
       torch.manual_seed(seed)
   
   direction = []
   for param in model.parameters():
       if distribution == 'gaussian':
           dir_param = torch.randn_like(param)
       elif distribution == 'uniform':
           dir_param = torch.rand_like(param) * 2 - 1  # U(-1, 1)
       else:  # rademacher
           dir_param = torch.sign(torch.randn_like(param))
       
       direction.append(dir_param.view(-1))
   
   return torch.cat(direction)


def filter_normalize(direction: torch.Tensor, 
                    model: nn.Module,
                    params_at_point: Optional[torch.Tensor] = None,
                    epsilon: float = 1e-10) -> torch.Tensor:
   """Applies filter normalization to a direction vector (Li et al. 2018).
   
   Implements the filter normalization technique that enables meaningful
   comparison of loss landscapes across different architectures and scales.
   Each filter (row in weight matrices) is normalized independently to
   respect the natural scale of different network layers.
   
   The normalization formula for each filter:
       d_normalized = (d / ||d||) * ||w||
   
   where d is the direction filter, w is the corresponding weight filter,
   and ||·|| denotes the Frobenius norm.
   
   Args:
       direction (torch.Tensor): Direction vector to normalize, shape (total_params,).
                                Must have same dimension as model parameters.
       model (nn.Module): Neural network model providing parameter structure
                         and shapes for the normalization process.
       params_at_point (torch.Tensor, optional): Reference parameters for computing
                                                 filter norms. If None, uses current
                                                 model parameters. Shape (total_params,).
                                                 Default is None.
       epsilon (float): Small constant to prevent division by zero in normalization.
                       Default is 1e-10.
   
   Returns:
       torch.Tensor: Filter-normalized direction vector, shape (total_params,).
                    Has the same dimension as the input direction.
   
   Raises:
       ValueError: If direction dimension doesn't match model parameter count.
       RuntimeError: If params_at_point is provided but has wrong dimension.
       
   Example:
       >>> # Basic usage with current model parameters
       >>> model = nn.Sequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
       >>> direction = create_random_direction(model)
       >>> normalized_dir = filter_normalize(direction, model)
       
       >>> # Using specific reference point for normalization
       >>> init_params = get_params(model).clone()
       >>> # ... train model ...
       >>> direction = create_random_direction(model)
       >>> normalized_dir = filter_normalize(direction, model, params_at_point=init_params)
       
   Note:
       - For fully connected layers, each row of the weight matrix is treated as a filter
       - For convolutional layers, each output channel is treated as a filter
       - Bias terms are not normalized (they don't have multiple filters)
       - This normalization is crucial for comparing loss landscapes across different
         architectures, as it removes the scale ambiguity in neural networks
         
   Reference:
       Li, H., Xu, Z., Taylor, G., Studer, C., & Goldstein, T. (2018).
       Visualizing the Loss Landscape of Neural Nets. NeurIPS.
   """
   # Validate input dimensions
   total_params = sum(p.numel() for p in model.parameters())
   if direction.numel() != total_params:
       raise ValueError(f"Direction size {direction.numel()} doesn't match "
                       f"model parameter count {total_params}")
   
   if params_at_point is None:
       params_at_point = get_params(model)
   elif params_at_point.numel() != total_params:
       raise RuntimeError(f"params_at_point size {params_at_point.numel()} "
                         f"doesn't match model parameters {total_params}")
   
   normalized_direction = []
   idx = 0
   
   for param in model.parameters():
       param_size = param.numel()
       param_direction = direction[idx:idx+param_size]
       
       if len(param.shape) >= 2:  # Weight matrix (FC or Conv layer has filters)
           # Reshape to match parameter shape
           param_direction = param_direction.view_as(param)
           param_weights = params_at_point[idx:idx+param_size].view_as(param)
           
           # Normalize each filter (row for FC layers, output channel for Conv)
           for i in range(param.shape[0]):
               filter_dir = param_direction[i]
               filter_weight = param_weights[i]
               
               # Compute norms (flatten for correct norm calculation)
               filter_dir_flat = filter_dir.view(-1)
               filter_weight_flat = filter_weight.view(-1)
               
               dir_norm = filter_dir_flat.norm()
               weight_norm = filter_weight_flat.norm()
               
               # Apply filter normalization with epsilon for numerical stability
               if dir_norm > epsilon:
                   filter_dir_normalized = filter_dir * (weight_norm / dir_norm)
                   param_direction[i] = filter_dir_normalized
               # If direction norm is too small, keep it as zero
           
           normalized_direction.append(param_direction.view(-1))
       else:
           # Bias terms or 1D parameters - no normalization needed
           normalized_direction.append(param_direction)
       
       idx += param_size
   
   return torch.cat(normalized_direction)

In [10]:
def compute_loss_landscape(model: nn.Module,
                         train_loader: DataLoader,
                         resolution: int = 51,
                         scale: float = 1.0,
                         use_filter_norm: bool = True,
                         smoothing: float = 0.0,
                         scale_multiplier: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
   """Computes the loss landscape of a neural network in a 2D parameter subspace.
   
   Generates a 2D visualization of the loss landscape by evaluating the loss
   function on a grid of points in a 2-dimensional subspace of the parameter
   space. The subspace is defined by two random direction vectors, which can
   optionally be filter-normalized following Li et al. 2018.
   
   This function is useful for understanding optimization geometry, comparing
   architectures, and visualizing the effect of different training methods
   on the loss landscape structure.
   
   Args:
       model (nn.Module): Neural network model whose loss landscape will be computed.
                         The model parameters define the center point of the landscape.
       train_loader (DataLoader): DataLoader containing training data for loss evaluation.
                                 Only the first few batches are used for efficiency.
       resolution (int): Number of grid points along each axis. Default is 51.
                        Total evaluations will be resolution^2.
                        Higher values give more detail but increase computation time.
       scale (float): Range of the grid in each direction, from -scale to +scale.
                     Default is 1.0. Larger values show wider views of the landscape.
       use_filter_norm (bool): Whether to apply filter normalization to direction vectors.
                              True (default) enables scale-invariant comparison across
                              architectures following Li et al. 2018.
                              False shows raw landscape which may be dominated by scale effects.
       smoothing (float): Gaussian smoothing sigma parameter. Default is 0.0 (no smoothing).
                         Positive values apply smoothing to reduce noise in the landscape.
       scale_multiplier (float): Additional scaling factor for the direction vectors.
                                Default is 1.0. Useful for adjusting the effective
                                step size in parameter space.
   
   Returns:
       Tuple containing:
           - Alpha (np.ndarray): Meshgrid x-coordinates for the grid points,
                                shape (resolution, resolution).
           - Beta (np.ndarray): Meshgrid y-coordinates for the grid points,
                               shape (resolution, resolution).
           - Loss_surface (np.ndarray): Loss values at each grid point,
                                       shape (resolution, resolution).
                                       May contain smoothed values if smoothing > 0.
   
   Raises:
       RuntimeError: If CUDA is requested but not available.
       ValueError: If resolution is less than 2 or scale is non-positive.
       
   Example:
       >>> # Basic usage with filter normalization
       >>> model = SimpleNet(hidden_dim=256)
       >>> train_loader = load_mnist(train_samples=1000, batch_size=64)
       >>> Alpha, Beta, Loss = compute_loss_landscape(
       ...     model, train_loader, resolution=51, scale=1.0, use_filter_norm=True
       ... )
       >>> print(f"Loss range: [{Loss.min():.2f}, {Loss.max():.2f}]")
       
       >>> # Wide view without filter normalization to see raw landscape
       >>> Alpha, Beta, Loss = compute_loss_landscape(
       ...     model, train_loader, resolution=31, scale=5.0, 
       ...     use_filter_norm=False, scale_multiplier=0.1
       ... )
       
       >>> # Smoothed landscape for cleaner visualization
       >>> Alpha, Beta, Loss = compute_loss_landscape(
       ...     model, train_loader, smoothing=2.0
       ... )
   
   Note:
       - Only the first 3 batches of data are used for efficiency. This is a
         approximation that trades accuracy for speed.
       - The landscape is computed around the current model parameters.
       - Filter normalization is crucial for fair comparison between different
         architectures or training stages.
       - Large scale values with small models may cause numerical instability.
       - The function includes error handling for NaN/Inf values that may occur
         with extreme perturbations.
         
   Performance:
       Time complexity: O(resolution^2 * batch_evaluations * forward_pass_time)
       Memory complexity: O(resolution^2 + model_parameters)
       
       Typical timing (on GPU):
       - resolution=51, 3 batches: ~1-2 minutes for small models
       - resolution=101, 3 batches: ~5-10 minutes for small models
       
   Reference:
       Li, H., Xu, Z., Taylor, G., Studer, C., & Goldstein, T. (2018).
       Visualizing the Loss Landscape of Neural Nets. NeurIPS.
       
   See Also:
       filter_normalize: For understanding the normalization technique.
       create_random_direction: For generating direction vectors.
       set_params/get_params: For parameter manipulation.
   """
   # Input validation
   if resolution < 2:
       raise ValueError(f"Resolution must be at least 2, got {resolution}")
   if scale <= 0:
       raise ValueError(f"Scale must be positive, got {scale}")
   
   # Device setup
   device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
   model = model.to(device)
   criterion = nn.CrossEntropyLoss()
   
   # Store original parameters (center point of landscape)
   initial_params = get_params(model)
   
   # Create two random directions for defining the 2D subspace
   d1 = create_random_direction(model)
   d2 = create_random_direction(model)
   
   # Apply normalization strategy based on use_filter_norm
   if use_filter_norm:
       print("Using filter normalization (Li et al. 2018)...")
       # Filter normalization: each filter scaled independently
       direction1 = filter_normalize(d1, model)
       direction2 = filter_normalize(d2, model)
       
       # Apply additional scaling for wider/narrower view
       direction1 = direction1 * scale_multiplier
       direction2 = direction2 * scale_multiplier
   else:
       print("Using raw random directions (no filter norm)...")
       # Global normalization: scale by total parameter norm
       # This shows the "raw" landscape without filter-wise scaling
       param_norm = initial_params.norm()
       direction1 = (d1 / d1.norm()) * param_norm * 0.1 * scale_multiplier
       direction2 = (d2 / d2.norm()) * param_norm * 0.1 * scale_multiplier
   
   # Orthogonalize direction2 with respect to direction1
   # This ensures the two directions span a proper 2D subspace
   direction2 = direction2 - direction1.dot(direction2) / direction1.dot(direction1) * direction1
   
   # Re-normalize direction2 to match direction1's scale
   # This prevents distortion in the visualization
   if use_filter_norm:
       direction2 = direction2 / direction2.norm() * direction1.norm()
   
   # Print diagnostic information
   print(f"Direction norms: d1={direction1.norm():.3f}, d2={direction2.norm():.3f}")
   print(f"Parameter norm: {initial_params.norm():.3f}")
   
   # Create evaluation grid
   alpha_range = np.linspace(-scale, scale, resolution)
   beta_range = np.linspace(-scale, scale, resolution)
   Alpha, Beta = np.meshgrid(alpha_range, beta_range)
   Loss_surface = np.zeros_like(Alpha)
   
   print(f"Computing loss landscape ({resolution}x{resolution} grid, scale={scale})...")
   print(f"Actual parameter perturbation range: [{-scale * direction1.norm():.3f}, {scale * direction1.norm():.3f}]")
   
   # Track min/max for debugging and validation
   min_loss = float('inf')
   max_loss = float('-inf')
   
   # Evaluate loss at each grid point
   for i in range(resolution):
       if i % 10 == 0:
           print(f"  Progress: {i}/{resolution}")
       for j in range(resolution):
           # Compute perturbed parameters at this grid point
           # θ_perturbed = θ_initial + α * direction1 + β * direction2
           perturbed_params = initial_params + \
                             alpha_range[i] * direction1 + \
                             beta_range[j] * direction2
           
           # Set model to these parameters
           set_params(model, perturbed_params)
           
           # Compute loss (using subset of batches for efficiency)
           batch_losses = []
           with torch.no_grad():  # No gradients needed for landscape computation
               for batch_idx, (data, target) in enumerate(train_loader):
                   if batch_idx >= 3:  # Use only first 3 batches for speed
                       break
                   
                   data, target = data.to(device), target.to(device)
                   
                   try:
                       output = model(data)
                       loss = criterion(output, target)
                       
                       # Handle numerical issues
                       if torch.isnan(loss) or torch.isinf(loss):
                           # Use large but finite value for unstable points
                           batch_losses.append(1000.0)
                       else:
                           batch_losses.append(loss.item())
                   except RuntimeError as e:
                       # Catch numerical overflow or other runtime errors
                       # This can happen with extreme perturbations
                       batch_losses.append(1000.0)
           
           # Average loss across batches
           if batch_losses:
               Loss_surface[j, i] = np.mean(batch_losses)
           else:
               Loss_surface[j, i] = 1000.0  # Fallback value
           
           # Track statistics
           min_loss = min(min_loss, Loss_surface[j, i])
           max_loss = max(max_loss, Loss_surface[j, i])
   
   print(f"  Raw loss range: [{min_loss:.3f}, {max_loss:.3f}]")
   
   # Restore original parameters
   # Important: This ensures the model is unchanged after landscape computation
   set_params(model, initial_params)
   
   # Apply optional Gaussian smoothing to reduce noise
   if smoothing > 0:
       print(f"  Applying Gaussian smoothing (sigma={smoothing})...")
       from scipy.ndimage import gaussian_filter
       Loss_surface = gaussian_filter(Loss_surface, sigma=smoothing)
       print(f"  Smoothed loss range: [{Loss_surface.min():.3f}, {Loss_surface.max():.3f}]")
   else:
       print("  No smoothing applied")
   
   return Alpha, Beta, Loss_surface

In [11]:
def visualize_loss_landscape(Alpha, Beta, Loss_surface, title="Loss Landscape (Li et al. 2018)"):
    """
    Create 3D visualization of the loss landscape
    """
    fig = go.Figure()
    
    # Add the loss surface
    fig.add_trace(go.Surface(
        x=Alpha,
        y=Beta,
        z=Loss_surface,
        colorscale='Viridis',
        opacity=0.9,
        contours={
            "z": {
                "show": True,
                "usecolormap": True,
                "project": {"z": True},
                "width": 2
            }
        },
        colorbar=dict(
            title="Loss",
            x=1.20
        )
    ))
    
    # Mark the center point (current parameters)
    center_loss = Loss_surface[len(Loss_surface)//2, len(Loss_surface[0])//2]
    fig.add_trace(go.Scatter3d(
        x=[0],
        y=[0],
        z=[center_loss],
        mode='markers',
        marker=dict(size=8, color='red', symbol='diamond'),
        name='Center Point'
    ))
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title="Random Direction 1 (filter normalized)",
            yaxis_title="Random Direction 2 (filter normalized)",
            zaxis_title="Loss",
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.3)
            ),
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=0.5)
        ),
        height=700,
        showlegend=True
    )
    
    return fig

In [12]:
def train_model(model, train_loader, epochs=30, lr=0.01):
    """
    Optional: Train the model before visualizing its loss landscape
    This gives a more interesting landscape around a trained minimum
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    print(f"Training for {epochs} epochs...")
    for epoch in range(epochs):
        epoch_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        if epoch % 10 == 0:
            avg_loss = epoch_loss / len(train_loader)
            print(f"  Epoch {epoch}: Loss = {avg_loss:.4f}")
    
    print("Training complete!")
    return model

In [13]:
# Load MNIST data
print("Loading MNIST dataset...")
train_loader = load_mnist(train_samples=1000, batch_size=64)

# Create model
model = SimpleNet()

# Option 1: Visualize landscape at random initialization
print("\n" + "="*50)
print("Visualizing loss landscape at RANDOM INITIALIZATION")
print("="*50)
Alpha_init, Beta_init, Loss_init = compute_loss_landscape(
    model, train_loader, resolution=51, scale=1.0, use_filter_norm=True, smoothing=0.9, scale_multiplier=1.0
)

fig_init = visualize_loss_landscape(
    Alpha_init, Beta_init, Loss_init,
    title="Loss Landscape at Random Initialization"
)
fig_init.show()

# Option 2: Train model then visualize landscape around trained minimum
print("\n" + "="*50)
print("Training model, then visualizing landscape around TRAINED MINIMUM")
print("="*50)

model = train_model(model, train_loader, epochs=30, lr=0.01)

Alpha_trained, Beta_trained, Loss_trained = compute_loss_landscape(
    model, train_loader, resolution=51, scale=1.0, use_filter_norm=True, smoothing=0.9, scale_multiplier=1.0
)

fig_trained = visualize_loss_landscape(
    Alpha_trained, Beta_trained, Loss_trained,
    title="Loss Landscape around Trained Minimum"
)
fig_trained.show()

# Print some statistics
print("\n" + "="*50)
print("LANDSCAPE STATISTICS")
print("="*50)
print(f"At initialization:")
print(f"  Loss range: [{Loss_init.min():.3f}, {Loss_init.max():.3f}]")
print(f"  Center loss: {Loss_init[len(Loss_init)//2, len(Loss_init[0])//2]:.3f}")

print(f"\nAfter training:")
print(f"  Loss range: [{Loss_trained.min():.3f}, {Loss_trained.max():.3f}]")
print(f"  Center loss: {Loss_trained[len(Loss_trained)//2, len(Loss_trained[0])//2]:.3f}")

Loading MNIST dataset...

Visualizing loss landscape at RANDOM INITIALIZATION
Using filter normalization (Li et al. 2018)...
Direction norms: d1=5.621, d2=5.621
Parameter norm: 3.181
Computing loss landscape (51x51 grid, scale=1.0)...
Actual parameter perturbation range: [-5.621, 5.621]
  Progress: 0/51
  Progress: 10/51
  Progress: 20/51
  Progress: 30/51
  Progress: 40/51
  Progress: 50/51
  Raw loss range: [2.278, 3.910]
  Applying Gaussian smoothing (sigma=0.9)...
  Smoothed loss range: [2.312, 3.789]



Training model, then visualizing landscape around TRAINED MINIMUM
Training for 30 epochs...
  Epoch 0: Loss = 2.1893
  Epoch 10: Loss = 0.6757
  Epoch 20: Loss = 0.4243
Training complete!
Using filter normalization (Li et al. 2018)...
Direction norms: d1=7.146, d2=7.146
Parameter norm: 4.262
Computing loss landscape (51x51 grid, scale=1.0)...
Actual parameter perturbation range: [-7.146, 7.146]
  Progress: 0/51
  Progress: 10/51
  Progress: 20/51
  Progress: 30/51
  Progress: 40/51
  Progress: 50/51
  Raw loss range: [0.255, 6.620]
  Applying Gaussian smoothing (sigma=0.9)...
  Smoothed loss range: [0.305, 6.373]



LANDSCAPE STATISTICS
At initialization:
  Loss range: [2.312, 3.789]
  Center loss: 2.322

After training:
  Loss range: [0.305, 6.373]
  Center loss: 0.326
