# Demo 4: Machine Learning Analysis of Loss Landscapes

This notebook demonstrates how to apply **unsupervised machine learning techniques** to analyze loss landscape data and extract meaningful patterns. We'll explore the data structure, perform clustering, and quantify relationships between material properties and loss landscape characteristics.

## **What This Demo Covers**

This notebook teaches you how to:
1. **Load and prepare processed data** from Demo 3
2. **Explore data with interactive selectors** to choose properties and landscapes  
3. **Apply dimensionality reduction** using UMAP and PCA
4. **Perform clustering analysis** with K-Means and Spectral Clustering
5. **Visualize relationships** between material properties and clusters
6. **Quantify associations** using mutual information analysis

## **Expected Input Data**

This demo requires processed data from **Demo 3**:

### **Required Input Files:**
- `computed_loss_landscapes/demo2_automated_landscape/processed_loss_function_dict.pkl` - Processed loss landscape metrics
- `computed_loss_landscapes/demo2_automated_landscape/feat_sample_df.pkl` - Enhanced sample features
- `computed_loss_landscapes/demo2_automated_landscape/feat_sample_composition_df.pkl` - Chemical composition features
- `computed_loss_landscapes/demo2_automated_landscape/feat_sample_structure_df.pkl` - Crystal structure features

## **Key Analysis Methods**

### **1. Dimensionality Reduction**
- **UMAP (Uniform Manifold Approximation and Projection)**: Preserves local and global structure
- **PCA (Principal Component Analysis)**: Linear dimensionality reduction with interpretable components

### **2. Clustering Algorithms**
- **K-Means Clustering**: Partitions data into k clusters based on feature similarity
- **Spectral Clustering**: Uses eigenvalues of similarity matrix for non-linear cluster identification

### **3. Property Analysis**
- **Interactive visualization**: Explore relationships between material properties and loss landscapes
- **Mutual information**: Quantify statistical dependencies between variables

## **No Files Generated**

Unlike previous demos, this notebook focuses on **interactive analysis and visualization**. All results are displayed within the notebook and no output files are created. This makes it perfect for **exploratory data analysis** and **parameter experimentation**.

---

**Let's explore the hidden patterns in our loss landscape data!**

## 1. Setup and Data Loading

### 1.1 Import Required Libraries

We'll import all the necessary libraries for machine learning analysis, visualization, and data processing. This includes:

- **Machine Learning**: scikit-learn for clustering, dimensionality reduction, and metrics
- **Visualization**: matplotlib, seaborn for plotting 
- **Data Processing**: pandas, numpy for data manipulation
- **Dimensionality Reduction**: UMAP, PCA for data visualization
- **Interactive Widgets**: ipywidgets for parameter exploration
- **Custom Utilities**: Our local modules for specialized functions


In [None]:
import os
os.chdir('..')
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, Checkbox, Dropdown, FloatSlider, fixed
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_kernels
from sklearn.neighbors import kneighbors_graph
from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import normalize
from scipy.linalg import eigh
import umap
from util.general import flatten_and_vstack, create_selectors
from util.landscape_processing import restore_to_square_shape
from util.plot import plot_loss_landscape, visualize_image_clusters, plot_categorical_data, plot_numerical_data, plot_umap_parameter_grid, plot_umap_scatter, plot_twin_umap_scatter
from src.pca_analysis import visualize_top_pcs, plot_pairwise_pc, plot_explained_variance
from src.spectral_clustering import manual_spectral_clustering
from sklearn.metrics import mutual_info_score
from sklearn.cluster import SpectralClustering
import seaborn as sns
import warnings
from src.tilt_angle import compute_best_tilt
warnings.filterwarnings("ignore", category=FutureWarning, message="'force_all_finite' was renamed to 'ensure_all_finite'")

### 1.2 Configure Data Source

Here we specify which folder contains our processed data from Demo 3. 

**You can modify this section to:**
- Change the folder path to analyze different datasets
- Add multiple folders to combine results from different experiments
- Switch between different loss landscape analysis results

**Current Configuration:**
- Using results from `demo2_automated_landscape` (from Demo 2 + Demo 3)


In [None]:
folders = [os.path.join('computed_loss_landscapes', 'demo2_automated_landscape')]

### 1.3 Load Processed Data

This section loads all the processed data files from Demo 3's post-processing step. We load:

