# Demo 1: Complete Hessian Eigenvector-Based Loss Landscape Analysis

This comprehensive notebook demonstrates the **complete end-to-end workflow** for computing Hessian eigenvectors from a trained ALIGNN model and generating 2D loss landscapes for materials property prediction analysis.

## **What This Demo Covers**

This notebook provides a **complete, self-contained tutorial** covering two major phases:

### **Part 1: Hessian Eigenvector Computation**
1. **Model and Data Loading** - Load demo model and demo dataset
2. **Prediction and Error Analysis** - Make predictions and identify lowest-error samples (Hessian computation is resource intensive. We always choose a smaller subset such as lowest prediction error samples, random samples, input-feature based selection, etc., to approximate the full dataset.)  
3. **Hessian Matrix Computation** - Compute maximum and minimum eigenvectors of the loss Hessian of the subset of samples
4. **Eigenvector Processing** - Convert eigenvectors to model weight format
5. **Saving** - Create and save eigenvectors

### **Part 2: Loss Landscape Generation and Analysis**
6. **2D Landscape Generation** - Use planar interpolation between the original model and the 2 (max&min) eigenvector models computed from the subset of dataset to generate loss landscapes for the full dataset.
7. **Visualization** - Create informative plots of loss surface topology

## **Scientific Context**

**Loss landscapes** reveal the topology of the loss function around a trained model, providing insights into:
- **Local minima structure** and escape paths  
- **Sensitivity to parameter perturbations**
- **Relationship between loss geometry and prediction quality**

**Hessian eigenvectors** define the principal directions of curvature in the loss surface, representing:
- **Maximum curvature direction** (steepest/sharpest changes)
- **Minimum curvature direction** (flattest/most stable changes)

## **Technical Details**

### **Dataset**: 
- **Source**: JARVIS-DFT (Joint Automated Repository for Various Integrated Simulations)
- **Property**: Formation energy (dHf) prediction
- **Samples**: 50 random materials from the database for demo purposes
- **Target Selection**: 20 lowest-error samples for eigenvector computation for demo purposes

### **Model**: 
- **Architecture**: ALIGNN (Atomistic Line Graph Neural Network)
- **Task**: Regression on formation energy values
- **Parameters**: ~4 million
- **Training**: Pre-trained on JARVIS-DFT formation energy data

### **Computational Requirements**:
- **Memory**: Hessian computation requires significant GPU memory
- **Time**: Complete workflow takes 10 minutes on a RTX4070 GPU

### **Key Algorithms**:
- **Hessian-Vector Products**: Efficient computation without storing full Hessian matrix
- **Power Iteration**: For computing dominant eigenvectors
- **Planar Interpolation**: Linear combinations of three models (original + 2 eigenvectors)

## **Output Files Generated**

This notebook creates several important output files:

### **Low Error Subset**:
- `demo_JVDFT_dHf_dataset_lowest_20_error_samples_from_50` - lowest error samples

### **Eigenvector Models**:
- `test_max_eig.pt` - Model weights set to maximum eigenvector
- `test_min_eig.pt` - Model weights set to minimum eigenvector  

### **Loss Landscapes**:
- `demo_loss_landscapes_df.pkl` - Structured DataFrame with landscapes and metadata
- `demo_raw_loss_landscape_array.npy` - Raw 3D numpy array (samples × grid_x × grid_y)

## **How to Use This Demo**

1. **Run all cells sequentially** - The notebook is designed to be executed from top to bottom
2. **Monitor memory usage** - Watch GPU/CPU memory during Hessian computation
3. **Adjust parameters** - Modify grid size, sample count, or scaling factors as needed
 

## 1. Import Required Libraries

We'll need several libraries for model loading, data processing, and Hessian computation.


In [None]:
import sys
sys.path.append('..')
import json
import torch
import copy
from alignn.models.alignn import ALIGNN, ALIGNNConfig
from collections import OrderedDict
from torchinfo import summary
import pandas as pd
from util.utils_AD import *
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from src.hessian_eigenvector import min_max_hessian_eigs, force_wts_into_model, npvec_to_tensorlist

print("All libraries imported successfully!")

## 2. Setup Device and Paths

Configure the computational device and define paths to our demo data and model.


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Define paths to demo data
data_path = 'demo_JVDFT_dHf_dataset_50.pkl'
model_path = 'demo_JVDFT_dHf_model.pt'
target = 'formation_energy_peratom'  # Formation energy target

print(f"Data path: {data_path}")
print(f"Model path: {model_path}")
print(f"Target property: {target}")

## 3. Load and Examine Dataset

