In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
import os
import time
import pickle
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

np.random.seed(1234)
tf.random.set_seed(1234)

#---------------------------------------------------------------------------
# Section 1: Parameter Space Exploration with Langevin Dynamics
#---------------------------------------------------------------------------

class LangevinParameterExplorer:
    """
    Class for analyzing parameter space exploration with Langevin dynamics
    """
    
    def __init__(self, models, trainer):
        """
        Initialize the parameter explorer
        
        Parameters:
        -----------
        models : list
            List of PINN models
        trainer : EntropyLangevinPINNTrainer
            Trainer object with Entropy-Langevin algorithm
        """
        self.models = models
        self.trainer = trainer
        self.num_models = len(models)
        
        self.parameter_trajectories = []
        self.parameter_statistics = []
        
        self.loss_trajectories = []
    
    def extract_parameters(self, model_idx):
        """
        Extract flattened parameters from a model
        
        Parameters:
        -----------
        model_idx : int
            Index of the model
            
        Returns:
        --------
        np.ndarray
            Flattened parameters
        """
        model = self.models[model_idx]
        params = []
        
        for layer in model.layers_list:
            weights = layer.get_weights()
            for w in weights:
                params.append(w.flatten())
        
        return np.concatenate(params)
    
    def record_parameter_snapshot(self):
        """
        Record parameter snapshot for all models
        """
        snapshot = []
        
        for i in range(self.num_models):
            params = self.extract_parameters(i)
            snapshot.append(params)
        
        snapshot = np.array(snapshot)
        mean_params = np.mean(snapshot, axis=0)
        std_params = np.std(snapshot, axis=0)
        
        self.parameter_statistics.append({
            'mean': mean_params,
            'std': std_params,
            'min': np.min(snapshot, axis=0),
            'max': np.max(snapshot, axis=0),
            'median': np.median(snapshot, axis=0)
        })
        
        if snapshot.shape[1] < 10000:  
            self.parameter_trajectories.append(snapshot)
    
    def record_loss_snapshot(self, losses):
        """
        Record loss snapshot for all models
        
        Parameters:
        -----------
        losses : dict
            Dictionary with loss values
        """
        self.loss_trajectories.append(losses)
    
    def visualize_parameter_diversity(self, n_components=2, method='pca'):
        """
        Visualize parameter diversity using dimensionality reduction
        
        Parameters:
        -----------
        n_components : int
            Number of components for dimensionality reduction
        method : str
            Method for dimensionality reduction: 'pca' or 'tsne'
        """
        if len(self.parameter_trajectories) == 0:
            raise ValueError("No parameter trajectories recorded.")
        
        # Select a subset of snapshots (every 10th) to avoid clutter
        snapshot_indices = np.arange(0, len(self.parameter_trajectories), 10)
        if snapshot_indices[-1] != len(self.parameter_trajectories) - 1:
            snapshot_indices = np.append(snapshot_indices, len(self.parameter_trajectories) - 1)
        
        # Extract snapshots
        snapshots_data = [self.parameter_trajectories[i] for i in snapshot_indices]
        
        # Flatten data for dimensionality reduction
        X = np.vstack([snapshot.reshape(self.num_models, -1) for snapshot in snapshots_data])
        
        # Apply dimensionality reduction
        if method == 'pca':
            reducer = PCA(n_components=n_components)
            X_reduced = reducer.fit_transform(X)
            print(f"Explained variance ratio: {reducer.explained_variance_ratio_}")
        elif method == 'tsne':
            reducer = TSNE(n_components=n_components, perplexity=min(30, self.num_models*len(snapshot_indices)-1))
            X_reduced = reducer.fit_transform(X)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        # Reshape to get trajectory for each model
        X_reduced = X_reduced.reshape(len(snapshot_indices), self.num_models, n_components)
        
        # Visualization
        plt.figure(figsize=(12, 10))
        
        # Plot parameter space trajectories
        colors = plt.cm.tab10(np.linspace(0, 1, self.num_models))
        
        for i in range(self.num_models):
            plt.plot(X_reduced[:, i, 0], X_reduced[:, i, 1], '-', color=colors[i], 
                     alpha=0.7, linewidth=1.5, label=f'Model {i+1}')
            plt.plot(X_reduced[0, i, 0], X_reduced[0, i, 1], 'o', color=colors[i], markersize=8)
            plt.plot(X_reduced[-1, i, 0], X_reduced[-1, i, 1], 's', color=colors[i], markersize=8)
        
        # Add start and end markers for the full ensemble
        plt.plot([], [], 'ko', markersize=8, label='Start')
        plt.plot([], [], 'ks', markersize=8, label='End')
        
        plt.title(f'Parameter Space Exploration with {method.upper()} ({n_components} components)', fontsize=16)
        plt.xlabel(f'Component 1', fontsize=14)
        plt.ylabel(f'Component 2', fontsize=14)
        plt.legend(fontsize=12)
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'figures/parameter_space_{method}.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_parameter_statistics(self):
        """
        Plot statistics of parameters over training iterations
        """
        if len(self.parameter_statistics) == 0:
            raise ValueError("No parameter statistics recorded.")
        
        # Extract statistics
        iterations = range(1, len(self.parameter_statistics) + 1)
        
        # Compute aggregate statistics over all parameters
        mean_std = [np.mean(stats['std']) for stats in self.parameter_statistics]
        max_std = [np.max(stats['std']) for stats in self.parameter_statistics]
        
        max_range = [np.mean(stats['max'] - stats['min']) for stats in self.parameter_statistics]
        
        # Plot statistics
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
        
        # Plot standard deviation
        ax1.plot(iterations, mean_std, 'b-', linewidth=2, label='Mean Std Dev')
        ax1.plot(iterations, max_std, 'r--', linewidth=1.5, label='Max Std Dev')
        ax1.set_ylabel('Standard Deviation', fontsize=14)
        ax1.set_title('Parameter Diversity Over Training', fontsize=16)
        ax1.legend(fontsize=12)
        ax1.grid(True, alpha=0.3)
        
        # Plot parameter range
        ax2.plot(iterations, max_range, 'g-', linewidth=2, label='Mean Parameter Range')
        ax2.set_xlabel('Training Iteration', fontsize=14)
        ax2.set_ylabel('Parameter Range', fontsize=14)
        ax2.legend(fontsize=12)
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('figures/parameter_statistics.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_loss_correlation(self):
        """
        Plot correlation between parameter diversity and loss
        """
        if len(self.parameter_statistics) == 0 or len(self.loss_trajectories) == 0:
            raise ValueError("No parameter statistics or loss trajectories recorded.")
        
        # Ensure same length
        n = min(len(self.parameter_statistics), len(self.loss_trajectories))
        
        # Extract statistics
        iterations = range(1, n + 1)
        
        # Parameter diversity metrics
        mean_std = [np.mean(self.parameter_statistics[i]['std']) for i in range(n)]
        
        # Loss metrics (assuming total loss is available)
        total_loss = [np.mean(self.loss_trajectories[i]['total']) for i in range(n)]
        
        # Create scatter plot
        plt.figure(figsize=(10, 8))
        
        plt.scatter(mean_std, total_loss, c=iterations, cmap='viridis', 
                   s=50, alpha=0.8, edgecolors='k', linewidths=0.5)
        
        plt.colorbar(label='Training Iteration')
        plt.xlabel('Mean Parameter Standard Deviation', fontsize=14)
        plt.ylabel('Mean Total Loss', fontsize=14)
        plt.title('Correlation Between Parameter Diversity and Loss', fontsize=16)
        plt.grid(True, alpha=0.3)
        
        # Add trend line
        z = np.polyfit(mean_std, total_loss, 1)
        p = np.poly1d(z)
        plt.plot(sorted(mean_std), p(sorted(mean_std)), 'r--', linewidth=2)
        
        plt.tight_layout()
        plt.savefig('figures/diversity_loss_correlation.png', dpi=300, bbox_inches='tight')
        plt.show()

#---------------------------------------------------------------------------
# Section 2: Entropy-Langevin Hyperparameter Analysis
#---------------------------------------------------------------------------

class EntropyLangevinAnalyzer:
    """
    Class for analyzing the effects of Entropy-Langevin hyperparameters
    """
    
    def __init__(self, domain_bounds, phys_params):
        """
        Initialize the analyzer
        
        Parameters:
        -----------
        domain_bounds : dict
            Dictionary with domain bounds
        phys_params : CVDPhysicalParams
            Object containing physical parameters
        """
        self.domain_bounds = domain_bounds
        self.phys_params = phys_params
        
        # Results storage
        self.results = {
            'alpha': [],
            'beta': [],
            'ensemble_size': [],
            'final_loss': [],
            'param_diversity': [],
            'convergence_rate': []
        }
    
    def run_experiment(self, alpha_values, beta_values, ensemble_sizes, n_epochs=500, n_points=1000):
        """
        Run experiments with different hyperparameter values
        
        Parameters:
        -----------
        alpha_values : list
            List of alpha values to test
        beta_values : list
            List of beta values to test
        ensemble_sizes : list
            List of ensemble sizes to test
        n_epochs : int
            Number of epochs for each experiment
        n_points : int
            Number of training points
            
        Returns:
        --------
        dict
            Experiment results
        """
        print("Starting Entropy-Langevin hyperparameter analysis...")
        start_time = time.time()
        
        # Create data generator once
        data_generator = CVDDataGenerator(self.domain_bounds)
        
        # Generate training data once
        x_collocation = data_generator.generate_collocation_points(n_points)
        x_collocation = tf.convert_to_tensor(x_collocation, dtype=tf.float32)
        
        boundary_points = data_generator.generate_boundary_points(n_points // 10)
        # Convert to tensors
        boundary_points_tensor = {}
        for key in boundary_points:
            boundary_points_tensor[key] = tf.convert_to_tensor(boundary_points[key], dtype=tf.float32)
        
        initial_points = data_generator.generate_initial_points(n_points // 10)
        initial_points_tensor = tf.convert_to_tensor(initial_points, dtype=tf.float32)
        
        # Iterate over all combinations
        total_experiments = len(alpha_values) * len(beta_values) * len(ensemble_sizes)
        experiment_count = 0
        
        for alpha in alpha_values:
            for beta in beta_values:
                for ensemble_size in ensemble_sizes:
                    experiment_count += 1
                    print(f"\nExperiment {experiment_count}/{total_experiments}: "
                          f"alpha={alpha}, beta={beta}, ensemble_size={ensemble_size}")
                    
                    # Create models
                    models = create_model_ensemble(num_models=ensemble_size)
                    
                    # Create trainer
                    trainer = EntropyLangevinPINNTrainer(
                        models, self.phys_params, self.domain_bounds,
                        alpha=alpha, beta=beta, learning_rate=1e-3
                    )
                    
                    # Create parameter explorer
                    explorer = LangevinParameterExplorer(models, trainer)
                    
                    # Record initial parameter snapshot
                    explorer.record_parameter_snapshot()
                    
                    # Train with small number of epochs to save time
                    print(f"Training for {n_epochs} epochs...")
                    
                    # Epochs loop
                    losses_history = []
                    for epoch in range(n_epochs):
                        # Update entropy-Langevin parameters
                        trainer.entropy_reg.update_parameters(epoch, n_epochs)
                        
                        # Perform one training step
                        total_losses, pde_losses, bc_losses, ic_losses = trainer.train_step(
                            epoch, x_collocation, boundary_points_tensor, initial_points_tensor
                        )
                        
                        avg_total_loss = tf.reduce_mean(total_losses).numpy()
                        losses_history.append(avg_total_loss)
                        
                        if (epoch + 1) % (n_epochs // 10) == 0:
                            print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_total_loss:.6e}")
                            
                            # Record parameter snapshot every 10% of epochs
                            explorer.record_parameter_snapshot()
                            explorer.record_loss_snapshot({
                                'total': total_losses.numpy(),
                                'pde': pde_losses.numpy(),
                                'bc': bc_losses.numpy(),
                                'ic': ic_losses.numpy()
                            })
                    
                    # Measure final loss
                    final_loss = losses_history[-1]
                    
                    # Measure parameter diversity
                    param_diversity = np.mean(explorer.parameter_statistics[-1]['std'])
                    
                    # Measure convergence rate (simplified as loss reduction over time)
                    initial_loss = losses_history[0]
                    convergence_rate = (initial_loss - final_loss) / initial_loss if initial_loss > 0 else 0
                    
                    # Store results
                    self.results['alpha'].append(alpha)
                    self.results['beta'].append(beta)
                    self.results['ensemble_size'].append(ensemble_size)
                    self.results['final_loss'].append(final_loss)
                    self.results['param_diversity'].append(param_diversity)
                    self.results['convergence_rate'].append(convergence_rate)
                    
                    # Save parameter explorer for later analysis
                    with open(f'models/explorer_a{alpha}_b{beta}_e{ensemble_size}.pkl', 'wb') as f:
                        pickle.dump(explorer, f)
        
        total_time = time.time() - start_time
        print(f"Analysis completed in {total_time:.2f} seconds.")
        
        # Save results
        with open('models/hyperparameter_analysis.pkl', 'wb') as f:
            pickle.dump(self.results, f)
        
        return self.results
    
    def visualize_results(self):
        """
        Visualize hyperparameter analysis results
        """
        if len(self.results['alpha']) == 0:
            try:
                with open('models/hyperparameter_analysis.pkl', 'rb') as f:
                    self.results = pickle.load(f)
            except:
                raise ValueError("No results found. Run experiments first.")
        
        # Convert to DataFrame for easier analysis
        df = pd.DataFrame(self.results)
        
        # Visualize the effect of alpha and beta on final loss
        plt.figure(figsize=(12, 10))
        pivot_table = df.pivot_table(
            index='alpha', columns='beta', values='final_loss', aggfunc='mean'
        )
        sns.heatmap(pivot_table, annot=True, cmap='viridis_r', fmt='.2e',
                   cbar_kws={'label': 'Final Loss'})
        plt.title('Effect of Alpha and Beta on Final Loss', fontsize=16)
        plt.xlabel('Beta (Inverse Temperature)', fontsize=14)
        plt.ylabel('Alpha (Entropy Weight)', fontsize=14)
        plt.tight_layout()
        plt.savefig('figures/alpha_beta_loss.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Visualize the effect of alpha and beta on parameter diversity
        plt.figure(figsize=(12, 10))
        pivot_table = df.pivot_table(
            index='alpha', columns='beta', values='param_diversity', aggfunc='mean'
        )
        sns.heatmap(pivot_table, annot=True, cmap='plasma', fmt='.4f',
                   cbar_kws={'label': 'Parameter Diversity'})
        plt.title('Effect of Alpha and Beta on Parameter Diversity', fontsize=16)
        plt.xlabel('Beta (Inverse Temperature)', fontsize=14)
        plt.ylabel('Alpha (Entropy Weight)', fontsize=14)
        plt.tight_layout()
        plt.savefig('figures/alpha_beta_diversity.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Visualize the effect of ensemble size on convergence rate
        plt.figure(figsize=(10, 6))
        sns.barplot(x='ensemble_size', y='convergence_rate', data=df)
        plt.title('Effect of Ensemble Size on Convergence Rate', fontsize=16)
        plt.xlabel('Ensemble Size', fontsize=14)
        plt.ylabel('Convergence Rate', fontsize=14)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig('figures/ensemble_convergence.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Scatter plot of parameter diversity vs final loss
        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(df['param_diversity'], df['final_loss'], 
                             c=df['alpha'], s=df['ensemble_size']*10, 
                             alpha=0.7, cmap='viridis', edgecolors='k', linewidths=0.5)
        plt.colorbar(scatter, label='Alpha Value')
        plt.xlabel('Parameter Diversity', fontsize=14)
        plt.ylabel('Final Loss', fontsize=14)
        plt.title('Relationship Between Parameter Diversity and Final Loss', fontsize=16)
        plt.grid(True, alpha=0.3)
        plt.xscale('log')
        plt.yscale('log')
        
        # Add trend line
        z = np.polyfit(np.log(df['param_diversity']), np.log(df['final_loss']), 1)
        p = np.poly1d(z)
        x_trend = np.logspace(np.log10(df['param_diversity'].min()), 
                             np.log10(df['param_diversity'].max()), 100)
        y_trend = np.exp(p(np.log(x_trend)))
        plt.plot(x_trend, y_trend, 'r--', linewidth=2)
        
        plt.tight_layout()
        plt.savefig('figures/diversity_loss_relationship.png', dpi=300, bbox_inches='tight')
        plt.show()

#---------------------------------------------------------------------------
# Section 3: Fokker-Planck Analysis of Parameter Distributions
#---------------------------------------------------------------------------

class FokkerPlanckAnalyzer:
    """
    Class for analyzing parameter distributions through Fokker-Planck lens
    """
    
    def __init__(self, parameter_explorer):
        """
        Initialize the Fokker-Planck analyzer
        
        Parameters:
        -----------
        parameter_explorer : LangevinParameterExplorer
            Object with parameter trajectories
        """
        self.explorer = parameter_explorer
        
        # Check if we have trajectory data
        if len(self.explorer.parameter_trajectories) == 0:
            raise ValueError("No parameter trajectories available in the explorer.")
    
    def compute_kernel_density(self, snapshot_idx, param_idx, bandwidth=0.1):
        """
        Compute kernel density estimation for parameter distribution
        
        Parameters:
        -----------
        snapshot_idx : int
            Index of the parameter snapshot
        param_idx : int
            Index of parameter subset to analyze
        bandwidth : float
            Bandwidth for kernel density estimation
            
        Returns:
        --------
        tuple
            (x_grid, density_values)
        """
        # Extract parameters from snapshot
        params = self.explorer.parameter_trajectories[snapshot_idx]
        
        # Extract specific parameter across all models
        param_values = params[:, param_idx]
        
        # Use TensorFlow Probability for KDE
        kde = tfp.distributions.KernelDensity(
            bandwidth=bandwidth,
            kernel='gaussian',
            dtype=tf.float32
        )
        
        # Fit KDE
        kde = kde.fit(param_values[:, np.newaxis])
        
        # Create grid for evaluation
        min_val = np.min(param_values) - 2 * bandwidth
        max_val = np.max(param_values) + 2 * bandwidth
        x_grid = np.linspace(min_val, max_val, 1000)
        
        # Evaluate KDE
        log_density = kde.log_prob(x_grid[:, np.newaxis])
        density = tf.exp(log_density).numpy()
        
        return x_grid, density
    
    def visualize_parameter_evolution(self, param_idx=0, n_snapshots=5):
        """
        Visualize evolution of parameter distribution over time
        
        Parameters:
        -----------
        param_idx : int
            Index of parameter subset to analyze
        n_snapshots : int
            Number of snapshots to visualize
        """
        # Select snapshots (evenly spaced)
        total_snapshots = len(self.explorer.parameter_trajectories)
        snapshot_indices = np.linspace(0, total_snapshots-1, n_snapshots, dtype=int)
        
        # Set up figure
        plt.figure(figsize=(12, 8))
        
        # Colors for different snapshots
        colors = plt.cm.viridis(np.linspace(0, 1, n_snapshots))
        
        # Compute and plot KDEs for each snapshot
        for i, idx in enumerate(snapshot_indices):
            x_grid, density = self.compute_kernel_density(idx, param_idx)
            
            # Scale density for better visualization
            scaled_density = density / np.max(density)
            
            plt.plot(x_grid, scaled_density, '-', color=colors[i], linewidth=2, 
                    label=f'Iteration {idx+1}')
            
            # Add scatter plot of actual parameter values
            params = self.explorer.parameter_trajectories[idx][:, param_idx]
            plt.scatter(params, np.zeros_like(params) + 0.05*i, color=colors[i], 
                       alpha=0.7, s=30, marker='|')
        
        plt.xlabel('Parameter Value', fontsize=14)
        plt.ylabel('Normalized Density', fontsize=14)
        plt.title(f'Evolution of Parameter Distribution Over Training', fontsize=16)
        plt.legend(fontsize=12)
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'figures/parameter_evolution_p{param_idx}.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def analyze_fokker_planck_dynamics(self, param_idx=0):
        """
        Analyze parameter dynamics through Fokker-Planck equation
        
        Parameters:
        -----------
        param_idx : int
            Index of parameter subset to analyze
        """
        # Need at least 3 snapshots for drift and diffusion estimation
        if len(self.explorer.parameter_trajectories) < 3:
            raise ValueError("Need at least 3 snapshots for drift and diffusion estimation.")
        
        # Extract parameter values for all snapshots
        param_values = []
        for snapshot in self.explorer.parameter_trajectories:
            param_values.append(snapshot[:, param_idx])
        
        # Convert to numpy array
        param_values = np.array(param_values)
        
        # Compute drift (first moment)
        # Drift = E[dx/dt] ≈ (x_{t+1} - x_t) / dt
        drift = np.zeros(len(param_values) - 1)
        for t in range(len(param_values) - 1):
            drift[t] = np.mean(param_values[t+1] - param_values[t])
        
        # Compute diffusion (second moment)
        # Diffusion = E[(dx/dt)²] / 2 ≈ (x_{t+1} - x_t)² / (2*dt)
        diffusion = np.zeros(len(param_values) - 1)
        for t in range(len(param_values) - 1):
            diffusion[t] = np.mean((param_values[t+1] - param_values[t])**2) / 2
        
        # Use mean parameter value for each snapshot
        mean_param_values = np.mean(param_values, axis=1)
        
        # Plot drift and diffusion vs parameter value
        plt.figure(figsize=(12, 10))
        
        plt.subplot(2, 1, 1)
        plt.plot(range(1, len(drift)+1), drift, 'b-o', linewidth=2, markersize=6)
        plt.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        plt.xlabel('Training Iteration', fontsize=14)
        plt.ylabel('Drift Term', fontsize=14)
        plt.title('Drift Term in Fokker-Planck Equation', fontsize=16)
        plt.grid(True, alpha=0.3)
        
        plt.subplot(2, 1, 2)
        plt.plot(range(1, len(diffusion)+1), diffusion, 'r-o', linewidth=2, markersize=6)
        plt.xlabel('Training Iteration', fontsize=14)
        plt.ylabel('Diffusion Term', fontsize=14)
        plt.title('Diffusion Term in Fokker-Planck Equation', fontsize=16)
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'figures/fokker_planck_p{param_idx}.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        # Plot drift and diffusion together to show relationship
        plt.figure(figsize=(10, 8))
        plt.scatter(drift, diffusion, c=range(len(drift)), cmap='viridis', 
                   s=80, alpha=0.8, edgecolors='k', linewidths=0.5)
        plt.colorbar(label='Training Iteration')
        plt.xlabel('Drift Term', fontsize=14)
        plt.ylabel('Diffusion Term', fontsize=14)
        plt.title('Relationship Between Drift and Diffusion in Parameter Space', fontsize=16)
        plt.grid(True, alpha=0.3)
        
        # Add trend line
        if len(drift) > 1:
            z = np.polyfit(drift, diffusion, 1)
            p = np.poly1d(z)
            plt.plot(sorted(drift), p(sorted(drift)), 'r--', linewidth=2)
            
            # Add correlation coefficient
            corr_coef = np.corrcoef(drift, diffusion)[0, 1]
            plt.annotate(f'Correlation: {corr_coef:.4f}', 
                        xy=(0.05, 0.95), xycoords='axes fraction',
                        fontsize=12, bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="k", alpha=0.8))
        
        plt.tight_layout()
        plt.savefig(f'figures/drift_diffusion_relation_p{param_idx}.png', dpi=300, bbox_inches='tight')
        plt.show()

#---------------------------------------------------------------------------
# Section 4: Demo and Testing
#---------------------------------------------------------------------------

def demo_langevin_analysis(load_existing=True):
    """
    Demonstrate Langevin dynamics and entropy regularization analysis
    
    Parameters:
    -----------
    load_existing : bool
        Whether to load existing data or run new experiments
    """
    # Define domain bounds
    domain_bounds = {
        'x_min': 0.0,
        'x_max': 0.1,
        'y_min': 0.0,
        'y_max': 0.05,
        't_min': 0.0,
        't_max': 10.0
    }
    
    # Create physical parameters
    phys_params = CVDPhysicalParams()
    
    if load_existing:
        try:
            # Load explorer from file
            with open('models/explorer_example.pkl', 'rb') as f:
                explorer = pickle.load(f)
            print("Loaded existing parameter explorer data.")
        except:
            print("No existing explorer data found. Running a small experiment to generate data.")
            load_existing = False
    
    if not load_existing:
        # Create models
        ensemble_size = 5
        models = create_model_ensemble(num_models=ensemble_size)
        
        # Create trainer
        trainer = EntropyLangevinPINNTrainer(
            models, phys_params, domain_bounds,
            alpha=0.2, beta=5.0, learning_rate=1e-3
        )
        
        # Create parameter explorer
        explorer = LangevinParameterExplorer(models, trainer)
        
        # Run a small training experiment
        print("Running a small training experiment to collect parameter data...")
        n_epochs = 200
        
        # Create data generator
        data_generator = CVDDataGenerator(domain_bounds)
        
        # Generate training data
        x_collocation = data_generator.generate_collocation_points(1000)
        x_collocation = tf.convert_to_tensor(x_collocation, dtype=tf.float32)
        
        boundary_points = data_generator.generate_boundary_points(100)
        # Convert to tensors
        boundary_points_tensor = {}
        for key in boundary_points:
            boundary_points_tensor[key] = tf.convert_to_tensor(boundary_points[key], dtype=tf.float32)
        
        initial_points = data_generator.generate_initial_points(100)
        initial_points_tensor = tf.convert_to_tensor(initial_points, dtype=tf.float32)
        
        # Record initial parameter snapshot
        explorer.record_parameter_snapshot()
        
        # Train and record parameter snapshots
        for epoch in range(n_epochs):
            # Update entropy-Langevin parameters
            trainer.entropy_reg.update_parameters(epoch, n_epochs)
            
            # Perform one training step
            total_losses, pde_losses, bc_losses, ic_losses = trainer.train_step(
                epoch, x_collocation, boundary_points_tensor, initial_points_tensor
            )
            
            avg_total_loss = tf.reduce_mean(total_losses).numpy()
            
            if (epoch + 1) % 20 == 0:
                print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_total_loss:.6e}")
                
                # Record parameter snapshot
                explorer.record_parameter_snapshot()
                explorer.record_loss_snapshot({
                    'total': total_losses.numpy(),
                    'pde': pde_losses.numpy(),
                    'bc': bc_losses.numpy(),
                    'ic': ic_losses.numpy()
                })
        
        # Save explorer for future use
        with open('models/explorer_example.pkl', 'wb') as f:
            pickle.dump(explorer, f)
    
    # Analyze parameter space exploration
    print("\nAnalyzing parameter space exploration...")
    explorer.visualize_parameter_diversity(method='pca')
    explorer.visualize_parameter_diversity(method='tsne')
    explorer.plot_parameter_statistics()
    explorer.plot_loss_correlation()
    
    # Analyze parameter distributions with Fokker-Planck
    try:
        print("\nAnalyzing parameter distributions with Fokker-Planck...")
        fp_analyzer = FokkerPlanckAnalyzer(explorer)
        fp_analyzer.visualize_parameter_evolution(param_idx=0)
        fp_analyzer.visualize_parameter_evolution(param_idx=100)
        fp_analyzer.analyze_fokker_planck_dynamics(param_idx=0)
    except Exception as e:
        print(f"Error in Fokker-Planck analysis: {e}")
    
    # Hyperparameter analysis
    if not load_existing:
        print("\nRunning hyperparameter analysis (limited scope for demo)...")
        analyzer = EntropyLangevinAnalyzer(domain_bounds, phys_params)
        
        # Run very limited experiments for demo purposes
        results = analyzer.run_experiment(
            alpha_values=[0.1, 0.2],
            beta_values=[5.0, 10.0],
            ensemble_sizes=[3, 5],
            n_epochs=100,
            n_points=500
        )
        
        analyzer.visualize_results()
    else:
        try:
            # Load existing results
            analyzer = EntropyLangevinAnalyzer(domain_bounds, phys_params)
            analyzer.visualize_results()
        except:
            print("No existing hyperparameter analysis results found.")

# Only execute demo if explicitly requested (to avoid long computation times)
if __name__ == "__main__":
    print("This notebook analyzes Langevin dynamics and entropy regularization for PINNs.")
    print("To run the demo, execute: demo_langevin_analysis(load_existing=True)")