1. **Loss Function Metrics** (`processed_loss_function_dict.pkl`): Contains various loss landscape metrics like:
   - `loss_at_origin`: Loss value at the original model position
   - `average_loss`: Mean loss across the landscape
   - `standard_deviation_of_loss`: Loss variability measure
   - `euclidean_distance_best_to_original`: Distance to optimal point
   - Transformed versions (log, z-score, min-max normalized)

2. **Enhanced Sample Features** (`feat_sample_df.pkl`): Original sample data with additional computed features

3. **Composition Features** (`feat_sample_composition_df.pkl`): Chemical composition descriptors from matminer

4. **Structure Features** (`feat_sample_structure_df.pkl`): Crystal structure descriptors from matminer

**Note:** If you have multiple experiment folders, this code will automatically merge all dictionaries.


In [None]:
loss_function_dicts = []
for folder in folders:
    with open(os.path.join(folder,'processed_loss_function_dict.pkl'), 'rb') as file:
        loss_function_dicts.append(pickle.load(file))


# Merge all dictionaries from loss_function_dicts into a single dictionary
merged_loss_function_dict = {}
for d in loss_function_dicts:
    for key, value in d.items():
        if key not in merged_loss_function_dict:
            merged_loss_function_dict[key] = value
        else:
            # If key exists, update/append values
            merged_loss_function_dict[key].update(value)

# Update the loss_function_dict to use merged version
loss_function_dict = merged_loss_function_dict

In [None]:
with open(os.path.join(folders[0],'feat_sample_df.pkl'), 'rb') as file:
    feat_sample_df = pickle.load(file)

with open(os.path.join(folders[0],'feat_sample_composition_df.pkl'), 'rb') as file:
    feat_sample_composition_df = pickle.load(file)

with open(os.path.join(folders[0],'feat_sample_structure_df.pkl'), 'rb') as file:
    feat_sample_structure_df = pickle.load(file)

sample_dict = {
    'feat_sample_df': feat_sample_df,
    'feat_sample_composition_df': feat_sample_composition_df,
    'feat_sample_structure_df': feat_sample_structure_df}

combined_dict = {**sample_dict, **loss_function_dict}

## 2. Interactive Data Selection

### 2.1 Create Interactive Selectors

This section creates interactive widgets that allow you to explore different combinations of:
- **Material properties** (composition, structure, formation energy)
- **Loss landscape metrics** (various processed metrics from Demo 3)
- **Transformed loss landscapes**

**How to use the selectors:**
1. **Property Dict Selector**: Choose which data type (sample features, composition, structure)
2. **Property Column Selector**: Pick specific properties within that data type  
3. **Loss Function Selector**: Choose which loss landscape analysis results
4. **Loss Function Column Selector**: Pick specific loss metrics

**Available property types:**
- `feat_sample_df`: Original sample properties (formation energy, bandgap, etc.)
- `feat_sample_composition_df`: Chemical composition features (elemental fractions, compound descriptors)
- `feat_sample_structure_df`: Crystal structure features (density, volume, symmetry)

**Available loss metrics:**
- Raw metrics: `loss_at_origin`, `average_loss`, `standard_deviation_of_loss`
- Derived metrics: `is_original_loss_the_lowest`, `euclidean_distance_best_to_original`
- Transformed versions: `log_*`, `z_*`, `minmax_*` variants


In [None]:
selectors = create_selectors(sample_dict, loss_function_dict)

### Recommended Demo Selections

For this demonstration, we recommend using the following interactive widget selections to explore meaningful relationships between material properties and loss landscapes:

#### **Material Properties**
| Selection | Dataset | Property | Description |
|-----------|---------|----------|-------------|
| **Property 1** | `feat_sample_df` | `formation_energy_peratom` | Formation energy per atom (fundamental thermodynamic property) |
| **Property 2** | `feat_sample_structure_df` | `density` | Material density (structural characteristic) |

#### **Loss Landscape Data**
| Selection | Dataset | Landscape Type | Description |
|-----------|---------|----------------|-------------|
| **Landscape 1** | `demo2_automated_landscape_mse` | `log_loss_landscape_array` | Log-transformed loss landscapes |
| **Landscape 2** | `demo2_automated_landscape_mse` | `loss_landscape_array` | Raw loss landscape values |

#### **Why These Selections?**
- **Formation energy**: Directly related to material stability and synthesizability
- **Density**: Captures structural compactness and packing efficiency  
- **Raw vs Log landscapes**: Compare original vs transformed loss surface topology

**Feel free to experiment with different combinations after running through this demo!**