Load the demo dataset and examine its structure. This dataset contains 50 samples from JARVIS-DFT with formation energy values.


In [None]:
# Load the dataset
data_df = pd.read_pickle(data_path)
print(f"Dataset shape: {data_df.shape}")
print(f"Columns: {list(data_df.columns)}")
print(f"\nFirst few rows:")

In [None]:
data_df.head()

The dataset should be a pandas DataFrame with at least three columns: 
- **'jid'**: This column serves as an arbitrary identifier for the structure, which could be a formula or any identifier you prefer.
- **'atoms'**: This is the input column, which should be created by converting your structure to a JARVIS Atoms structure and exporting it as a dictionary object.
- **'formula**: The formula of the structure (optional).
- A target label column of your choice.

## 4. Load Trained ALIGNN Model

Load the pre-trained ALIGNN model from the checkpoint file. This model has been trained on formation energy prediction.

In [None]:
# Load model checkpoint
checkpoint = torch.load(model_path, map_location=torch.device(device), weights_only=False)
print("Checkpoint loaded successfully")

# Initialize ALIGNN model and load weights
model = ALIGNN()
model.load_state_dict(checkpoint["model"])
model.eval()
model.to(device)

print(f"Model loaded and moved to {device}")
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

# Store model weights dictionary for later use
model_wt_dict = OrderedDict([i for i in model.named_parameters()])


## 5. Create Data Loader and Make Predictions

Convert the dataset to the format expected by ALIGNN and create a data loader. Then make predictions to calculate errors.


In [None]:
# Convert dataframe to list format
data_list = [row.to_dict() for _, row in data_df.iterrows()]
print(f"Converted {len(data_list)} samples to data list format")

# Create data loader
data_loader = get_data_loader(data_list, target, workers=0)
print(f"Data loader created with {len(data_loader)} batches")

# Make predictions and calculate errors
predictions = []
true_values = []
errors = []

print("Making predictions...")
model.eval()
with torch.no_grad():
    for i, batch in enumerate(data_loader):
        if i % 50 == 0:  # Progress indicator
            print(f"  Processing batch {i+1}/{len(data_loader)}")
        
        s0, s1, target_batch = batch
        s0, s1, target_batch = s0.to(device), s1.to(device), target_batch.to(device)
        
        # Make prediction
        pred = model((s0, s1))
        
        # Store results
        predictions.extend(pred.cpu().numpy().flatten())
        true_values.extend(target_batch.cpu().numpy().flatten())
        
        # Calculate absolute errors
        batch_errors = torch.abs(pred.flatten() - target_batch.flatten()).cpu().numpy()
        errors.extend(batch_errors)

print(f"Predictions completed for {len(predictions)} samples")


## 6. Identify 20 Lowest Error Samples

Since often times it is impossible to compute the Hessian eigenvectors for the full dataset due to the enormous amount of memory required. It is best to choose a subset to approximate the full dataset. One way to do it is to find the samples with the lowest prediction errors. These samples represent cases where the model performs best. Alternative ways could be random sampling, selection based on input features.

In [None]:
# Create results dataframe
results_df = data_df.copy()
results_df['predicted'] = predictions
results_df['true_value'] = true_values
results_df['absolute_error'] = errors

# Sort by error and get 20 lowest error samples
lowest_error_df = results_df.nsmallest(20, 'absolute_error')

print("Error Statistics:")
print(f"Mean error: {np.mean(errors):.4f}")
print(f"Median error: {np.median(errors):.4f}")
print(f"Min error: {np.min(errors):.4f}")
print(f"Max error: {np.max(errors):.4f}")

print(f"\n20 Lowest Error Samples:")
print(f"Error range: {lowest_error_df['absolute_error'].min():.6f} - {lowest_error_df['absolute_error'].max():.6f}")
print(f"JIDs: {list(lowest_error_df['jid'].values)}")

# Show the lowest error samples
lowest_error_df[['jid', 'true_value', 'predicted', 'absolute_error']].head(10)
lowest_error_df.to_pickle('demo_JVDFT_dHf_dataset_lowest_20_error_samples_from_50.pkl')

## 7. Prepare Data for Hessian Computation

Create a data loader using only the 20 lowest error samples for Hessian eigenvector computation.


In [None]:
# Convert lowest error samples to list format
lowest_error_list = [row.to_dict() for _, row in lowest_error_df.iterrows()]
print(f"Selected {len(lowest_error_list)} lowest error samples for Hessian computation")

# Create data loader for Hessian computation
hessian_data_loader = get_data_loader(lowest_error_list, target, workers=0)
print(f"Hessian data loader created with {len(hessian_data_loader)} batches")

