In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import plotly.graph_objects as go

# Set seed
torch.manual_seed(42)
np.random.seed(42)

# Simple MNIST Network
class SimpleNet(nn.Module):
    def __init__(self, hidden_dim=20):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Load MNIST subset
def load_mnist(train_samples=1000, batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    
    train_subset = Subset(train_dataset, range(train_samples))
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    
    return train_loader

# Train and record trajectory
def train_and_record(model, train_loader, epochs=30, lr=0.01):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    # Helper to get/set all parameters as single vector
    def get_params():
        return torch.cat([p.data.view(-1) for p in model.parameters()])
    
    # Record trajectory
    initial_params = get_params().clone()
    param_trajectory = [initial_params.clone()]
    losses = []
    
    print("Training...")
    for epoch in range(epochs):
        epoch_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        param_trajectory.append(get_params().clone())
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
    
    # Get PCA directions for projection
    param_changes = torch.stack([p - initial_params for p in param_trajectory[1:]])
    U, S, V = torch.svd(param_changes.t())
    direction1 = U[:, 0]
    direction2 = U[:, 1] if U.shape[1] > 1 else torch.randn_like(direction1)
    direction2 = direction2 - direction1.dot(direction2) * direction1
    direction2 = direction2 / direction2.norm()
    
    # Project trajectory to 2D
    trajectory_2d = []
    for params in param_trajectory[1:]:  # Skip initial (it's at origin)
        diff = params - initial_params
        alpha = diff.dot(direction1).item()
        beta = diff.dot(direction2).item()
        trajectory_2d.append([alpha, beta])
    
    return trajectory_2d, losses, direction1, direction2, initial_params

# Compute loss landscape
def compute_landscape(model, train_loader, initial_params, direction1, direction2, 
                     trajectory_2d, resolution=40):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    # Helper to set parameters
    def set_params(params):
        idx = 0
        for p in model.parameters():
            size = p.numel()
            p.data = params[idx:idx+size].view_as(p)
            idx += size
    
    # Get scale from trajectory
    alphas = [t[0] for t in trajectory_2d]
    betas = [t[1] for t in trajectory_2d]
    scale = max(abs(min(alphas)), abs(max(alphas)), abs(min(betas)), abs(max(betas)), 0.1) * 1.5
    
    # Create grid
    alpha_range = np.linspace(-scale, scale, resolution)
    beta_range = np.linspace(-scale, scale, resolution)
    Alpha, Beta = np.meshgrid(alpha_range, beta_range)
    Loss_surface = np.zeros_like(Alpha)
    
    print(f"Computing landscape ({resolution}x{resolution})...")
    
    # Compute loss at each point
    for i in range(resolution):
        if i % 10 == 0:
            print(f"  {i}/{resolution}")
        for j in range(resolution):
            # Set parameters to this point
            test_params = initial_params + alpha_range[i] * direction1 + beta_range[j] * direction2
            set_params(test_params)
            
            # Compute loss (on subset for speed)
            total_loss = 0
            count = 0
            with torch.no_grad():
                for batch_idx, (data, target) in enumerate(train_loader):
                    if batch_idx >= 5:  # Use only first 5 batches
                        break
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss = criterion(output, target)
                    total_loss += loss.item()
                    count += 1
            
            Loss_surface[j, i] = total_loss / count
    
    return Alpha, Beta, Loss_surface

# Create 3D plot
def create_3d_plot(Alpha, Beta, Loss_surface, trajectory_2d, losses):
    fig = go.Figure()
    
    # Add surface
    fig.add_trace(go.Surface(
        x=Alpha, y=Beta, z=Loss_surface,
        colorscale='Viridis',
        opacity=0.9,
        contours={"z": {"show": True, "project": {"z": True}}}
    ))
    
    # Add trajectory
    x_traj = [t[0] for t in trajectory_2d]
    y_traj = [t[1] for t in trajectory_2d]
    
    fig.add_trace(go.Scatter3d(
        x=x_traj, y=y_traj, z=losses,
        mode='lines+markers',
        line=dict(color='red', width=6),
        marker=dict(size=4, color='red'),
        name='Training Path'
    ))
    
    # Mark start and end
    fig.add_trace(go.Scatter3d(
        x=[0], y=[0], z=[losses[0]],  # Start at origin
        mode='markers',
        marker=dict(size=10, color='lime', symbol='diamond'),
        name='Start'
    ))
    
    fig.add_trace(go.Scatter3d(
        x=[x_traj[-1]], y=[y_traj[-1]], z=[losses[-1]],
        mode='markers',
        marker=dict(size=10, color='cyan', symbol='square'),
        name='End'
    ))
    
    fig.update_layout(
        title=f"MNIST Loss Landscape: {losses[0]:.3f} → {losses[-1]:.3f}",
        scene=dict(
            xaxis_title="PC1", yaxis_title="PC2", zaxis_title="Loss",
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.3))
        ),
        height=700
    )
    
    return fig