### 2.2 Extract Selected Data

This cell extracts the specific properties and loss landscapes based on your widget selections above. 

**What gets extracted:**
- `property_1` & `property_2`: Material properties for analysis and visualization
- `landscape_array_1` & `landscape_array_2`: Loss landscape data arrays for dimensionality reduction

**Note:** You can modify the widget selections above and re-run this cell to analyze different property-landscape combinations.


In [None]:
property_1 = combined_dict[selectors['property_dict_selector_1'].value][selectors['property_column_selector_1'].value].copy()
property_2 = combined_dict[selectors['property_dict_selector_2'].value][selectors['property_column_selector_2'].value].copy()
landscape_array_1 = loss_function_dict[selectors['loss_function_selector_1'].value][selectors['loss_function_column_selector_1'].value].copy()
landscape_array_2 = loss_function_dict[selectors['loss_function_selector_2'].value][selectors['loss_function_column_selector_2'].value].copy()

## 3. Dimensionality Reduction with UMAP

### 3.1 UMAP Parameter Grid Search

**UMAP (Uniform Manifold Approximation and Projection)** is a powerful dimensionality reduction technique that preserves both local and global data structure.

**Key UMAP Parameters to explore:**

1. **`n_neighbors` (2-50)**: 
   - **Low values (2-5)**: Focus on local structure, more fragmented clusters
   - **High values (30-50)**: Emphasize global structure, smoother embeddings

2. **`min_dist` (0.0-1.0)**:
   - **Low values (0.0-0.1)**: Tighter clusters, points can be very close
   - **High values (0.5-1.0)**: More spread out, prevents overlapping

**What this visualization shows:**
- Grid of UMAP embeddings with different parameter combinations
- Helps you identify which parameters reveal the most interesting structure in your loss landscape data

**How to interpret:**
- Look for parameter combinations that show clear clustering patterns
- Consider how the structure relates to your domain knowledge of the materials

**Parameter lists you can modify:**
```python
n_neighbors_list = [2, 5, 15, 20, 30, 50]      # Try different neighbor counts
min_dist_list = [0, 0.01, 0.05, 0.1, 0.5, 1.0] # Try different minimum distances
```


In [None]:
# Extract the selected loss landscape array from the dictionary using the chosen selectors
landscape_array_1 = loss_function_dict[selectors['loss_function_selector_1'].value][selectors['loss_function_column_selector_1'].value].copy()

# Define a list of possible values for the number of neighbors parameter in UMAP
n_neighbors_list = [2, 5, 15, 20, 30, 50]

# Define a list of possible values for the minimum distance parameter in UMAP
min_dist_list = [0, 0.01, 0.05, 0.1, 0.5, 1.0]

# Plot the UMAP parameter grid using the defined hyperparameter lists
plot_umap_parameter_grid(landscape_array_1, n_neighbors_list, min_dist_list)


### 3.2 Apply Optimal UMAP Parameters

Based on the parameter grid above, we select optimal UMAP parameters and create the final embedding.

**Current Parameters:**
- `n_neighbors=15`: Balanced focus on local and global structure  
- `min_dist=0.1`: Allows tight clusters while preventing excessive overlap

**You can modify these parameters:**
```python
n_neighbors, min_dist = 15, 0.1  # Experiment with different values!
```

**Parameter recommendations:**
- For **materials discovery**: Try `n_neighbors=30, min_dist=0.0` (emphasize global patterns)
- For **anomaly detection**: Try `n_neighbors=5, min_dist=0.5` (highlight outliers)
- For **cluster analysis**: Try `n_neighbors=15, min_dist=0.1` (balanced view)

**Output:**
- 2D UMAP embedding saved as `umap_transformed_data`
- Scatter plot showing the embedding structure
- Each point represents one material sample's loss landscape


In [None]:

# Set UMAP hyperparameters for dimensionality reduction
n_neighbors, min_dist = 15, 0.1

# Initialize UMAP with the specified hyperparameters and a fixed random state for reproducibility
umap_reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=2, random_state=42)

# Apply UMAP to fit and transform the flattened landscape array data
umap_transformed_data = umap_reducer.fit_transform(flatten_and_vstack(landscape_array_1))

# Visualize the UMAP-transformed data using a scatter plot
plot_umap_scatter(umap_transformed_data, n_neighbors=n_neighbors, min_dist=min_dist, landscape_array=landscape_array_1)

## 4. Loss Landscape Tilt Analysis

