In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
import os
import time
import pickle
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.spatial import KDTree

import import_ipynb
import Mathematical_Framework_Setup as nb1
import Entropy_Langevin_Training as nb2





TensorFlow version: 2.19.0
TensorFlow Probability version: 0.25.0
No GPUs available, using CPU

Physical parameters initialized:
Diffusion coefficients: D_SiH4 = 1e-05, D_Si = 5e-06, D_H2 = 4e-05, D_SiH2 = 1.5e-05
Thermal parameters: k = 0.1, Cp = 700.0, ρ = 1.0
Reaction parameters: A1 = 1000000.0, E1 = 150000.0, A2 = 200000.0, E2 = 120000.0, A3 = 300000.0, E3 = 100000.0
Gas constant: R = 8.314

Domain bounds:
x_min = 0.0
x_max = 0.1
y_min = 0.0
y_max = 0.05
t_min = 0.0
t_max = 10.0

Test input shape: (5, 3)
Test output shape: (5, 5)

Computed gradients:
dy_dx shape: (5, 5, 3)
y_x shape: (5, 5)
y_y shape: (5, 5)
y_t shape: (5, 5)
y_xx shape: (5, 5)
y_yy shape: (5, 5)

Computed residuals:
Residual 1 shape: (5, 1)
Residual 2 shape: (5, 1)
Residual 3 shape: (5, 1)
Residual 4 shape: (5, 1)
Residual 5 shape: (5, 1)

Generated collocation points shape: (10, 3)
Generated boundary points:
inlet shape: (5, 3)
substrate shape: (5, 3)
left_wall shape: (5, 3)
right_wall shape: (5, 3)
Generated 

In [2]:

# Set random seed for reproducibility
np.random.seed(1234)
tf.random.set_seed(1234)

#---------------------------------------------------------------------------
# Section 1: Residual-Based Adaptive Sampling
#---------------------------------------------------------------------------