# ============================================
# MAIN EXECUTION
# ============================================

# Load data
print("Loading MNIST...")
train_loader = load_mnist(train_samples=1000, batch_size=64)

# Create and train model
model = SimpleNet(hidden_dim=20)
trajectory_2d, losses, dir1, dir2, init_params = train_and_record(
    model, train_loader, epochs=30, lr=0.01
)

# Compute landscape
Alpha, Beta, Loss_surface = compute_landscape(
    model, train_loader, init_params, dir1, dir2, trajectory_2d, resolution=40
)

# Create and show plot
fig = create_3d_plot(Alpha, Beta, Loss_surface, trajectory_2d, losses)
fig.show()

Loading MNIST...
Training...
Epoch 0: Loss = 2.1914
Epoch 10: Loss = 0.6710
Epoch 20: Loss = 0.4265
Computing landscape (40x40)...
  0/40
  10/40
  20/40
  30/40


# Visualize Loss Plane

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import plotly.graph_objects as go
from scipy.ndimage import gaussian_filter

# Set seed
torch.manual_seed(42)
np.random.seed(42)

# Simple MNIST Network
class SimpleNet(nn.Module):
    def __init__(self, hidden_dim=20):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

# Load MNIST subset
def load_mnist(train_samples=1000, batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    
    train_subset = Subset(train_dataset, range(train_samples))
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    
    return train_loader

def get_params(model):
    """Extract all parameters from model as a single vector"""
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))
    return torch.cat(params)

def set_params(model, params_vector):
    """Set model parameters from a single vector"""
    idx = 0
    for param in model.parameters():
        param_size = param.numel()
        param.data = params_vector[idx:idx+param_size].view_as(param)
        idx += param_size

def create_random_direction(model):
    """Create a random direction vector with same dimension as model parameters"""
    direction = []
    for param in model.parameters():
        direction.append(torch.randn_like(param).view(-1))
    return torch.cat(direction)

def filter_normalize(direction, model):
    """
    Filter normalization - the KEY innovation from Li et al. 2018
    Normalizes each filter independently to respect the scale of network layers
    """
    normalized_direction = []
    idx = 0
    
    for param in model.parameters():
        param_size = param.numel()
        param_direction = direction[idx:idx+param_size]
        
        if len(param.shape) >= 2:  # FC or Conv layer (has filters)
            # Reshape to match parameter shape
            param_direction = param_direction.view_as(param)
            param_weights = param.data
            
            # Normalize each filter (row for FC layers)
            for i in range(param.shape[0]):
                filter_dir = param_direction[i]
                filter_weight = param_weights[i]
                
                # Compute norms
                filter_dir_flat = filter_dir.view(-1)
                filter_weight_flat = filter_weight.view(-1)
                
                dir_norm = filter_dir_flat.norm()
                weight_norm = filter_weight_flat.norm()
                
                # Apply filter normalization: scale direction to match weight norm
                if dir_norm > 0:
                    filter_dir_normalized = filter_dir * (weight_norm / dir_norm)
                    param_direction[i] = filter_dir_normalized
            
            normalized_direction.append(param_direction.view(-1))
        else:
            # Bias terms - no normalization needed
            normalized_direction.append(param_direction)
        
        idx += param_size
    
    return torch.cat(normalized_direction)