### 4.1 Compute Tilt Angles

As you have seen at the end of demo 1, most loss landscapes show up as straight lines with different slopes.

This section analyzes the **directionality** in loss landscapes by computing the tilt angle of the loss ridge.

**What this analysis reveals:**
- **Tilt angle**: Direction of steepest descent in the loss landscape
- **Relationship to material properties**: How loss landscape orientation correlates with formation energy

**Key function: `compute_best_tilt()`**
- Fits lines to loss minima in both row and column directions
- Selects the fit with higher R² value
- Returns the tilt angle in the specified convention

**Parameters you can modify:**
- `degrees=True`: Return angle in degrees (vs radians)
- `angle_mode='unsigned180'`: Angle range (0°-180°) vs `'signed90'` (-90°-90°)

**Interpretation:**
- **0° or 180°**: Vertical loss ridge
- **90°**: Horizontal loss ridge  
- **45°**: Diagonal loss ridge

The histogram and scatter plot help identify if certain materials have preferred loss landscape orientations.


In [None]:
angles = []
for i in range(len(landscape_array_1.to_list())):
    angle, _ = compute_best_tilt(landscape_array_1.to_list()[i], degrees=True,angle_mode='unsigned180')
    angles.append(angle)

#plot histogram of angles
plt.hist(angles, bins=30, alpha=0.5, color='blue')
plt.title('Histogram of Tilt Angles')
plt.xlabel('Tilt Angle (degrees)')
plt.ylabel('Frequency')
plt.show()

# Plot predicted values as a function of tilt angle
true_vals = combined_dict['feat_sample_df']['formation_energy_peratom']

# Add a vertical line at x=0 and a horizontal line at y=-1 to the previous plot
plt.figure(figsize=(12,4))
plt.scatter(angles, true_vals, alpha=0.5, color='royalblue', edgecolor='k')
plt.xlabel('Tilt Angle (degrees)')
plt.ylabel('Formation Energy per Atom (eV/atom)')
plt.title('Formation Energy per Atom vs. Tilt Angle')
plt.tight_layout()
plt.show()

## 5. Principal Component Analysis (PCA)

### 5.1 Linear Dimensionality Reduction

**PCA** provides a linear approach to dimensionality reduction that's complementary to UMAP's non-linear approach.

**What PCA reveals:**
- **Principal components**: Linear combinations of loss landscape features that capture maximum variance
- **Explained variance**: How much information each component retains
- **Component interpretation**: Which aspects of loss landscapes drive the main variations

**Key parameters to explore:**

1. **`n_components`**: 
   - `None` (current): Retain all components for full analysis
   - Integer: Specify exact number of components to keep
   - Float (0.0-1.0): Retain components explaining this fraction of variance

2. **Interactive sliders:**
   - `n_pcs_to_visualize`: How many PC pairs to plot (1-10)
   - `n_pcs_to_analyze`: How many top PCs to examine in detail (2-10)

**Outputs:**
- **Explained variance plot**: Shows information content of each PC
- **Pairwise PC plots**: 2D scatter plots of different PC combinations  
- **PC visualization**: Heatmaps showing what each PC represents

**Comparison with UMAP:**
- **PCA**: Reveals linear relationships, interpretable components
- **UMAP**: Captures non-linear structure, better for clustering


In [None]:
# Retrieve and flatten the selected loss landscape array for PCA analysis
flattened_landscape_array_1 = flatten_and_vstack(landscape_array_1)

# Set the number of components for PCA to retain 98% of variance
n_components = None
pca = PCA(n_components=n_components)
# Transform the data using PCA
pca_transformed_data = pca.fit_transform(flattened_landscape_array_1)

# Plot the explained variance to understand the contribution of each principal component
plot_explained_variance(pca)

# Define an interactive function for PCA analysis and visualization
def interactive_pca_analysis(n_pcs_to_visualize, n_pcs_to_analyze):
    # Plot pairwise principal components for visualization
    plot_pairwise_pc(pca_transformed_data, n_pcs_to_visualize)
    # Visualize the top principal components
    visualize_top_pcs(pca, n_pcs_to_analyze)

# Create interactive sliders for selecting the number of principal components to visualize and analyze
interact(interactive_pca_analysis, n_pcs_to_visualize=IntSlider(min=1, max=10, step=1, value=3), n_pcs_to_analyze=IntSlider(min=2, max=10, step=1, value=2))


## 6. K-Means Clustering Analysis

### 6.1 Determine Optimal Number of Clusters

