# Cluster-Based Parameterization of Vertical Mixing Coefficients in the Ocean Surface Boundary Layer

This notebook extends the analysis carried out by **Sane et al. (2023)** "Parameterizing vertical mixing coefficients in the ocean surface boundary layer using neural networks" by implementing a **conditional modeling approach based on clustered ocean states**.

## Motivation

The original paper demonstrated that neural networks can effectively predict vertical mixing shapes in the ocean. However, we hypothesize that **different ocean states might benefit from specialized models**. By clustering the input feature space, we can potentially develop models that are better tailored to specific oceanic conditions, further improving prediction accuracy.

## Approach

1. **Reproduce baseline results** from Sane et al. (2023)
2. **Cluster ocean states** based on physical parameters
3. **Train separate models** for each cluster
4. **Compare performance** of the cluster-specific approach vs. the baseline

## Expected Outcomes

- Identification of distinct oceanic regimes with different mixing behaviors
- Improved prediction accuracy, especially in regions with complex dynamics
- Better understanding of how different physical drivers influence mixing in different regimes

In [ ]:
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
import torch.nn as nn
from torch import optim
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import xarray as xr
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
from datetime import datetime
import warnings
from matplotlib.gridspec import GridSpec

warnings.filterwarnings("ignore", category=FutureWarning)

# Setup notebook environment
today = datetime.today()
np.random.seed(100)
torch.manual_seed(100)

# Fix paths
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
os.chdir(parent_dir)
cwd = parent_dir
print("Current working directory:", os.getcwd())

# Import custom modules
import lib.func_file as ff
from lib.visual_figure4 import performance_sigma_point
from lib.cluster_visualizations import (
    plot_silhouette_scores,
    plot_cluster_distributions, 
    plot_shape_functions_by_cluster, 
    plot_cluster_2d_projections,
    plot_cluster_size_distribution,
    plot_cluster_centers,
    plot_model_performance_comparison,
    plot_sample_predictions,
    plot_error_distributions,
    plot_overall_performance_comparison
)

# Set up directories
cwd_data = cwd + '/Data/'
cwd_output = cwd + '/output/'
os.makedirs(cwd_output, exist_ok=True)

# Set device for PyTorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
%%capture
import importlib

# Check and install required packages
if importlib.util.find_spec("torch") is None:
    !pip install torch
if importlib.util.find_spec("zarr") is None:
    !pip install zarr
if importlib.util.find_spec("scikit-learn") is None:
    !pip install scikit-learn

# Ensure xarray is upgraded
!pip install --upgrade xarray

In [None]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
import torch.nn as nn
from torch import optim
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import xarray as xr
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
from datetime import datetime
import warnings
from matplotlib.gridspec import GridSpec

warnings.filterwarnings("ignore", category=FutureWarning)

# Setup notebook environment
today = datetime.today()
np.random.seed(100)
torch.manual_seed(100)

# Fix paths
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
os.chdir(parent_dir)
cwd = parent_dir
print("Current working directory:", os.getcwd())

# Import custom modules
import lib.func_file as ff
from lib.visual_figure4 import performance_sigma_point

# Set up directories
cwd_data = cwd + '/Data/'
cwd_output = cwd + '/output/'
os.makedirs(cwd_output, exist_ok=True)

# Set device for PyTorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1.1 Load and Preprocess Data

We'll load the GOTM (General Ocean Turbulence Model) data provided by Sane et al. (2023) and apply the same preprocessing steps.

In [None]:
# Load GOTM training data produced by Sane et al. 2023
store = 'https://nyu1.osn.mghpcc.org/leap-pangeo-manual/GOTM/sf_training_data.zarr'
d = xr.open_dataset(store, engine='zarr', chunks={})

# Coriolis parameter calculation
def corio(lat):
    return 2*(2*np.pi/(24*60*60)) * np.sin(lat*(np.pi/180))

# Extract variables
l0 = corio(d['l'][:])
b00 = d['b0'][:]
ustar0 = d['ustar'][:]
h0 = d['h'][:]
lat0 = d['lat'][:]
heat0 = d['heat'][:]
tx0 = d['tx'][:]
tx0 = np.round(tx0, 2)
SF0 = d['SF'][:]