def compute_loss_landscape(model, train_loader, resolution=51, scale=1.0):
    """
    Compute loss landscape following Li et al. 2018 methodology
    Uses random filter-normalized directions
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    # Store original parameters
    initial_params = get_params(model)
    
    # Create two random directions
    print("Creating random directions...")
    d1 = create_random_direction(model)
    d2 = create_random_direction(model)
    
    # Apply filter normalization (KEY step from paper)
    print("Applying filter normalization...")
    direction1 = filter_normalize(d1, model)
    direction2 = filter_normalize(d2, model)
    
    # Make directions orthogonal
    direction2 = direction2 - direction1.dot(direction2) * direction1
    direction2 = direction2 / direction2.norm()
    
    # Create evaluation grid
    alpha_range = np.linspace(-scale, scale, resolution)
    beta_range = np.linspace(-scale, scale, resolution)
    Alpha, Beta = np.meshgrid(alpha_range, beta_range)
    Loss_surface = np.zeros_like(Alpha)
    
    print(f"Computing loss landscape ({resolution}x{resolution} grid)...")
    print("This may take a few minutes...")
    
    # Evaluate loss at each grid point
    for i in range(resolution):
        if i % 10 == 0:
            print(f"  Progress: {i}/{resolution}")
        for j in range(resolution):
            # Perturb parameters along the two directions
            perturbed_params = initial_params + \
                              alpha_range[i] * direction1 + \
                              beta_range[j] * direction2
            
            # Set model to these parameters
            set_params(model, perturbed_params)
            
            # Compute loss at this point
            total_loss = 0
            total_samples = 0
            
            with torch.no_grad():
                for data, target in train_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    loss = criterion(output, target)
                    total_loss += loss.item() * data.size(0)
                    total_samples += data.size(0)
            
            Loss_surface[j, i] = total_loss / total_samples
    
    # Restore original parameters
    set_params(model, initial_params)
    
    # Apply minimal Gaussian smoothing (as mentioned in paper)
    Loss_surface = gaussian_filter(Loss_surface, sigma=1.0)
    
    print("Loss landscape computation complete!")
    
    return Alpha, Beta, Loss_surface

def visualize_loss_landscape(Alpha, Beta, Loss_surface, title="Loss Landscape (Li et al. 2018)"):
    """
    Create 3D visualization of the loss landscape
    """
    fig = go.Figure()
    
    # Add the loss surface
    fig.add_trace(go.Surface(
        x=Alpha,
        y=Beta,
        z=Loss_surface,
        colorscale='Viridis',
        opacity=0.9,
        contours={
            "z": {
                "show": True,
                "usecolormap": True,
                "project": {"z": True},
                "width": 2
            }
        },
        colorbar=dict(
            title="Loss",
            x=1.02
        )
    ))
    
    # Mark the center point (current parameters)
    center_loss = Loss_surface[len(Loss_surface)//2, len(Loss_surface[0])//2]
    fig.add_trace(go.Scatter3d(
        x=[0],
        y=[0],
        z=[center_loss],
        mode='markers',
        marker=dict(size=8, color='red', symbol='diamond'),
        name='Center Point'
    ))
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title="Random Direction 1 (filter normalized)",
            yaxis_title="Random Direction 2 (filter normalized)",
            zaxis_title="Loss",
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.3)
            ),
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=0.5)
        ),
        height=700,
        showlegend=True
    )
    
    return fig

def train_model(model, train_loader, epochs=30, lr=0.01):
    """
    Optional: Train the model before visualizing its loss landscape
    This gives a more interesting landscape around a trained minimum
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    print(f"Training for {epochs} epochs...")
    for epoch in range(epochs):
        epoch_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        if epoch % 10 == 0:
            avg_loss = epoch_loss / len(train_loader)
            print(f"  Epoch {epoch}: Loss = {avg_loss:.4f}")
    
    print("Training complete!")
    return model