**K-Means clustering** partitions data into k clusters by minimizing within-cluster variance.

**The Elbow Method:**
- **Inertia**: Sum of squared distances from points to their cluster centers
- **Elbow point**: Where adding more clusters provides diminishing returns
- **Optimal k**: Usually where the curve "bends" most sharply

**Parameters you can modify:**
```python
k_values = range(2, 11)  # Try different ranges like range(2, 20)
```


In [None]:
# Retrieve and copy the selected loss landscape array for K-means clustering
landscape_array_1 = loss_function_dict[selectors['loss_function_selector_1'].value][selectors['loss_function_column_selector_1'].value].copy()

# Initialize a list to store inertia values for different numbers of clusters
inertias = []
# Define the range of k values (number of clusters) to evaluate
k_values = range(2, 11)

# Iterate over each k value to perform K-means clustering
for k in k_values:
    # Initialize K-means with the current number of clusters and a fixed random state for reproducibility
    kmeans = KMeans(n_clusters=k, random_state=42)
    # Fit K-means to the flattened landscape array data
    kmeans.fit(flatten_and_vstack(landscape_array_1))
    # Append the inertia (sum of squared distances to the nearest cluster center) to the list
    inertias.append(kmeans.inertia_)

# Plot the elbow curve to visualize the inertia for each k value
plt.figure(figsize=(5, 4))
plt.plot(k_values, inertias, 'bo-')  # Plot inertia values as a line with blue circle markers
plt.xlabel('Number of Clusters (k)')  # Label the x-axis
plt.ylabel('Inertia')  # Label the y-axis
plt.title('Elbow Plot for K-means Clustering')  # Set the title of the plot
plt.grid(True)  # Enable grid for better readability
plt.show()  # Display the plot

### 6.2 Interactive K-Means Visualization

This interactive section lets you explore different numbers of clusters and see how they appear in both UMAP and PCA space.

**Interactive slider: `n_k_means_clusters`**
- **Range**: 2-20 clusters
- **Default**: 3 clusters (good starting point)

**Dual visualization:**
1. **UMAP scatter plot**: Shows clusters in non-linear reduced space
   - Colors represent different cluster assignments
   - Points that are close in UMAP space tend to have similar loss landscapes

2. **PCA pairwise plots**: Shows clusters in linear reduced space
   - Multiple 2D projections of the first 4 principal components
   - Reveals linear separability of clusters

**How to use this interactively:**
1. Start with k=3 and observe cluster patterns
2. Increase k gradually and watch how clusters split
3. Compare UMAP vs PCA representations
4. Look for consistent cluster boundaries across both methods

**Good clusters typically:**
- Are well-separated in both UMAP and PCA space
- Have roughly balanced sizes
- Make physical/chemical sense for your materials


In [None]:
warnings.filterwarnings('ignore', category=UserWarning, module='sklearn.cluster._kmeans')

def run_kmeans_and_plot(n_k_means_clusters):
    # Initialize and fit K-means clustering with the specified number of clusters
    kmeans = KMeans(n_clusters=n_k_means_clusters, random_state=42)
    kmeans.fit(flatten_and_vstack(landscape_array_1))
    
    # Retrieve the cluster labels from the fitted K-means model
    k_means_labels = kmeans.labels_
    
    # Plot the UMAP scatter plot with the K-means cluster labels
    plot_umap_scatter(
        umap_transformed_data, 
        labels=k_means_labels, 
        label_name='K-means Clusters', 
        label_type='categorical', 
        n_neighbors=n_neighbors, 
        min_dist=min_dist, 
        landscape_array=landscape_array_1
    )
    
    # Plot pairwise principal components with the K-means cluster labels
    plot_pairwise_pc(
        pca_transformed_data, 
        4, 
        labels=k_means_labels, 
        label_name='K-means Clusters', 
        label_type='continuous'
    )

# Create an interactive slider to select the number of clusters for K-means
interact(run_kmeans_and_plot, n_k_means_clusters=IntSlider(min=2, max=20, step=1, value=3))

### 6.3 Visualize Cluster Representatives

This section provides a detailed look at the actual loss landscapes within each cluster.

**Fixed clustering**: Uses k=2 for clear comparison between two main groups

**What you'll see:**
- **UMAP plot**: Shows the 2-cluster assignment in embedding space
- **Loss landscape grid**: Shows up to 20 example loss landscapes from each cluster