# Prepare model for Hessian computation
loss_func = torch.nn.MSELoss()
func = copy.deepcopy(model)
func.to(device)
func.eval()

print(f"Model prepared for Hessian computation on {device}")


## 8. Compute Hessian Eigenvectors

This is the computationally intensive step where we compute the maximum and minimum eigenvectors of the Hessian matrix. 

**Note**: This computation may take several minutes depending on your hardware and model size. 

You could adjust the amount of samples used to see its influence on the memory required.


In [None]:
# Get original model parameters for eigenvector conversion
og_params = [i[1] for i in func.named_parameters() if len(i[1].size()) >= 1]
og_layer_names = [i[0] for i in func.named_parameters() if len(i[1].size()) >= 1]

print(f"Model structure:")
print(f"Total parameters with gradients: {len(og_params)}")
print(f"Total trainable parameters: {sum(p.numel() for p in og_params)}")

print(f"\nStarting Hessian eigenvector computation...")
print("This may take several minutes...")

# Compute Hessian eigenvectors
maxeig, mineig, maxeigvec, mineigvec, second_maxeig, second_maxeigvec = min_max_hessian_eigs(
    func, hessian_data_loader, loss_func, 
    all_params=False, verbose=False, use_cuda=(device=='cuda')
)

print(f"\nHessian computation completed!")


If the calculation is correct, you should get (+/- 0.1%):
- Maximum eigenvalue: 16.390605
- Minimum eigenvalue: 6.854265

In [None]:
if maxeig < mineig:
    maxeig, mineig = mineig, maxeig
    maxeigvec, mineigvec = mineigvec, maxeigvec
    print("Assumption of minimum eigenvalue < 0 is false. Switched maximum and minimum eigenvalues and eigenvectors.")

print(f"Maximum eigenvalue: {maxeig:.6f}")
print(f"Minimum eigenvalue: {mineig:.6f}")

## 9. Convert Eigenvectors to Model Weights

Convert the computed eigenvectors back into the tensor format that matches the model's parameter structure.


In [None]:
# Convert eigenvectors to model weight tensors
print("Converting eigenvectors to model weight format...")

max_model_wts = npvec_to_tensorlist(maxeigvec, og_params)
min_model_wts = npvec_to_tensorlist(mineigvec, og_params)

print(f"Max eigenvector converted to {len(max_model_wts)} weight tensors")
print(f"Min eigenvector converted to {len(min_model_wts)} weight tensors")

# Verify shapes match
print(f"\nShape verification:")
for i, (orig, max_eig, min_eig) in enumerate(zip(og_params, max_model_wts, min_model_wts)):
    if i < 3:  # Show first 3 layers only
        print(f"  Layer {i}: Original {orig.shape} == Max {max_eig.shape} == Min {min_eig.shape}")
    if orig.shape != max_eig.shape or orig.shape != min_eig.shape:
        print(f"Shape mismatch at layer {i}")
        break
else:
    print("All shapes match correctly")


## 10. Create and Save Eigenvector Models

Create new model instances with weights set to the eigenvectors and save them for later use in loss landscape generation.


In [None]:
# Create copies of the original model
model_eig_max = copy.deepcopy(func)
model_eig_min = copy.deepcopy(func)

print("Loading eigenvectors into model copies...")

# Load eigenvectors into models
model_eig_max = force_wts_into_model(og_layer_names, max_model_wts, model_eig_max, model_wt_dict)
model_eig_min = force_wts_into_model(og_layer_names, min_model_wts, model_eig_min, model_wt_dict)

print("Eigenvectors loaded into models successfully")

# Save the eigenvector models
print("Saving eigenvector models...")

os.makedirs('demo_computed_eigenvectors', exist_ok=True)
torch.save(model_eig_max.state_dict(), 'demo_computed_eigenvectors/test_max_eig.pt')
torch.save(model_eig_min.state_dict(), 'demo_computed_eigenvectors/test_min_eig.pt')

print("Models saved:")
print("  - test_max_eig.pt (maximum eigenvector model)")
print("  - test_min_eig.pt (minimum eigenvector model)")

# Also save eigenvalues for reference
eigenvalue_info = {
    'max_eigenvalue': float(maxeig),
    'min_eigenvalue': float(mineig),
    'second_max_eigenvalue': float(second_maxeig),
    'num_samples': len(lowest_error_list),
    'target': target
}



# Part 2: Loss Landscape Generation

Now that we have computed the Hessian eigenvectors, let's continue to generate the actual 2D loss landscapes. We'll use the eigenvector models we just created to explore the loss surface around our trained model.