# ============================================
# MAIN EXECUTION
# ============================================

if __name__ == "__main__":
    # Load MNIST data
    print("Loading MNIST dataset...")
    train_loader = load_mnist(train_samples=1000, batch_size=64)
    
    # Create model
    model = SimpleNet(hidden_dim=20)
    
    # Option 1: Visualize landscape at random initialization
    print("\n" + "="*50)
    print("Visualizing loss landscape at RANDOM INITIALIZATION")
    print("="*50)
    Alpha_init, Beta_init, Loss_init = compute_loss_landscape(
        model, train_loader, resolution=41, scale=1.0
    )
    
    fig_init = visualize_loss_landscape(
        Alpha_init, Beta_init, Loss_init,
        title="Loss Landscape at Random Initialization"
    )
    fig_init.show()
    
    # Option 2: Train model then visualize landscape around trained minimum
    print("\n" + "="*50)
    print("Training model, then visualizing landscape around TRAINED MINIMUM")
    print("="*50)
    
    model = train_model(model, train_loader, epochs=30, lr=0.01)
    
    Alpha_trained, Beta_trained, Loss_trained = compute_loss_landscape(
        model, train_loader, resolution=41, scale=1.0
    )
    
    fig_trained = visualize_loss_landscape(
        Alpha_trained, Beta_trained, Loss_trained,
        title="Loss Landscape around Trained Minimum"
    )
    fig_trained.show()
    
    # Print some statistics
    print("\n" + "="*50)
    print("LANDSCAPE STATISTICS")
    print("="*50)
    print(f"At initialization:")
    print(f"  Loss range: [{Loss_init.min():.3f}, {Loss_init.max():.3f}]")
    print(f"  Center loss: {Loss_init[len(Loss_init)//2, len(Loss_init[0])//2]:.3f}")
    
    print(f"\nAfter training:")
    print(f"  Loss range: [{Loss_trained.min():.3f}, {Loss_trained.max():.3f}]")
    print(f"  Center loss: {Loss_trained[len(Loss_trained)//2, len(Loss_trained[0])//2]:.3f}")
    
    # Check for convexity (negative values indicate non-convexity)
    center_i, center_j = len(Loss_trained)//2, len(Loss_trained[0])//2
    if center_i > 0 and center_i < len(Loss_trained)-1:
        curvature = (Loss_trained[center_i+1, center_j] + 
                    Loss_trained[center_i-1, center_j] - 
                    2*Loss_trained[center_i, center_j])
        print(f"  Approximate curvature: {curvature:.4f}")
        if curvature > 0:
            print("  → Locally convex")
        else:
            print("  → Locally non-convex")
    
    print("\n✅ Done! Rotate the 3D plots to explore the loss landscapes.")

Loading MNIST dataset...

Visualizing loss landscape at RANDOM INITIALIZATION
Creating random directions...
Applying filter normalization...
Computing loss landscape (41x41 grid)...
This may take a few minutes...
  Progress: 0/41
  Progress: 10/41
  Progress: 20/41
  Progress: 30/41
  Progress: 40/41
Loss landscape computation complete!



Training model, then visualizing landscape around TRAINED MINIMUM
Training for 30 epochs...
  Epoch 0: Loss = 2.1878
  Epoch 10: Loss = 0.6761
  Epoch 20: Loss = 0.4268
Training complete!
Creating random directions...
Applying filter normalization...
Computing loss landscape (41x41 grid)...
This may take a few minutes...
  Progress: 0/41
  Progress: 10/41
  Progress: 20/41
  Progress: 30/41
  Progress: 40/41
Loss landscape computation complete!



LANDSCAPE STATISTICS
At initialization:
  Loss range: [2.318, 3.533]
  Center loss: 2.323

After training:
  Loss range: [0.329, 4.531]
  Center loss: 0.330
  Approximate curvature: 0.0002
  → Locally convex

✅ Done! Rotate the 3D plots to explore the loss landscapes.