**Parameters you can modify:**
```python
kmeans = KMeans(n_clusters=2, random_state=42)  # Try different k values
max_arrays_per_cluster=20  # Show more/fewer examples per cluster
```

**How to interpret the landscape visualization:**
- **Cluster consistency**: Do landscapes within a cluster look similar?
- **Cluster differences**: How do the landscape patterns differ between clusters?
- **Outliers**: Are there landscapes that don't fit their assigned cluster well?

**This analysis helps validate that:**
- Clusters correspond to meaningful differences in loss landscape structure
- The clustering algorithm is capturing interpretable patterns
- The number of clusters is appropriate for your data


In [None]:
landscape_array_1 = loss_function_dict[selectors['loss_function_selector_1'].value][selectors['loss_function_column_selector_1'].value].copy()

#save k_means_labels
kmeans = KMeans(n_clusters=2, random_state=42)
kmeans.fit(flatten_and_vstack(landscape_array_1))
k_means_labels = kmeans.labels_

plot_umap_scatter(umap_transformed_data, labels=k_means_labels, label_name='k-MEANS Clusters', label_type='continuous', n_neighbors=n_neighbors, min_dist=min_dist, landscape_array=landscape_array_1)

visualize_image_clusters(landscape_array_1, k_means_labels,max_arrays_per_cluster=20)

## 7. Spectral Clustering Analysis

### 7.1 Interactive Spectral Clustering

**Spectral clustering** uses eigenvalues of similarity matrices to identify clusters, making it effective for non-convex cluster shapes that K-means might miss.

**Key parameters to explore:**

1. **`affinity_type`**:
   - **`'nearest_neighbors'`**: Uses k-nearest neighbor graph (good for local structure)
   - **`'rbf'`**: Uses radial basis function (good for smooth density variations)

2. **`affinity_param`**:
   - For `'nearest_neighbors'`: Number of neighbors (5-50)
   - For `'rbf'`: Gamma parameter (0.1-10, higher = more localized)

3. **`n_clusters`**: Number of clusters to find (2-10)

**Interactive options:**
- **`visualize_affinity_scree`**: Shows the affinity matrix heatmap
- **`visualize_in_2d`**: Displays clusters in UMAP and PCA space
- **`visualize_image`**: Shows example loss landscapes per cluster

**How to use this section:**
1. Start with `affinity_type='nearest_neighbors'` and `affinity_param=13`
2. Try different numbers of neighbors and observe cluster changes
3. Compare with `affinity_type='rbf'` to see different cluster shapes
4. Enable visualization options to understand cluster quality

**Spectral vs K-means:**
- **Spectral**: Can find non-spherical clusters, considers global structure
- **K-means**: Assumes spherical clusters, faster computation


In [None]:
def visualized_spectral_clustering(data, data_visualize, affinity_type, affinity_param, visualize_affinity_scree, n_clusters, visualize_in_2d, visualize_image):
        # Perform spectral clustering on the data with specified parameters
        spectral_clustering = SpectralClustering(n_clusters=n_clusters, gamma = affinity_param, affinity=affinity_type, n_neighbors=int(affinity_param))
        label = spectral_clustering.fit_predict(data)

        if visualize_affinity_scree:
            # Plot the affinity matrix
            plt.figure(figsize=(6,5))
            heatmap = sns.heatmap(spectral_clustering.affinity_matrix_.toarray(), cmap='viridis')
            plt.title('Affinity Matrix Heatmap')
            plt.show()

        # Extract the affinity matrix from the SpectralClustering object
        # Check if 2D visualization is enabled
        if visualize_in_2d:
            # Plot UMAP scatter plot with spectral cluster labels
            plot_umap_scatter(umap_transformed_data, labels=label, label_name='Spectral Clusters', label_type='continuous', n_neighbors=n_neighbors, min_dist=min_dist, landscape_array=landscape_array_1)
            # Plot pairwise principal components with spectral cluster labels
            plot_pairwise_pc(pca_transformed_data, 3, labels=label, label_name='Spectral Clusters', label_type='continuous')

        # Check if image visualization is enabled
        if visualize_image:
            # Visualize image clusters based on spectral clustering labels
            visualize_image_clusters(data_visualize, label,max_arrays_per_cluster=30)
        else:
            # Print message if visualization is disabled
            print("Visualization is disabled. Enable visualization options to view results.")

# Copy the selected loss function arrays for further processing
landscape_array_1 = loss_function_dict[selectors['loss_function_selector_1'].value][selectors['loss_function_column_selector_1'].value].copy()