## 12. Import Additional Libraries for Loss Landscapes

We need the loss_landscapes library for generating the 2D interpolation grids.


In [None]:
import loss_landscapes
import loss_landscapes.metrics
from loss_landscapes.model_interface.model_wrapper import ModelWrapper
from abc import ABC, abstractmethod

print("Loss landscapes libraries imported successfully!")

## 13. Define Custom Loss Metric

Create a custom loss metric class that will be used to evaluate the model at different points in the loss landscape.


In [None]:
class Metric(ABC):
    """ A quantity that can be computed given a model or an agent. """

    def __init__(self):
        super().__init__()

    @abstractmethod
    def __call__(self, model_wrapper: ModelWrapper):
        pass

class Loss(Metric):
    """ Computes a specified loss function over specified input-output pairs. """
    def __init__(self, loss_fn, model, inputs: torch.Tensor, target: torch.Tensor):
        super().__init__()
        self.loss_fn = loss_fn
        self.inputs = inputs
        self.model = model
        self.target = target

    def __call__(self, model_wrapper: ModelWrapper) -> float:
        outputs = model_wrapper.forward(self.inputs)
        err = self.loss_fn(self.target[0], outputs)
        return err

def split_3d_array(array):
    return [array[:,:,i:i+1] for i in range(array.shape[2])]

print("Custom metric classes defined successfully!")

## 14. Load Saved Eigenvector Models

Load the eigenvector models we just saved to use as the perturbation directions for our loss landscape.


In [None]:
# Load the eigenvector models we just created
print("Loading saved eigenvector models...")

# Create fresh model copies for loss landscape computation
model_eig_max_ll = copy.deepcopy(model)
model_eig_min_ll = copy.deepcopy(model)

# Load the saved eigenvector weights
model_eig_max_ll.load_state_dict(torch.load('demo_computed_eigenvectors/test_max_eig.pt', weights_only=True))
model_eig_min_ll.load_state_dict(torch.load('demo_computed_eigenvectors/test_min_eig.pt', weights_only=True))

# Move models to device and set to eval mode
model_eig_max_ll.to(device)
model_eig_min_ll.to(device)
model_eig_max_ll.eval()
model_eig_min_ll.eval()

print("Eigenvector models loaded successfully!")
print(f"All models are on {device} and in eval mode")


## 15. Create Loss Metrics for Each Sample

Create loss metric objects for the all 50 samples. Each metric will evaluate the loss for one specific sample. Here we use the MSE loss as the metric. 

Other viable metrics include:
1. MAE
2. Huber Loss
3. LogCosh


In [None]:
# Create metrics for each batch in our lowest error dataset
print("Creating loss metrics for each sample...")

metric_list = []
sample_count = 0

loss_func = torch.nn.MSELoss()

for batch in data_loader:
    s0_device = batch[0].to(device)
    s1_device = batch[1].to(device)
    s2_device = batch[2].to(device)
    
    x_train = (s0_device, s1_device)
    y_train = (s2_device)
    
    # Create a loss metric for this batch
    metric_list.append(Loss(loss_func, model.eval(), x_train, y_train))
    sample_count += len(s2_device)

print(f"Created {len(metric_list)} loss metrics for {sample_count} samples")
print(f"Each metric will evaluate loss for one batch of data")


## 16. Configure Loss Landscape Parameters

Set up the parameters for generating the 2D loss landscape grid.


In [None]:
# Configure loss landscape parameters
steps = 20  # Creates a 20x20 grid (you can adjust this)
scale_factor = 1.0  # Scaling factor for eigenvector perturbations
half = False  # Set to True to skip every other computation for speed

print(f"Loss landscape configuration:")
print(f"  Grid size: {steps} x {steps}")
print(f"  Scale factor: {scale_factor}")
print(f"  Half computation: {half}")
print(f"  Total evaluations: {steps * steps * len(metric_list) if not half else steps * steps * len(metric_list) // 2}")

# Estimate computation time
total_evals = steps * steps * len(metric_list)
if half:
    total_evals = total_evals // 2

print(f"Estimated total model evaluations: {total_evals:,}")
print(f"This may take several minutes depending on your hardware...")


## 17. Generate 2D Loss Landscapes

Now we'll generate the actual loss landscapes by interpolating between the original model and the two eigenvector models.

**Note**: This is may take several minutes to complete but is not memory intensive.


In [None]:
# Generate loss landscapes using batch planar interpolation
print("Starting loss landscape generation...")
print("This will create a 2D grid by interpolating between:")
print("  - Original model (center)")
print("  - Maximum eigenvector model (axis 1)")
print("  - Minimum eigenvector model (axis 2)")