class ResidualBasedSampler:
    """
    Class for residual-based adaptive sampling technique.
    This focusses computational resources on regions with high PDE residual.
    """
    
    def __init__(self, model, pde_calculator, domain_bounds, max_points= 10000):
        """
        Initialize the residual-based sampler

        Args:
            model: PINN or List of PINNs
            The PINN model to compute residual
            pde_calculator: CVDPDE
            Object to compute PDEs residuals
            domain_bounds: dict 
            Dictionary with domain bounds
            max_points: int
            maximum number of points to store
        """
        self.model = model
        self.is_ensemble = isinstance(model, list)
        self.pde_calculator = pde_calculator
        self.domain_bounds = domain_bounds
        self.max_points = max_points
        
        # Initialize points and residuals
        self.points = None
        self.residuals = None
        
    def compute_residuals(self, points):
        """
        Compute PDE residuals at given points
        
        Parameters:
        -----------
        points : np.ndarray or tf.Tensor
            Points to compute residuals at
            
        Returns:
        --------
        np.ndarray
            Residual magnitudes at each point
        """
        # Convert to tensor if numpy array
        if isinstance(points, np.ndarray):
            points = tf.convert_to_tensor(points, dtype=tf.float32)
            
        # Compute residuals
        if self.is_ensemble:
            # For ensemble, compute average residuals
            all_residuals = []
            
            for model in self.model:
                # Get model predictions
                y_pred = model(points)
                
                # Get derivatives
                derivatives = model.get_gradients(points, y_pred)
                
                # Compute PDE residuals
                residuals = self.pde_calculator.compute_residuals(
                    points, y_pred, derivatives
                )
                
                # Compute total residual magnitude
                residual_magnitude = tf.sqrt(
                    tf.square(residuals[0]) + 
                    tf.square(residuals[1]) + 
                    tf.square(residuals[2]) + 
                    tf.square(residuals[3]) + 
                    tf.square(residuals[4])
                )
                
                all_residuals.append(residual_magnitude)
            
            # Compute average residual
            residual_magnitude = tf.reduce_mean(tf.stack(all_residuals, axis=0), axis=0)
        else:
            # For single model
            # Get model predictions
            y_pred = self.model(points)
            
            # Get derivatives
            derivatives = self.model.get_gradients(points, y_pred)
            
            # Compute PDE residuals
            residuals = self.pde_calculator.compute_residuals(
                points, y_pred, derivatives
            )
            
            # Compute total residual magnitude
            residual_magnitude = tf.sqrt(
                tf.square(residuals[0]) + 
                tf.square(residuals[1]) + 
                tf.square(residuals[2]) + 
                tf.square(residuals[3]) + 
                tf.square(residuals[4])
            )
        
        return residual_magnitude.numpy()
    
    def initialize_points(self, n_points):
        """
        Initialize with random points in the domain
        
        Parameters:
        -----------
        n_points : int
            Number of initial points
        """
        # Generate random points
        x = np.random.uniform(self.domain_bounds['x_min'], self.domain_bounds['x_max'], n_points)
        y = np.random.uniform(self.domain_bounds['y_min'], self.domain_bounds['y_max'], n_points)
        t = np.random.uniform(self.domain_bounds['t_min'], self.domain_bounds['t_max'], n_points)
        
        # Stack coordinates
        self.points = np.stack([x, y, t], axis=1)
        
        # Compute residuals
        self.residuals = self.compute_residuals(self.points)
        
        print(f"Initialized {n_points} points with random sampling")
    
    def get_training_points(self, n_points, method='residual_weighted'):
        """
        Get training points based on residuals
        
        Parameters:
        -----------
        n_points : int
            Number of points to select
        method : str
            Sampling method: 'residual_weighted', 'top_residual', or 'mixed'
            
        Returns:
        --------
        np.ndarray
            Selected training points
        """
        if self.points is None:
            raise ValueError("Points not initialized. Call initialize_points() first.")
        
        if method == 'residual_weighted':
            # Normalize residuals to create a probability distribution
            probs = self.residuals / np.sum(self.residuals)
            
            # Sample indices based on residual magnitudes
            indices = np.random.choice(
                len(self.points), 
                size=n_points, 
                replace=True,  # Allow replacement for true importance sampling
                p=probs
            )
            
            # Return selected points
            return self.points[indices]
        
        elif method == 'top_residual':
            # Select points with highest residuals
            indices = np.argsort(self.residuals)[-n_points:]
            
            # Return selected points
            return self.points[indices]
        
        elif method == 'mixed':
            # Mix of random and residual-based sampling
            n_random = n_points // 2
            n_residual = n_points - n_random
            
            # Normalize residuals to create a probability distribution
            probs = self.residuals / np.sum(self.residuals)
            
            # Sample indices based on residual magnitudes
            residual_indices = np.random.choice(
                len(self.points), 
                size=n_residual, 
                replace=True,
                p=probs
            )
            
            # Sample random indices
            random_indices = np.random.choice(
                len(self.points), 
                size=n_random, 
                replace=True
            )
            
            # Combine indices
            indices = np.concatenate([residual_indices, random_indices])
            
            # Return selected points
            return self.points[indices]
        
        else:
            raise ValueError(f"Unknown sampling method: {method}")
    
    def update_points(self, n_new_points, n_keep=None):
        """
        Update the point database with new samples
        
        Parameters:
        -----------
        n_new_points : int
            Number of new points to generate
        n_keep : int
            Number of existing points to keep (None = keep all)
        """
        # Generate new points
        x = np.random.uniform(self.domain_bounds['x_min'], self.domain_bounds['x_max'], n_new_points)
        y = np.random.uniform(self.domain_bounds['y_min'], self.domain_bounds['y_max'], n_new_points)
        t = np.random.uniform(self.domain_bounds['t_min'], self.domain_bounds['t_max'], n_new_points)
        
        # Stack coordinates
        new_points = np.stack([x, y, t], axis=1)
        
        # Compute residuals for new points
        new_residuals = self.compute_residuals(new_points)
        
        # Combine with existing points
        if self.points is not None:
            if n_keep is not None:
                # Keep only top n_keep points from existing database
                indices = np.argsort(self.residuals)[-n_keep:]
                self.points = self.points[indices]
                self.residuals = self.residuals[indices]
            
            # Combine with new points
            self.points = np.vstack([self.points, new_points])
            self.residuals = np.concatenate([self.residuals, new_residuals])
            
            # Check if exceeding maximum size
            if len(self.points) > self.max_points:
                # Keep top max_points with highest residuals
                indices = np.argsort(self.residuals)[-self.max_points:]
                self.points = self.points[indices]
                self.residuals = self.residuals[indices]
        else:
            # First time adding points
            self.points = new_points
            self.residuals = new_residuals
        
        print(f"Updated point database. Now contains {len(self.points)} points.")
    
    def refine_near_high_residuals(self, n_refine, n_per_point=10, refine_radius=0.01):
        """
        Generate new points near locations with high residuals
        
        Parameters:
        -----------
        n_refine : int
            Number of high-residual points to refine around
        n_per_point : int
            Number of new points to generate around each high-residual point
        refine_radius : float
            Radius around each point for refinement
        """
        if self.points is None:
            raise ValueError("Points not initialized. Call initialize_points() first.")
        
        # Select points with highest residuals
        indices = np.argsort(self.residuals)[-n_refine:]
        high_residual_points = self.points[indices]
        
        # Generate new points around each high-residual point
        new_points = []
        
        for point in high_residual_points:
            # Generate random perturbations
            dx = np.random.uniform(-refine_radius, refine_radius, n_per_point)
            dy = np.random.uniform(-refine_radius, refine_radius, n_per_point)
            dt = np.random.uniform(-refine_radius, refine_radius, n_per_point)
            
            # Create new points
            x = np.clip(point[0] + dx, self.domain_bounds['x_min'], self.domain_bounds['x_max'])
            y = np.clip(point[1] + dy, self.domain_bounds['y_min'], self.domain_bounds['y_max'])
            t = np.clip(point[2] + dt, self.domain_bounds['t_min'], self.domain_bounds['t_max'])
            
            # Stack coordinates
            refined_points = np.stack([x, y, t], axis=1)
            new_points.append(refined_points)
        
        # Combine all new points
        new_points = np.vstack(new_points)
        
        # Compute residuals for new points
        new_residuals = self.compute_residuals(new_points)
        
        # Combine with existing points
        self.points = np.vstack([self.points, new_points])
        self.residuals = np.concatenate([self.residuals, new_residuals])
        
        # Check if exceeding maximum size
        if len(self.points) > self.max_points:
            # Keep top max_points with highest residuals
            indices = np.argsort(self.residuals)[-self.max_points:]
            self.points = self.points[indices]
            self.residuals = self.residuals[indices]
        
        print(f"Added {len(new_points)} refined points. Now contains {len(self.points)} points.")
    
    def visualize_residuals(self, t_idx=5, nx=50, ny=50, nt=10):
        """
        Visualize residuals at a specific time step
        
        Parameters:
        -----------
        t_idx : int
            Time index to visualize
        nx, ny, nt : int
            Number of points in each dimension for visualization grid
        """
        # Generate uniform grid for visualization
        x = np.linspace(self.domain_bounds['x_min'], self.domain_bounds['x_max'], nx)
        y = np.linspace(self.domain_bounds['y_min'], self.domain_bounds['y_max'], ny)
        t = np.linspace(self.domain_bounds['t_min'], self.domain_bounds['t_max'], nt)
        
        # Create meshgrid
        X, Y = np.meshgrid(x, y, indexing='ij')
        
        # Create grid points for the specific time step
        time_val = t[t_idx]
        grid_t = np.ones_like(X.flatten()) * time_val
        grid_points = np.stack([X.flatten(), Y.flatten(), grid_t], axis=1)
        
        # Compute residuals
        residuals = self.compute_residuals(grid_points)
        
        # Reshape residuals
        residuals = residuals.reshape(X.shape)
        
        # Plot residuals
        plt.figure(figsize=(10, 8))
        
        # Plot residual magnitudes
        plt.contourf(X, Y, residuals, 50, cmap='hot')
        plt.colorbar(label='Residual Magnitude')
        
        # Plot high residual points (time slice)
        time_mask = np.abs(self.points[:, 2] - time_val) < (t[1] - t[0])
        if np.any(time_mask):
            high_res_points = self.points[time_mask]
            high_res_values = self.residuals[time_mask]
            
            # Plot top 10% residual points
            if len(high_res_values) > 0:
                threshold = np.percentile(high_res_values, 90)
                mask = high_res_values > threshold
                plt.scatter(high_res_points[mask, 0], high_res_points[mask, 1], 
                            c='k', marker='x', s=40, label='High Residual Points')
        
        plt.xlabel('x (m)')
        plt.ylabel('y (m)')
        plt.title(f'Residual Magnitude at t = {time_val:.2f}s')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'figures/residuals_t{t_idx}.png', dpi=300, bbox_inches='tight')
        plt.show()