# Create an interactive widget for spectral clustering visualization
interact(visualized_spectral_clustering,
        data=fixed(flatten_and_vstack(landscape_array_1)),  # Flatten and stack the first landscape array
        data_visualize=fixed(landscape_array_2),  # Use the second landscape array for visualization
        affinity_type=Dropdown(options=['rbf', 'nearest_neighbors'], value='nearest_neighbors', description='Affinity Type'),  # Dropdown for selecting affinity type
        affinity_param=FloatSlider(min=0.1, max=50, step=0.1, value=13, description='σ/Neighbors'),  # Slider for affinity parameter
        visualize_affinity_scree=Checkbox(value=False, description='Affinity Plot'),  # Checkbox for scree plot visualization
        n_clusters=IntSlider(min=2, max=10, step=1, value=3, description='Clusters'),  # Slider for number of clusters
        visualize_in_2d=Checkbox(value=False, description='Visualize in 2D'),  # Checkbox for 2D visualization
        visualize_image=Checkbox(value=False, description='Visualize Image'))  # Checkbox for image visualization


### 7.2 Save Spectral Clustering Results

This cell creates a fixed spectral clustering result for comparison with K-means.

**Fixed parameters:**
- `n_clusters=3`: Three-cluster solution
- `affinity='nearest_neighbors'`: k-NN based similarity
- `n_neighbors=15`: Moderate neighborhood size

**You can modify these parameters:**
```python
spectral_clustering = SpectralClustering(
    n_clusters=2,           # Try 2-6 clusters
    affinity='nearest_neighbors',  # Or 'rbf'
    n_neighbors=int(15)     # Try 5-30 neighbors
)
```


In [None]:
landscape_array_1 = loss_function_dict[selectors['loss_function_selector_1'].value][selectors['loss_function_column_selector_1'].value].copy()

spectral_clustering = SpectralClustering(n_clusters=2, affinity='nearest_neighbors', n_neighbors=int(15))
spectral_labels = spectral_clustering.fit_predict(flatten_and_vstack(landscape_array_1))

### 7.3 Compare K-means vs Spectral Clustering

This visualization shows both clustering methods side-by-side to compare their results.

**Twin UMAP plot shows:**
- **Left panel**: K-means cluster assignments
- **Right panel**: Spectral clustering assignments
- **Same UMAP embedding**: Both use identical dimensional reduction

**What to look for:**
- **Agreement**: Do both methods identify similar cluster boundaries?
- **Differences**: Where do they disagree and why?
- **Cluster shapes**: Does spectral clustering find non-spherical patterns that K-means misses?

**Parameters you can modify:**
```python
plot_twin_umap_scatter(
    umap_data=umap_transformed_data, 
    n_neighbors=20,     # UMAP neighbors for display
    min_dist=0.05,      # UMAP min_dist for display
    labels1=k_means_labels, 
    labels2=spectral_labels,
    # ... other parameters
)
```


In [None]:
plot_twin_umap_scatter(umap_data=umap_transformed_data, n_neighbors=20, min_dist=0.05, labels1=k_means_labels, labels2=spectral_labels, label_name1='K-Means Clusters', label_name2='Spectral Clusters', label_type='continuous', landscape_array=landscape_array_1)

## 8. DBSCAN Clustering (Optional)

### 8.1 Density-Based Clustering Parameter Grid

**DBSCAN (Density-Based Spatial Clustering)** identifies clusters based on point density, automatically determining the number of clusters and identifying outliers.

**Key DBSCAN parameters:**

1. **`eps`**: Maximum distance for points to be considered neighbors
   - **Small values**: More clusters, tighter groups
   - **Large values**: Fewer clusters, more points as noise

2. **`min_samples`**: Minimum points needed to form a cluster
   - **Small values**: More sensitive to local density variations
   - **Large values**: Requires stronger evidence for cluster formation

**Parameter grid exploration:**
- **`eps_values`**: Logarithmic scale from ~6.3 to 31.6
- **`min_samples_values`**: Linear scale from 1 to 7

**How to interpret the grid:**
- **Colors**: Different clusters (noise points usually in dark)
- **Good parameters**: Clear cluster separation with minimal noise
- **Too restrictive**: Most points labeled as noise (dark)
- **Too permissive**: Everything in one large cluster

In [None]:
# Copy the selected loss function array for DBSCAN analysis
landscape_array_1 = loss_function_dict[selectors['loss_function_selector_1'].value][selectors['loss_function_column_selector_1'].value].copy()