In [None]:
# Quick visualization of input features
variables = {
    "Coriolis parameter (l0) (s$^{-1}$)": l0.values.flatten(),
    "Surface Buoyancy Flux (b00) (m$^{2}$ s$^{-3}$)": b00.values.flatten(),
    "Surface Friction Velocity (ustar0) (m s$^{-1}$)": ustar0.values.flatten(),
    "Boundary Layer Depth (h0) (m)": h0.values.flatten()
}

# Create histograms with improved styling
plt.figure(figsize=(15, 10))
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

for i, (name, data) in enumerate(variables.items()):
    plt.subplot(2, 2, i + 1)
    plt.hist(data, bins=30, color=colors[i], alpha=0.7, edgecolor='black')
    plt.title(name, fontsize=14, fontweight='bold')
    plt.xlabel("Value", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.grid(axis="y", linestyle="--", alpha=0.6)

plt.tight_layout()
plt.suptitle("Distribution of Input Features", fontsize=16, fontweight='bold', y=1.02)
plt.show()

In [None]:
# Apply the same filtering criteria as in Sane et al. (2023)
ind101 = np.where(np.abs(heat0) < 601)[0]
ind1 = ind101
ind2 = np.where(tx0 < 1.2)[0]
ind3 = np.where(h0 > 29)[0]
ind4 = np.where(h0 < 301)[0]

# Combine all filters
ind5 = np.intersect1d(ind1, ind2)
ind6 = np.intersect1d(ind3, ind5)
ind7 = np.intersect1d(ind4, ind6)

# Prepare data for model training
mm1 = 0; mm2 = 16  # 16 levels (level 1 at bottom, level 16 at top)
data_load_main = np.zeros([len(h0[ind7]), 4 + mm2 - mm1])
data_load_main[:, 0] = l0[ind7]     # Coriolis parameter
data_load_main[:, 1] = b00[ind7]    # Surface buoyancy flux
data_load_main[:, 2] = ustar0[ind7] # Surface friction velocity
data_load_main[:, 3] = h0[ind7]     # Boundary layer depth
data_load_main[:, 4:(mm2 - mm1 + 4)] = SF0[ind7, mm1:mm2] # Shape functions

# Store additional forcing variables for reference
data_forc = np.zeros([len(ind7), 3])
data_forc[:, 0] = lat0[ind7]  # Latitude
data_forc[:, 1] = heat0[ind7] # Heat flux
data_forc[:, 2] = tx0[ind7]   # Wind stress

# Create a copy of the data for processing
data_load3 = copy.deepcopy(data_load_main)

# Preprocess data using the same function as in the original code
print('Preprocessing data...')
data, x, y, stats, k_mean, k_std = ff.preprocess_train_data(data_load3)
print('Done preprocessing')

In [None]:
# Load validation data
url = "https://nyu1.osn.mghpcc.org/leap-pangeo-manual/GOTM/data_testing_4_paper.txt"
df = pd.read_csv(url, delim_whitespace=True, header=None)
valid_data = df.iloc[:, 3:].values

# Apply same filtering criteria
ind3 = np.where(valid_data[:, 3] > 29)[0]
ind4 = np.where(valid_data[:, 3] < 301)[0]
ind = np.intersect1d(ind3, ind4)

# Extract and normalize validation features
valid_x = valid_data[ind, 0:4]
valid_x[:, 0] = (valid_x[:, 0] - stats[0]) / stats[1]  # Normalize Coriolis
valid_x[:, 1] = (valid_x[:, 1] - stats[2]) / stats[3]  # Normalize buoyancy flux
valid_x[:, 2] = (valid_x[:, 2] - stats[4]) / stats[5]  # Normalize friction velocity
valid_x[:, 3] = (valid_x[:, 3] - stats[6]) / stats[7]  # Normalize layer depth

# Extract and normalize validation targets
valid_y = valid_data[ind, 5:]
for i in range(len(valid_y)):
    valid_y[i, :] = np.log(valid_y[i, :] / np.max(valid_y[i, :]))
for i in range(16):
    valid_y[:, i] = (valid_y[:, i] - k_mean[i]) / k_std[i]

# Convert to PyTorch tensors
x = torch.FloatTensor(x).to(device)
y = torch.FloatTensor(y).to(device)
valid_x = torch.FloatTensor(valid_x).to(device)
valid_y = torch.FloatTensor(valid_y).to(device)

print(f"Training data shape: {x.shape}, {y.shape}")
print(f"Validation data shape: {valid_x.shape}, {valid_y.shape}")

## 1.2 Define the Neural Network Model

We'll use the same architecture as in the original paper: a multilayer perceptron with 2 hidden layers and 32 neurons per layer.

In [None]:
class OceanMixingNN(nn.Module):
    def __init__(self, in_nodes, hidden_nodes, out_nodes):
        super(OceanMixingNN, self).__init__()
        self.linear1 = nn.Linear(in_nodes, hidden_nodes)
        self.linear2 = nn.Linear(hidden_nodes, hidden_nodes)
        self.linear3 = nn.Linear(hidden_nodes, out_nodes)
        self.dropout = nn.Dropout(0.25)  # Regularization
        
        # Initialize weights using the same approach as in original paper
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.xavier_uniform_(self.linear2.weight)
        nn.init.xavier_uniform_(self.linear3.weight)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        
        x = self.linear2(x)
        x = torch.relu(x)
        x = self.dropout(x)
        
        return self.linear3(x)

In [None]:
def train_model(model, x, y, valid_x, valid_y, k_mean, k_std, epochs=3000, lr=1e-3, patience=20):
    """Train a model with early stopping based on validation loss"""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.L1Loss(reduction='mean')
    losses = []
    
    # Convert k_mean and k_std to tensors on appropriate device
    k_mean_tensor = torch.tensor(k_mean, dtype=torch.float32).to(device)
    k_std_tensor = torch.tensor(k_std, dtype=torch.float32).to(device)
    
    # For early stopping
    best_loss = float('inf')
    best_model = None
    no_improve = 0
    
    with tqdm(total=epochs, desc="Training") as pbar:
        for epoch in range(epochs):
            # Training step
            model.train()
            optimizer.zero_grad()
            
            # Forward pass
            train_pred = model(x)
            
            # Compute loss (same as in original paper)
            loss = loss_fn(train_pred, y)
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            # Validation
            model.eval()
            with torch.no_grad():
                valid_pred = model(valid_x)
                
                # Calculate loss in original space, not normalized space
                train_loss = torch.mean(torch.abs(
                    torch.exp(train_pred * k_std_tensor + k_mean_tensor) - 
                    torch.exp(y * k_std_tensor + k_mean_tensor)
                ))
                
                valid_loss = torch.mean(torch.abs(
                    torch.exp(valid_pred * k_std_tensor + k_mean_tensor) - 
                    torch.exp(valid_y * k_std_tensor + k_mean_tensor)
                ))
            
            # Store losses
            losses.append((epoch, train_loss.item(), valid_loss.item()))
            
            # Update progress bar
            pbar.update(1)
            pbar.set_postfix({
                'train_loss': f"{train_loss.item():.4f}", 
                'valid_loss': f"{valid_loss.item():.4f}",
                'patience': no_improve
            })
            
            # Early stopping
            if valid_loss.item() < best_loss:
                best_loss = valid_loss.item()
                best_model = copy.deepcopy(model.state_dict())
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= patience:
                    print(f"\nEarly stopping at epoch {epoch+1}")
                    break
    
    # Load the best model
    if best_model is not None:
        model.load_state_dict(best_model)
    
    return model, np.array(losses)

## 1.3 Train the Baseline Model

First, we'll train a single model on all data, following the exact approach in Sane et al. (2023).

In [None]:
# Define model parameters
in_nodes = 4      # 4 input features
hidden_nodes = 32 # 32 nodes per hidden layer
out_nodes = 16    # 16 output points (one for each sigma level)

# Initialize and train the baseline model
print("Training baseline model on all data...")
baseline_model = OceanMixingNN(in_nodes, hidden_nodes, out_nodes).to(device)
baseline_model, baseline_losses = train_model(
    baseline_model, x, y, valid_x, valid_y, k_mean, k_std
)

In [None]:
# Plot training and validation losses
plt.figure(figsize=(10, 6))
epochs = baseline_losses[:, 0]
train_loss = baseline_losses[:, 1]
valid_loss = baseline_losses[:, 2]

plt.plot(epochs, train_loss, 'b-', linewidth=2, label='Training Loss')
plt.plot(epochs, valid_loss, 'r-', linewidth=2, label='Validation Loss')
plt.fill_between(epochs, train_loss, valid_loss, alpha=0.2, color='gray', label='Generalization Gap')

# Add the best validation loss
best_epoch = np.argmin(valid_loss)
best_valid = valid_loss[best_epoch]
plt.scatter(epochs[best_epoch], best_valid, c='g', s=100, zorder=3)
plt.annotate(f'Best: {best_valid:.4f}', 
             (epochs[best_epoch], best_valid),
             xytext=(10, -20), textcoords='offset points',
             arrowprops=dict(arrowstyle='->', color='green'))

plt.grid(True, alpha=0.3, linestyle='--')
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('L1 Loss', fontsize=14)
plt.title('Baseline Model Training', fontsize=16, fontweight='bold')
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

# 2. Clustering Analysis

Now we'll implement our extension to the original approach by clustering the ocean states and training separate models for each cluster.

## 2.1 Feature Clustering

We'll cluster the input data based on the four physical parameters: Coriolis parameter, surface buoyancy flux, surface friction velocity, and boundary layer depth.

In [ ]:
def find_optimal_clusters(features, max_clusters=10, fixed_clusters=None):
    """
    Find optimal number of clusters using silhouette score.
    If fixed_clusters is provided, will return that value without calculation.
    """
    # If fixed number of clusters is provided, return it
    if fixed_clusters is not None:
        print(f"Using fixed number of clusters: {fixed_clusters}")
        return fixed_clusters
    
    # Scale features for clustering
    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(features)
    
    # Calculate silhouette scores for different cluster counts
    silhouette_scores_list = []
    for k in range(2, max_clusters + 1):
        kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(scaled_features)
        score = silhouette_score(scaled_features, cluster_labels)
        silhouette_scores_list.append((k, score))
        print(f"Clusters: {k}, Silhouette Score: {score:.4f}")
    
    # Plot silhouette scores using our visualization function
    plot_silhouette_scores(silhouette_scores_list)
    
    # Return optimal number of clusters (highest silhouette score)
    return max(silhouette_scores_list, key=lambda x: x[1])[0]

In [ ]:
# Get original (non-normalized) input features for clustering
original_features = data_load_main[:, :4].astype(np.float64)

# Use a fixed number of clusters (4) to avoid potential infinite loops
n_clusters = 4
print(f"Using a fixed number of clusters: {n_clusters}")

# Optionally, you can uncomment the following line to find optimal clusters instead
# n_clusters = find_optimal_clusters(original_features, max_clusters=8)

In [ ]:
def cluster_data(features, n_clusters):
    """Cluster the data based on physical parameters"""
    # Scale features
    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(features)
    
    # Perform K-means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_assignments = kmeans.fit_predict(scaled_features)
    
    # Print cluster sizes
    for i in range(n_clusters):
        count = np.sum(cluster_assignments == i)
        print(f"Cluster {i}: {count} samples ({count/len(cluster_assignments)*100:.2f}%)")
    
    # Extract cluster centroids in original feature space
    centroids_scaled = kmeans.cluster_centers_
    centroids_original = scaler.inverse_transform(centroids_scaled)
    
    # Create a DataFrame with cluster statistics
    feature_names = ['Coriolis', 'Buoyancy Flux', 'Friction Velocity', 'Layer Depth']
    cluster_stats = pd.DataFrame(centroids_original, columns=feature_names)
    cluster_stats.index = [f"Cluster {i}" for i in range(n_clusters)]
    
    # Add standard deviations
    for i in range(n_clusters):
        mask = cluster_assignments == i
        cluster_features = features[mask]
        stds = np.std(cluster_features, axis=0)
        for j, col in enumerate(cluster_stats.columns):
            cluster_stats.at[f"Cluster {i}", f"{col} StdDev"] = stds[j]
    
    # Visualize cluster distributions
    print("\nVisualizaing cluster feature distributions...")
    plot_cluster_distributions(features, cluster_assignments, n_clusters, feature_names)
    
    # Visualize cluster centers
    print("\nVisualizaing cluster centers...")
    plot_cluster_centers(kmeans, feature_names, scaler)
    
    # Visualize cluster size distribution
    print("\nVisualizaing cluster size distribution...")
    plot_cluster_size_distribution(cluster_assignments, n_clusters)
    
    return cluster_assignments, scaler, kmeans, cluster_stats

In [ ]:
def visualize_shape_functions(data_load_main, cluster_assignments, n_clusters):
    """
    Visualize the shape functions for each cluster and 2D projections of clusters.
    
    Parameters:
        data_load_main: Raw data containing shape functions
        cluster_assignments: Cluster labels
        n_clusters: Number of clusters
    """
    # Extract input features
    features = data_load_main[:, :4]
    feature_names = ['Coriolis Parameter', 'Surface Buoyancy Flux', 
                     'Surface Friction Velocity', 'Boundary Layer Depth']
    
    # Visualize shape functions by cluster
    print("Visualizing mean shape functions by cluster...")
    shape_function_fig = plot_shape_functions_by_cluster(data_load_main, cluster_assignments, n_clusters)
    
    # Visualize 2D projections of features colored by cluster
    print("Visualizing 2D projections of clusters...")
    projection_fig = plot_cluster_2d_projections(features, cluster_assignments, n_clusters, feature_names)
    
    return shape_function_fig, projection_fig

In [ ]:
# Cluster the data
print("Clustering the data with", n_clusters, "clusters...")
cluster_assignments, scaler, kmeans, cluster_stats = cluster_data(original_features, n_clusters)

In [ ]:
# Visualize the clusters
shape_function_fig, projection_fig = visualize_shape_functions(data_load_main, cluster_assignments, n_clusters)

## 2.2 Train Cluster-Specific Models

Now we'll train separate models for each cluster and compare their performance with the baseline model.

In [None]:
def train_cluster_models(x, y, valid_x, valid_y, k_mean, k_std, cluster_assignments, n_clusters):
    """Train separate models for each cluster"""
    # Convert variables to numpy for indexing
    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()
    
    # Initialize a dictionary to store models
    cluster_models = {}
    cluster_losses = {}
    
    # Train a model for each cluster
    for cluster_id in range(n_clusters):
        print(f"\nTraining model for Cluster {cluster_id}")
        
        # Get data points belonging to this cluster
        cluster_mask = cluster_assignments == cluster_id
        x_cluster = torch.FloatTensor(x_np[cluster_mask]).to(device)
        y_cluster = torch.FloatTensor(y_np[cluster_mask]).to(device)
        
        print(f"Cluster {cluster_id} data shape: {x_cluster.shape}, {y_cluster.shape}")
        
        # Initialize and train model
        model = OceanMixingNN(in_nodes, hidden_nodes, out_nodes).to(device)
        model, losses = train_model(
            model, x_cluster, y_cluster, valid_x, valid_y, k_mean, k_std
        )
        
        # Store model and losses
        cluster_models[cluster_id] = model
        cluster_losses[cluster_id] = losses
    
    return cluster_models, cluster_losses

In [None]:
# Train cluster-specific models
cluster_models, cluster_losses = train_cluster_models(
    x, y, valid_x, valid_y, k_mean, k_std, cluster_assignments, n_clusters
)

In [ ]:
# This function is simple enough to keep as is
def plot_cluster_losses(cluster_losses, n_clusters):
    """Plot training and validation losses for each cluster model"""
    # Set up colors
    colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))
    
    plt.figure(figsize=(12, 8))
    
    for i in range(n_clusters):
        losses = cluster_losses[i]
        epochs = losses[:, 0]
        valid_loss = losses[:, 2]
        
        plt.plot(epochs, valid_loss, '-', color=colors[i], linewidth=2, 
                label=f'Cluster {i} Validation Loss')
        
        # Mark the best validation loss
        best_epoch = np.argmin(valid_loss)
        best_valid = valid_loss[best_epoch]
        plt.scatter(epochs[best_epoch], best_valid, c=colors[i], s=80, zorder=3)
        
    # Add baseline validation loss for comparison
    baseline_epochs = baseline_losses[:, 0]
    baseline_valid = baseline_losses[:, 2]
    plt.plot(baseline_epochs, baseline_valid, 'k-', linewidth=3, 
            label='Baseline Validation Loss')
    
    best_baseline_epoch = np.argmin(baseline_valid)
    best_baseline = baseline_valid[best_baseline_epoch]
    plt.scatter(baseline_epochs[best_baseline_epoch], best_baseline, c='k', s=100, zorder=3)
    plt.annotate(f'Baseline: {best_baseline:.4f}', 
                 (baseline_epochs[best_baseline_epoch], best_baseline),
                 xytext=(10, 10), textcoords='offset points',
                 arrowprops=dict(arrowstyle='->', color='black'))
    
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.xlabel('Epochs', fontsize=14)
    plt.ylabel('Validation Loss', fontsize=14)
    plt.title('Validation Losses by Cluster', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.tight_layout()
    
    return plt.gcf()

In [None]:
# Plot losses for each cluster model
plot_cluster_losses(cluster_losses, n_clusters)

## 2.3 Performance Evaluation

We'll now evaluate the performance of our cluster-specific models compared to the baseline model.

In [ ]:
# Generate predictions with both the baseline and cluster models
print("Generating predictions with baseline model...")
with torch.no_grad():
    baseline_preds = baseline_model(valid_x)
    
print("Generating predictions with cluster-specific models...")
cluster_preds, val_clusters = predict_with_cluster_models(valid_x, cluster_models, scaler, kmeans)

# Calculate performance metrics in original space
k_mean_tensor = torch.tensor(k_mean, dtype=torch.float32).to(device)
k_std_tensor = torch.tensor(k_std, dtype=torch.float32).to(device)

# Transform predictions back to original space
baseline_preds_orig = torch.exp(baseline_preds * k_std_tensor + k_mean_tensor)
cluster_preds_orig = torch.exp(cluster_preds * k_std_tensor + k_mean_tensor)
valid_y_orig = torch.exp(valid_y * k_std_tensor + k_mean_tensor)

# Calculate node-wise losses for both models
baseline_node_losses = []
cluster_node_losses = []

for i in range(valid_y_orig.shape[1]):
    baseline_loss = torch.mean(torch.abs(baseline_preds_orig[:, i] - valid_y_orig[:, i])).item()
    cluster_loss = torch.mean(torch.abs(cluster_preds_orig[:, i] - valid_y_orig[:, i])).item()
    
    baseline_node_losses.append(baseline_loss)
    cluster_node_losses.append(cluster_loss)

print("Overall Baseline Loss:", np.mean(baseline_node_losses))
print("Overall Cluster Model Loss:", np.mean(cluster_node_losses))
improvement = (np.mean(baseline_node_losses) - np.mean(cluster_node_losses)) / np.mean(baseline_node_losses) * 100
print(f"Improvement: {improvement:.2f}%")

In [ ]:
def visualize_model_performance(baseline_preds_orig, cluster_preds_orig, valid_y_orig, 
                               baseline_node_losses, cluster_node_losses, val_clusters):
    """Visualize performance comparison between baseline and cluster models"""
    # Convert tensors to numpy
    baseline_np = baseline_preds_orig.cpu().numpy()
    cluster_np = cluster_preds_orig.cpu().numpy()
    valid_np = valid_y_orig.cpu().numpy()
    
    # 1. Node-wise performance comparison
    print("Visualizing node-wise performance improvements...")
    node_perf_fig = plot_model_performance_comparison(baseline_node_losses, cluster_node_losses)
    
    # 2. Overall performance comparison
    print("Visualizing overall performance comparison...")
    baseline_loss = np.mean(baseline_node_losses)
    cluster_loss = np.mean(cluster_node_losses)
    overall_perf_fig = plot_overall_performance_comparison(baseline_loss, cluster_loss)
    
    # 3. Error distribution comparison
    print("Visualizing error distributions...")
    error_fig = plot_error_distributions(baseline_np, cluster_np, valid_np)
    
    # 4. Sample predictions
    print("Visualizing sample predictions...")
    # Select samples from different clusters
    unique_clusters = np.unique(val_clusters)
    sample_indices = []
    for cluster_id in unique_clusters:
        cluster_mask = val_clusters == cluster_id
        if np.sum(cluster_mask) > 0:
            samples = np.where(cluster_mask)[0]
            sample_idx = np.random.choice(samples)
            sample_indices.append(sample_idx)
    
    # Limit to 3 samples
    sample_indices = sample_indices[:3]
    sample_fig = plot_sample_predictions(sample_indices, val_clusters, 
                                        baseline_np, cluster_np, valid_np)
    
    return node_perf_fig, overall_perf_fig, error_fig, sample_fig

In [ ]:
# Create visualization of performance comparison
node_perf_fig, overall_perf_fig, error_fig, sample_fig = visualize_model_performance(
    baseline_preds_orig, cluster_preds_orig, valid_y_orig, 
    baseline_node_losses, cluster_node_losses, val_clusters
)

In [ ]:
# Combining multiple visualizations for a comprehensive performance summary
plt.figure(figsize=(15, 12))
gs = GridSpec(2, 2)

# Plot 1: Node-wise performance comparison
ax1 = plt.subplot(gs[0, 0])
plt.sca(node_perf_fig.axes[0])
plt.title('Performance Improvement by Layer', fontsize=14, fontweight='bold')

# Plot 2: Overall performance comparison
ax2 = plt.subplot(gs[0, 1])
plt.sca(overall_perf_fig.axes[0])
plt.title('Overall Model Performance', fontsize=14, fontweight='bold')

# Plot 3: Error distributions
ax3 = plt.subplot(gs[1, 0])
plt.sca(error_fig.axes[0])
plt.title('Error Distribution', fontsize=14, fontweight='bold')

# Plot 4: Sample prediction comparison
ax4 = plt.subplot(gs[1, 1])
if len(sample_fig.axes) > 0:
    plt.sca(sample_fig.axes[0])
    plt.title('Sample Prediction', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.suptitle('Cluster-Based vs. Baseline Model Performance', fontsize=18, fontweight='bold', y=1.02)
plt.show()

# 3. Discussion and Future Work

## 3.1 Key Findings

1. **Distinct Oceanic Regimes**: We identified distinct clusters in the ocean state space with different characteristic shape functions, confirming our hypothesis that ocean mixing processes vary systematically across different physical regimes.

2. **Performance Improvement**: The cluster-specific models show improvement over the baseline model, demonstrating that conditional modeling can capture more nuanced relationships between physical parameters and vertical mixing profiles.

3. **Layer-Specific Improvements**: The improvement varies across different vertical layers, suggesting that some parts of the water column benefit more from the cluster-based approach than others.

## 3.2 Implications

These findings suggest that ocean mixing parameterizations could benefit from a regime-based approach, where different models are applied depending on the oceanic conditions. This could lead to more accurate representations of vertical mixing in ocean models, ultimately improving climate simulations.

## 3.3 Future Work

1. **Optimal Clustering Approach**: Explore different clustering methods (e.g., Gaussian Mixture Models, hierarchical clustering) and compare their effectiveness.

2. **Cluster Selection Mechanism**: Develop a smooth transition method between clusters to avoid discontinuities when ocean states evolve over time.

3. **Output-Based Clustering**: Instead of clustering on input features, cluster on output shape functions to identify regimes based on their mixing behavior directly.

4. **Physical Interpretation**: Analyze cluster characteristics more deeply to understand the physical mechanisms driving different mixing regimes.

5. **Global Model Implementation**: Test the clustered approach in a full global ocean model to evaluate its impact on larger-scale ocean and climate dynamics.