In [3]:
#---------------------------------------------------------------------------
# Section 2: RAR-Net (Residual-Adaptive Refinement Network)
#---------------------------------------------------------------------------

class RARNet:
    """
    Residual-Adaptive Refinement Network (RAR-Net)
    A PINN training strategy with adaptive point generation based on residual feedback
    """
    
    def __init__(self, models, phys_params, domain_bounds, use_entropy_langevin=True):
        """
        Initialize the RAR-Net
        
        Parameters:
        -----------
        models : list or PINN
            The PINN model(s) to train
        phys_params : CVDPhysicalParams
            Object containing physical parameters
        domain_bounds : dict
            Dictionary with domain bounds
        use_entropy_langevin : bool
            Whether to use Entropy-Langevin dynamics for ensemble training
        """
        self.models = models
        self.is_ensemble = isinstance(models, list)
        self.num_models = len(models) if self.is_ensemble else 1
        self.phys_params = phys_params
        self.domain_bounds = domain_bounds
        self.use_entropy_langevin = use_entropy_langevin
        
        # Create PDE residual calculator
        self.pde_calculator = nb1.CVDPDE(phys_params)
        
        # Create data generator
        self.data_generator = nb1.CVDDataGenerator(domain_bounds)
        
        # Create residual-based sampler
        self.sampler = ResidualBasedSampler(
            models, self.pde_calculator, domain_bounds
        )
        
        # Initialize trainer based on model type
        if self.is_ensemble and use_entropy_langevin:
            # Entropy-Langevin trainer for ensemble
            self.trainer = nb2.EntropyLangevinPINNTrainer(
                models, phys_params, domain_bounds
            )
        elif self.is_ensemble and not use_entropy_langevin:
            # Create multiple traditional trainers
            self.trainer = [
                nb2.TraditionalPINNTrainer(model, phys_params, domain_bounds)
                for model in models
            ]
        else:
            # Traditional trainer for single model
            self.trainer = nb2.TraditionalPINNTrainer(
                models, phys_params, domain_bounds
            )
        
        # Initialize loss history
        self.loss_history = {
            'total': [],
            'pde': [],
            'bc': [],
            'ic': []
        }
        
        # Initialize refinement history
        self.refinement_history = {
            'iteration': [],
            'num_points': [],
            'max_residual': [],
            'avg_residual': []
        }
    
    def train_with_adaptive_refinement(self, n_iterations=5, n_epochs_per_iter=1000,
                                       initial_points=5000, refinement_points=1000,
                                       refinement_fraction=0.2, sampling_method='mixed'):
        """
        Train the PINN with adaptive refinement
        
        Parameters:
        -----------
        n_iterations : int
            Number of refinement iterations
        n_epochs_per_iter : int
            Number of training epochs per iteration
        initial_points : int
            Number of initial points
        refinement_points : int
            Number of points to add at each refinement step
        refinement_fraction : float
            Fraction of high-residual points to refine around
        sampling_method : str
            Method for sampling training points
            
        Returns:
        --------
        dict
            Training and refinement history
        """
        print("Starting RAR-Net training with adaptive refinement...")
        start_time = time.time()
        
        # Step 1: Initialize point database with random points
        self.sampler.initialize_points(initial_points)
        
        # Generate boundary and initial condition points (these don't change)
        boundary_points = self.data_generator.generate_boundary_points(initial_points // 10)
        # Convert to tensors
        boundary_points_tensor = {}
        for key in boundary_points:
            boundary_points_tensor[key] = tf.convert_to_tensor(boundary_points[key], dtype=tf.float32)
        
        initial_points_data = self.data_generator.generate_initial_points(initial_points // 10)
        initial_points_tensor = tf.convert_to_tensor(initial_points_data, dtype=tf.float32)
        
        # Iterative refinement process
        for iteration in range(n_iterations):
            print(f"\n===== Refinement Iteration {iteration+1}/{n_iterations} =====")
            
            # Step 2: Sample training points based on residuals
            training_points = self.sampler.get_training_points(
                initial_points, method=sampling_method
            )
            training_points_tensor = tf.convert_to_tensor(training_points, dtype=tf.float32)
            
            # Step 3: Train the model(s)
            print(f"Training with {len(training_points)} points...")
            
            if self.is_ensemble and self.use_entropy_langevin:
                # Train with Entropy-Langevin
                loss_history = self.trainer.train(
                    n_epochs=n_epochs_per_iter,
                    n_collocation_points=len(training_points),
                    print_frequency=n_epochs_per_iter // 10
                )
                
                # Save loss history
                self.loss_history['total'].extend(loss_history['total'])
                # Handle ensemble losses differently
                for i in range(self.num_models):
                    if iteration == 0:
                        self.loss_history['pde'].append(loss_history['pde'][i])
                        self.loss_history['bc'].append(loss_history['bc'][i])
                        self.loss_history['ic'].append(loss_history['ic'][i])
                    else:
                        self.loss_history['pde'][i].extend(loss_history['pde'][i])
                        self.loss_history['bc'][i].extend(loss_history['bc'][i])
                        self.loss_history['ic'][i].extend(loss_history['ic'][i])
            
            elif self.is_ensemble and not self.use_entropy_langevin:
                # Train each model separately with traditional method
                for i, single_trainer in enumerate(self.trainer):
                    print(f"Training model {i+1}/{self.num_models}...")
                    loss_history = single_trainer.train(
                        n_epochs=n_epochs_per_iter,
                        n_collocation_points=len(training_points),
                        print_frequency=n_epochs_per_iter // 10
                    )
                    
                    # Save loss history
                    if i == 0:
                        self.loss_history['total'].extend(loss_history['total'])
                    
                    if iteration == 0:
                        self.loss_history['pde'].append(loss_history['pde'])
                        self.loss_history['bc'].append(loss_history['bc'])
                        self.loss_history['ic'].append(loss_history['ic'])
                    else:
                        self.loss_history['pde'][i].extend(loss_history['pde'])
                        self.loss_history['bc'][i].extend(loss_history['bc'])
                        self.loss_history['ic'][i].extend(loss_history['ic'])
            
            else:
                # Train single model with traditional method
                loss_history = self.trainer.train(
                    n_epochs=n_epochs_per_iter,
                    n_collocation_points=len(training_points),
                    print_frequency=n_epochs_per_iter // 10
                )
                
                # Save loss history
                self.loss_history['total'].extend(loss_history['total'])
                self.loss_history['pde'].extend(loss_history['pde'])
                self.loss_history['bc'].extend(loss_history['bc'])
                self.loss_history['ic'].extend(loss_history['ic'])
            
            # Step 4: Update residuals for all points
            self.sampler.residuals = self.sampler.compute_residuals(self.sampler.points)
            
            # Save refinement statistics
            self.refinement_history['iteration'].append(iteration + 1)
            self.refinement_history['num_points'].append(len(self.sampler.points))
            self.refinement_history['max_residual'].append(np.max(self.sampler.residuals))
            self.refinement_history['avg_residual'].append(np.mean(self.sampler.residuals))
            
            # Visualize current residuals
            self.sampler.visualize_residuals(t_idx=5)
            
            # Step 5: Adaptive refinement
            # Calculate how many points to refine around
            n_refine = int(refinement_fraction * len(self.sampler.points))
            # Number of new points per high-residual location
            n_per_point = int(refinement_points / n_refine) + 1
            
            self.sampler.refine_near_high_residuals(
                n_refine=n_refine, 
                n_per_point=n_per_point
            )
        
        total_time = time.time() - start_time
        print(f"RAR-Net training completed in {total_time:.2f} seconds.")
        
        # Save final model(s)
        if self.is_ensemble and self.use_entropy_langevin:
            self.trainer.save_models("models/rarnet_entropy_langevin")
        elif self.is_ensemble and not self.use_entropy_langevin:
            for i, single_trainer in enumerate(self.trainer):
                single_trainer.save_model(f"models/rarnet_traditional_ensemble_{i}.h5")
        else:
            self.trainer.save_model("models/rarnet_traditional.h5")
        
        return {
            'loss_history': self.loss_history,
            'refinement_history': self.refinement_history
        }
    
    def plot_refinement_history(self):
        """Plot the refinement history"""
        fig, ax1 = plt.subplots(figsize=(12, 6))
        
        color = 'tab:blue'
        ax1.set_xlabel('Refinement Iteration', fontsize=14)
        ax1.set_ylabel('Number of Points', color=color, fontsize=14)
        ax1.plot(self.refinement_history['iteration'], self.refinement_history['num_points'], 
                 'o-', color=color, linewidth=2, markersize=8)
        ax1.tick_params(axis='y', labelcolor=color)
        
        ax2 = ax1.twinx()
        color = 'tab:red'
        ax2.set_ylabel('Residual', color=color, fontsize=14)
        ax2.plot(self.refinement_history['iteration'], self.refinement_history['max_residual'], 
                 's--', color='tab:red', linewidth=2, markersize=8, label='Max Residual')
        ax2.plot(self.refinement_history['iteration'], self.refinement_history['avg_residual'], 
                 '^--', color='tab:orange', linewidth=2, markersize=8, label='Avg Residual')
        ax2.tick_params(axis='y', labelcolor=color)
        
        # Add legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax2.legend(lines1 + lines2, ['Number of Points'] + labels2, loc='upper left', fontsize=12)
        
        plt.title('Adaptive Refinement Progress', fontsize=16)
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('figures/refinement_history.png', dpi=300, bbox_inches='tight')
        plt.show()

    def evaluate_on_grid(self, nx=50, ny=50, nt=10):
        """
        Evaluate the model on a uniform grid and compute error statistics
        
        Parameters:
        -----------
        nx, ny, nt : int
            Number of points in each dimension
            
        Returns:
        --------
        dict
            Error statistics
        """
        # Generate uniform grid
        grid_points, grid_shape = self.data_generator.generate_uniform_grid(nx, ny, nt)
        
        # Compute residuals on grid
        grid_tensor = tf.convert_to_tensor(grid_points, dtype=tf.float32)
        grid_residuals = self.sampler.compute_residuals(grid_tensor)
        
        # Reshape residuals
        grid_residuals = grid_residuals.reshape(grid_shape)
        
        # Compute statistics
        max_residual = np.max(grid_residuals)
        mean_residual = np.mean(grid_residuals)
        median_residual = np.median(grid_residuals)
        std_residual = np.std(grid_residuals)
        
        # Find location of maximum residual
        max_idx = np.unravel_index(np.argmax(grid_residuals), grid_residuals.shape)
        max_x = np.linspace(self.domain_bounds['x_min'], self.domain_bounds['x_max'], nx)[max_idx[0]]
        max_y = np.linspace(self.domain_bounds['y_min'], self.domain_bounds['y_max'], ny)[max_idx[1]]
        max_t = np.linspace(self.domain_bounds['t_min'], self.domain_bounds['t_max'], nt)[max_idx[2]]
        
        # Create error statistics dictionary
        error_stats = {
            'max_residual': max_residual,
            'mean_residual': mean_residual,
            'median_residual': median_residual,
            'std_residual': std_residual,
            'max_location': (max_x, max_y, max_t)
        }
        
        print("\nError Statistics on Uniform Grid:")
        print(f"Max Residual: {max_residual:.6e}")
        print(f"Mean Residual: {mean_residual:.6e}")
        print(f"Median Residual: {median_residual:.6e}")
        print(f"Std Residual: {std_residual:.6e}")
        print(f"Location of Max Residual: x={max_x:.4f}, y={max_y:.4f}, t={max_t:.4f}")
        
        return error_stats
    
    def visualize_solution(self, output_idx=0, t_idx=5, nx=50, ny=50, nt=10):
        """
        Visualize the solution at a specific time step
        
        Parameters:
        -----------
        output_idx : int
            Output index to visualize (0: SiH4, 1: Si, 2: H2, 3: SiH2, 4: T)
        t_idx : int
            Time index to visualize
        nx, ny, nt : int
            Number of points in each dimension
        """
        # Species names and titles
        species_names = ["SiH4", "Si", "H2", "SiH2", "Temperature"]
        
        # Generate uniform grid
        grid_points, grid_shape = self.data_generator.generate_uniform_grid(nx, ny, nt)
        grid_tensor = tf.convert_to_tensor(grid_points, dtype=tf.float32)
        
        # Compute predictions
        if self.is_ensemble:
            # Compute ensemble predictions
            predictions = []
            for model in self.models:
                pred = model(grid_tensor).numpy()
                predictions.append(pred)
            
            # Stack predictions
            predictions = np.stack(predictions, axis=0)
            
            # Compute mean and standard deviation
            mean_pred = np.mean(predictions, axis=0)
            std_pred = np.std(predictions, axis=0)
            
            # Reshape predictions
            mean_pred = mean_pred.reshape(*grid_shape, 5)
            std_pred = std_pred.reshape(*grid_shape, 5)
            
            # Extract specific output and time step
            mean_slice = mean_pred[:, :, t_idx, output_idx]
            std_slice = std_pred[:, :, t_idx, output_idx]
            
            # Create visualization
            x = np.linspace(self.domain_bounds['x_min'], self.domain_bounds['x_max'], nx)
            y = np.linspace(self.domain_bounds['y_min'], self.domain_bounds['y_max'], ny)
            t = np.linspace(self.domain_bounds['t_min'], self.domain_bounds['t_max'], nt)
            
            X, Y = np.meshgrid(x, y, indexing='ij')
            time_val = t[t_idx]
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
            
            # Plot mean prediction
            cf1 = ax1.contourf(X, Y, mean_slice, 50, cmap='viridis')
            plt.colorbar(cf1, ax=ax1, label=species_names[output_idx])
            ax1.set_xlabel('x (m)')
            ax1.set_ylabel('y (m)')
            ax1.set_title(f'Mean {species_names[output_idx]} at t = {time_val:.2f}s')
            
            # Plot standard deviation
            cf2 = ax2.contourf(X, Y, std_slice, 50, cmap='plasma')
            plt.colorbar(cf2, ax=ax2, label=f'Std Dev of {species_names[output_idx]}')
            ax2.set_xlabel('x (m)')
            ax2.set_ylabel('y (m)')
            ax2.set_title(f'Uncertainty in {species_names[output_idx]} at t = {time_val:.2f}s')
            
            plt.tight_layout()
            plt.savefig(f'figures/rarnet_{species_names[output_idx]}_t{t_idx}.png', dpi=300, bbox_inches='tight')
            plt.show()
        
        else:
            # Single model prediction
            pred = self.models(grid_tensor).numpy()
            
            # Reshape prediction
            pred = pred.reshape(*grid_shape, 5)
            
            # Extract specific output and time step
            pred_slice = pred[:, :, t_idx, output_idx]
            
            # Create visualization
            x = np.linspace(self.domain_bounds['x_min'], self.domain_bounds['x_max'], nx)
            y = np.linspace(self.domain_bounds['y_min'], self.domain_bounds['y_max'], ny)
            t = np.linspace(self.domain_bounds['t_min'], self.domain_bounds['t_max'], nt)
            
            X, Y = np.meshgrid(x, y, indexing='ij')
            time_val = t[t_idx]
            
            plt.figure(figsize=(10, 8))
            
            cf = plt.contourf(X, Y, pred_slice, 50, cmap='viridis')
            plt.colorbar(cf, label=species_names[output_idx])
            plt.xlabel('x (m)')
            plt.ylabel('y (m)')
            plt.title(f'{species_names[output_idx]} at t = {time_val:.2f}s')
            
            plt.tight_layout()
            plt.savefig(f'figures/rarnet_{species_names[output_idx]}_t{t_idx}.png', dpi=300, bbox_inches='tight')
            plt.show()



In [4]:
#---------------------------------------------------------------------------
# Section 3: Demo and Testing
#---------------------------------------------------------------------------

def demo_adaptive_sampling(train_models=True, use_entropy_langevin=True):
    """
    Demonstrate RAR-Net with adaptive sampling
    
    Parameters:
    -----------
    train_models : bool
        Whether to train new models or load existing ones
    use_entropy_langevin : bool
        Whether to use Entropy-Langevin for ensemble training
    """
    # Define domain bounds
    domain_bounds = {
        'x_min': 0.0,
        'x_max': 0.1,
        'y_min': 0.0,
        'y_max': 0.05,
        't_min': 0.0,
        't_max': 10.0
    }
    
    # Create physical parameters
    phys_params = nb1.CVDPhysicalParams()
    
    # Create models
    if use_entropy_langevin:
        # Create ensemble of models
        ensemble_size = 5  # Smaller ensemble for faster demo
        models = nb1.create_model_ensemble(num_models=ensemble_size)
    else:
        # Create single model
        models = nb1.create_pinn_model()
    
    # Create RAR-Net
    rarnet = RARNet(models, phys_params, domain_bounds, use_entropy_langevin)
    
    if train_models:
        # Train with adaptive refinement
        history = rarnet.train_with_adaptive_refinement(
            n_iterations=3,  # Reduced iterations for demo
            n_epochs_per_iter=200,  # Reduced epochs for demo
            initial_points=1000,
            refinement_points=500,
            refinement_fraction=0.2,
            sampling_method='mixed'
        )
        
        # Save history
        with open('models/rarnet_history.pkl', 'wb') as f:
            pickle.dump(history, f)
    else:
        # Load models and history
        if use_entropy_langevin:
            rarnet.trainer.load_models("models/rarnet_entropy_langevin.h5")
        else:
            rarnet.trainer.load_model("models/rarnet_traditional.h5")
        
        # Load history
        try:
            with open('models/rarnet_history.pkl', 'rb') as f:
                history = pickle.load(f)
                rarnet.loss_history = history['loss_history']
                rarnet.refinement_history = history['refinement_history']
        except:
            print("No history found. Cannot visualize results without training first.")
            return
    
    # Plot refinement history
    rarnet.plot_refinement_history()
    
    # Evaluate on grid
    error_stats = rarnet.evaluate_on_grid()
    
    # Visualize solution for different species
    rarnet.visualize_solution(output_idx=0)  # SiH4
    rarnet.visualize_solution(output_idx=1)  # Si
    rarnet.visualize_solution(output_idx=4)  # Temperature

# Only execute demo if explicitly requested (to avoid long training times)
if __name__ == "__main__":
    print("This notebook implements advanced sampling techniques for Physics-Informed Neural Networks.")
    print("To run the demo, execute: demo_adaptive_sampling(train_models=True, use_entropy_langevin=True)")

This notebook implements advanced sampling techniques for Physics-Informed Neural Networks.
To run the demo, execute: demo_adaptive_sampling(train_models=True, use_entropy_langevin=True)