# Define a range of eps values using a logarithmic scale
eps_values = np.logspace(0.8, 1.5, num=8)  # Log scale from approximately 6.3 to 31.6

# Define a range of min_samples values using a linear scale
min_samples_values = np.arange(1, 8)  # Linear scale from 1 to 7

# Create a grid of subplots with dimensions based on eps and min_samples values
fig, axs = plt.subplots(len(eps_values), len(min_samples_values), figsize=(12, 12))
fig.suptitle('DBSCAN on UMAP Reduced Data', fontsize=16)  # Set the main title for the figure

# Iterate over each combination of eps and min_samples to perform DBSCAN
for i, eps in enumerate(eps_values):
    for j, min_samples in enumerate(min_samples_values):
        # Initialize DBSCAN with the current eps and min_samples parameters
        dbscan = DBSCAN(eps=eps, min_samples=min_samples)
        
        # Fit DBSCAN and predict cluster labels on the UMAP reduced data
        dbscan_labels = dbscan.fit_predict(flatten_and_vstack(landscape_array_1))
        
        # Plot the DBSCAN clustering result in the corresponding subplot
        ax = axs[i, j]
        scatter = ax.scatter(umap_transformed_data[:, 0], umap_transformed_data[:, 1], c=dbscan_labels, cmap='viridis', s=5)
        ax.set_ylabel(f'eps={eps}', fontsize=8)  # Label the y-axis with the current eps value
        ax.set_xlabel(f'min_samples={min_samples}', fontsize=8)  # Label the x-axis with the current min_samples value
        ax.label_outer()  # Hide x and y labels for inner plots to reduce clutter

# Add a legend to the figure to identify clusters
fig.legend(*scatter.legend_elements(), title="Clusters", loc='upper right', bbox_to_anchor=(1.1, 1))

# Adjust the layout to ensure the main title and subplots are properly spaced
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()  # Display the figure

## 9. Property-Cluster Relationship Analysis

### 9.1 Compare Material Properties with Clustering Results

This final analysis explores whether the clustering patterns discovered in loss landscapes correlate with material properties.

**Analysis setup:**
- **`n_clusters=4`**: Use 4-cluster solutions for comparison
- **`n_bins=4`**: Bin continuous properties into 4 categories (if needed)

**What this section compares:**
1. **K-means clustering**: Based purely on loss landscape similarity
2. **Material properties**: Chemical/physical characteristics (formation energy, etc.)

**Key visualization:**
- **Twin UMAP plot**: Shows K-means clusters vs selected material property
- **Left panel**: Clusters from unsupervised learning  
- **Right panel**: Material property values (continuous coloring)

**Mutual Information Score:**
- **Range**: 0 (no relationship) to ~1 (strong relationship)
- **Current result**: ~0.054 (weak relationship)
- **Interpretation**: Loss landscape patterns show minimal correlation with the selected property

**Parameters you can modify:**
```python
n_clusters, n_bins = 4, 4          # Try different numbers
property_1 = combined_dict[...]     # Select different properties via selectors above
```

**What different MI scores mean:**
- **0.0-0.1**: Very weak relationship
- **0.1-0.3**: Weak relationship  
- **0.3-0.7**: Moderate relationship
- **0.7+**: Strong relationship

This analysis helps determine if loss landscape clustering reveals chemically meaningful patterns or if it captures purely computational/numerical features.


In [None]:
n_clusters, n_bins = 4, 4

property_1 = combined_dict[selectors['property_dict_selector_1'].value][selectors['property_column_selector_1'].value].copy()
binned_labels = plot_numerical_data(property_1, bin_number=n_bins, display_stats=False)

spectral_labels = manual_spectral_clustering(n_clusters=n_clusters, data=flatten_and_vstack(landscape_array_1),affinity_type='knn',affinity_param=13)

kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(flatten_and_vstack(landscape_array_1))
k_means_labels = kmeans.labels_

plot_twin_umap_scatter(
    umap_data=umap_transformed_data,
    n_neighbors=20,
    min_dist=0.05,
    labels1=k_means_labels,
    label_name1='K-Means Clusters',
    labels2=property_1,
    label_name2=f'Binned {property_1.name}',
    label_type='continuous',
    landscape_array=landscape_array_1
)

mi = mutual_info_score(k_means_labels, binned_labels)
print(f"The mutual information score between the two labels is: {mi}")