import time
start_time = time.time()

try:
    loss_data_fin = loss_landscapes.batch_planar_interpolation(
        model_start=model.eval(),
        model_end_one=model_eig_max_ll.eval(),
        model_end_two=model_eig_min_ll.eval(),
        metric_list=metric_list,
        steps=steps,
        deepcopy_model=True,
        scale=scale_factor,
        half=half
    )
    
    end_time = time.time()
    computation_time = end_time - start_time
    
    print(f"Loss landscape generation completed!")
    print(f"Total computation time: {computation_time:.2f} seconds")
    print(f"Result shape: {loss_data_fin.shape}")
    
except Exception as e:
    print(f"Error during loss landscape generation:")
    print(f"Error: {str(e)}")
    raise


## 18. Process and Organize Results

Split the 3D result array and create a structured DataFrame for easy analysis and visualization.


In [None]:
# Process the results
print("Processing loss landscape results...")

# Split the 3D array into individual 2D landscapes for each sample
landscapes_list = split_3d_array(loss_data_fin)

print(f"Processing results:")
print(f"  Original shape: {loss_data_fin.shape}")
print(f"  Number of individual landscapes: {len(landscapes_list)}")
print(f"  Each landscape shape: {landscapes_list[0].shape}")

# Create a results DataFrame
loss_landscapes_df = pd.DataFrame()
loss_landscapes_df['jid'] = data_df['jid'].values
loss_landscapes_df['raw_loss_landscapes'] = landscapes_list

print(f"Created results DataFrame with {len(loss_landscapes_df)} samples")
print(f"DataFrame columns: {list(loss_landscapes_df.columns)}")


## 19. Save Results

Save the loss landscape results in multiple formats for future analysis.


In [None]:
# Save the results
print("Saving loss landscape results...")

os.makedirs('demo_computed_landscapes', exist_ok=True)

# Save DataFrame as pickle
loss_landscapes_df.to_pickle('demo_computed_landscapes/demo_loss_landscapes_df.pkl')
print("Saved: demo_loss_landscapes_df.pkl")

# Save raw array as numpy file
np.save('demo_computed_landscapes/demo_raw_loss_landscape_array.npy', loss_data_fin)
print("Saved: demo_raw_loss_landscape_array.npy")


## 20. Visualize Sample Loss Landscapes

Let's create some sample visualizations to see what the loss landscapes look like!


In [None]:
# Create visualizations for the first few samples
print("Creating sample loss landscape visualizations...")

# Set up the plot
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

center = steps // 2

# Plot the first 6 samples
for i in range(min(6, len(landscapes_list))):
    landscape = landscapes_list[i][:, :, 0]  # Remove the singleton dimension
    jid = loss_landscapes_df.iloc[i]['jid']  

    # Create contour plot
    im = axes[i].imshow(np.log(landscape), cmap='viridis', origin='lower', extent=[-center, center, -center, center])
    axes[i].set_title(f'Sample {i+1}: {jid}', fontsize=10)
    axes[i].set_xlabel('Max Eigenvector Direction')
    axes[i].set_ylabel('Min Eigenvector Direction')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=axes[i], shrink=0.8)
    cbar.set_label('log(MSE error)', fontsize=8)
    
    # Mark the center point (original model)
    axes[i].plot(0, 0, 'r*', markersize=10, label='Original Model')
    axes[i].legend()

plt.tight_layout()
plt.suptitle('Loss Landscapes for Lowest Error Samples', fontsize=16, y=1.02)
plt.show()


In [None]:
# Calculate the average landscape
avg_landscape = np.mean([landscape[:, :, 0] for landscape in landscapes_list], axis=0)

# Set up the plot for the average landscape
plt.figure(figsize=(8, 6))
plt.title('Average Loss Landscape', fontsize=14)
plt.xlabel(f'Max Eigenvector Direction (max eigenvalue: {maxeig:.6f})')
plt.ylabel(f'Min Eigenvector Direction (min eigenvalue: {mineig:.6f})')

# Create contour plot for the average landscape
im = plt.imshow(np.log(avg_landscape), cmap='viridis', origin='lower', extent=[-center, center, -center, center])
cbar = plt.colorbar(im, shrink=0.8)
cbar.set_label('log(MSE error)', fontsize=10)

# Mark the center point (original model)
plt.plot(0, 0, 'r*', markersize=10, label='Original Model')
plt.legend()

plt.tight_layout()
plt.show()


print("Sample visualizations created!")
