# PDESystem and Network Architectures

In [1]:
experiments = [
    ((64, 64), "SmallEqual"),
    ((128, 64), "Baseline"),
    ((128, 64, 32), "MediumThreeTapered"),
    ((64, 64, 64), "MediumThreeEqual"),
    ((64, 128, 128, 64), "MidInverseTapered"),
    ((256, 256), "MediumEqual"),
    ((64, 64, 64, 64), "MediumDeepEqual"),
    ((512, 256, 128, 64), "LargeTapered"),
    ((256, 256, 256, 256), "LargeEqual"),
]

total_params = [ # as calculated by TensorFlow
    ((64, 64), 4547),
    ((128, 64), 8835),
    ((128, 64, 32), 10819),
    ((64, 64, 64), 8707),
    ((64, 128, 128, 64), 33475),
    ((256, 256), 67331),
    ((64, 64, 64, 64), 12867),
    ((512, 256, 128, 64), 174211),
    ((256, 256, 256, 256), 198915),
]

In [None]:
# @title PDESystem
import tensorflow as tf
import numpy as np
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
from typing import List, Tuple
import time
import shap
import os
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

class PDESystem:
    '''
    Abstract class for defining a PDE system.
    '''

    def __init__(self, num_components: int, x_domains=[(0, 1)], t_domain=(0,1)):
        """
        Initialize the PDE system with the number of components and domain ranges.

        Parameters:
        num_components (int): Number of components in the PDE system.
        x_domains (list): List of tuples defining the spatial domain ranges.
        t_domain (tuple): Tuple defining the temporal domain range.
        """
        self.num_components = num_components
        self.x_domains = x_domains
        self.t_domain = t_domain

    @abstractmethod
    def residuals(self, model: tf.keras.Model, t_points: tf.Tensor, x_points: tf.Tensor):
        """
        Abstract method to compute the residuals of the PDE system.

        Parameters:
        model (tf.keras.Model): The neural network model.
        t_points (tf.Tensor): Tensor of temporal points.
        x_points (tf.Tensor): Tensor of spatial points.

        Returns:
        List[tf.Tensor]: List of residuals for each component.
        """
        raise NotImplementedError

    def compute_bounary_loss(self, model: tf.keras.Model, t_points: tf.Tensor, x_points: tf.Tensor) -> tf.Tensor:
        """
        Compute the boundary loss for the PDE system.

        Parameters:
        model (tf.keras.Model): The neural network model.
        t_points (tf.Tensor): Tensor of temporal points.
        x_points (tf.Tensor): Tensor of spatial points.

        Returns:
        tf.Tensor: Boundary loss.
        """
        return tf.constant(0.0)

    def compute_loss(self, model: tf.keras.Model, t_points: tf.Tensor, x_points: tf.Tensor, lambda_bc: float = 1) -> Tuple[tf.Tensor, List[tf.Tensor]]:
        """
        Compute the overall loss and component losses for the PDE system.

        Parameters:
        model (tf.keras.Model): The neural network model.
        t_points (tf.Tensor): Tensor of temporal points.
        x_points (tf.Tensor): Tensor of spatial points.
        lambda_bc (float): Weight for the boundary condition loss.

        Returns:
        Tuple[tf.Tensor, List[tf.Tensor]]: Overall loss and list of component losses.
        """
        res_list = self.residuals(model, t_points, x_points)
        res_concat = tf.concat(res_list, axis=1)

        res_concat = tf.cast(res_concat, dtype=tf.float64)
        overall_loss = tf.reduce_mean(tf.square(res_concat)) + tf.cast(lambda_bc, dtype=tf.float64) * tf.cast(
            self.compute_bounary_loss(model, t_points, x_points), dtype=tf.float64)

        component_losses = [tf.reduce_mean(tf.square(tf.cast(res, dtype=tf.float64))) for res in res_list]
        return tf.cast(overall_loss, dtype=tf.float32), [tf.cast(loss, dtype=tf.float32) for loss in component_losses]

    def generate_points(self, N: int = 1000) -> Tuple[tf.Tensor, tf.Tensor]:
        """
        Generate random points within the spatial and temporal domains.

        Parameters:
        N (int): Number of points to generate.

        Returns:
        Tuple[tf.Tensor, tf.Tensor]: Tensors of spatial and temporal points.
        """
        space_coords = []
        for (x_min, x_max) in self.x_domains:
            coord = np.random.uniform(x_min, x_max, size=(N, 1))
            space_coords.append(coord)
        space_coords_stacked = np.hstack(space_coords) if space_coords else np.zeros((N, 0))

        t_min, t_max = self.t_domain
        time_coords = np.random.uniform(t_min, t_max, size=(N, 1))

        return tf.convert_to_tensor(space_coords_stacked, dtype=tf.float32), tf.convert_to_tensor(time_coords, dtype=tf.float32)

    def train(self, model: tf.keras.Model, epochs: int, N: int, lr: float) -> Tuple[np.ndarray, np.ndarray, float]:
        """
        Train the neural network model on the PDE system.

        Parameters:
        model (tf.keras.Model): The neural network model.
        epochs (int): Number of training epochs.
        N (int): Number of points to generate for each epoch.
        lr (float): Learning rate for the optimizer.

        Returns:
        Tuple[np.ndarray, np.ndarray, float]: Loss history, component losses history, and elapsed training time.
        """
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
        loss_history = np.zeros(epochs)
        component_losses_history = np.zeros((epochs, self.num_components))
        start_time = time.time()

        for epoch in range(epochs):
            x_points, t_points = self.generate_points(N)

            with tf.GradientTape() as tape:
                overall_loss, component_losses = self.compute_loss(model, t_points, x_points)
                overall_loss = tf.cast(overall_loss, dtype=tf.float32)
            gradients = tape.gradient(overall_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            loss_history[epoch] = overall_loss.numpy()
            component_losses_history[epoch, :] = [loss.numpy() for loss in component_losses]
            if epoch % 100 == 0:
                print(f"Epoch {epoch}: Loss = {overall_loss.numpy()}")

        elapsed_time = time.time() - start_time
        return loss_history, component_losses_history, elapsed_time

In [3]:
# @title Network Types
class UnifiedPINN(tf.keras.Model):
    def __init__(self, output_dim: int, architecture: Tuple[int, ...] = (256, 256), activation: str ='tanh'):
        """
        Initialize the UnifiedPINN model.

        Parameters:
        output_dim (int): The dimension of the output layer.
        architecture (Tuple[int, ...]): A tuple defining the number of units in each hidden layer.
        activation (str): The activation function to use in the hidden layers.
        """
        super(UnifiedPINN, self).__init__()
        # Create hidden layers based on the architecture
        self.hidden_layers = [
            tf.keras.layers.Dense(units, activation=activation)
            for units in architecture
        ]
        # Create the output layer
        self.output_layer = tf.keras.layers.Dense(output_dim)

    def call(self, x_points: tf.Tensor, t_points: tf.Tensor) -> tf.Tensor:
        """
        Forward pass of the UnifiedPINN model.

        Parameters:
        x_points (tf.Tensor): Tensor of spatial points.
        t_points (tf.Tensor): Tensor of temporal points.

        Returns:
        tf.Tensor: The output of the model.
        """
        # Combine temporal and spatial inputs
        combined = tf.concat([t_points, x_points], axis=1)
        out = combined
        # Pass through hidden layers
        for layer in self.hidden_layers:
            out = layer(out)
        # Pass through the output layer
        return self.output_layer(out)

class ModularPINNs(tf.keras.Model):
    def __init__(self, num_components: int, architectures: List[Tuple[int, ...]] = None, activation: str = 'tanh'):
        """
        Initialize the ModularPINNs model.

        Parameters:
        num_components (int): Number of components in the PDE system.
        architectures (List[Tuple[int, ...]]): List of architectures for each component model.
        activation (str): The activation function to use in the hidden layers.
        """
        super(ModularPINNs, self).__init__()
        self.num_components = num_components

        # Set default architectures if none are provided
        if architectures is None:
            architectures = [(256, 256)] * num_components

        # Ensure the number of architectures matches the number of components
        assert len(architectures) == num_components, (
            "The number of architectures must match the number of components."
        )

        # Create a list of UnifiedPINN models for each component
        self.models = [
            UnifiedPINN(output_dim=1, architecture=arch, activation=activation)
            for arch in architectures
        ]

    def call(self, x_points: tf.Tensor, t_points: tf.Tensor) -> tf.Tensor:
        """
        Forward pass of the ModularPINNs model.

        Parameters:
        x_points (tf.Tensor): Tensor of spatial points.
        t_points (tf.Tensor): Tensor of temporal points.

        Returns:
        tf.Tensor: The concatenated output of all component models.
        """
        outputs = []
        # Pass through each component model and collect outputs
        for model in self.models:
            out = model(x_points, t_points)
            outputs.append(out)
        # Concatenate outputs from all component models
        return tf.concat(outputs, axis=1)

    @property
    def trainable_weights(self) -> List[tf.Variable]:
        """
        Get the trainable weights of the ModularPINNs model.

        Returns:
        List[tf.Variable]: List of trainable weights from all component models.
        """
        return [weight for model in self.models for weight in model.trainable_weights]

class SemiModularElasticWavePINN(tf.keras.Model):
    pass # Placeholder for Semi-Modular Elastic Wave PINN defined later

In [None]:
# @title Comparison Class
class UnifiedVsModularComparison:
    def __init__(self, system: PDESystem):
        self.system = system

    def train_unified_and_modular(self, epochs: int, N: int, lr: float, unified_architecture: Tuple[int, ...] = (256, 256), modular_architectures: List[Tuple[int, ...]] = None, activation: str = 'tanh', train_scaled=False, path="") -> Tuple[UnifiedPINN, ModularPINNs]:
        unified_model = UnifiedPINN(output_dim=self.system.num_components, architecture=unified_architecture, activation=activation)

        if modular_architectures is None:
            modular_architectures = [unified_architecture] * self.system.num_components
        modular_model = ModularPINNs(num_components=self.system.num_components, architectures=modular_architectures, activation=activation)

        unified_loss_history, unified_component_losses_history, unified_time = self.system.train(unified_model, epochs, N, lr)
        modular_loss_history, modular_component_losses_history, modular_time = self.system.train(modular_model, epochs, N, lr)
        if train_scaled:
            arch = self.find_modular_architecture(unified_architecture, self.system.num_components)
            print(f"Modular Scaled Architecture: {arch}")
            scaled_modular_model = ModularPINNs(num_components=self.system.num_components, architectures=([arch]*self.system.num_components), activation=activation)
            scaled_modular_loss_history, scaled_modular_component_losses_history, scaled_modular_time = self.system.train(scaled_modular_model, epochs, N, lr)

        x_val_points, t_val_points = self.generate_validation_points()
        for model in [unified_model, modular_model]:
            final_overall_loss, _ = self.system.compute_loss(model, t_val_points, x_val_points)
            print(f"Final Loss for {model}: {final_overall_loss.numpy()}")
        if train_scaled:
            for model in [scaled_modular_model]:
                final_overall_loss, _ = self.system.compute_loss(model, t_val_points, x_val_points)
                print(f"Final Loss for {model}: {final_overall_loss.numpy()}")

        if not train_scaled:
            self.compare_overall_loss(unified_loss_history, modular_loss_history)
            self.compare_component_losses(unified_component_losses_history, modular_component_losses_history)
        if train_scaled:
            self.compare_overall_loss(unified_loss_history, modular_loss_history, scaled_modular_loss=scaled_modular_loss_history)
            self.compare_component_losses(unified_component_losses_history, modular_component_losses_history, scaled_modular_loss=scaled_modular_component_losses_history)

        unified_model.summary()
        modular_model.summary()
        if train_scaled:
            scaled_modular_model.summary()

        self.compute_convergence_epochs(unified_loss_history)
        self.compute_convergence_epochs(modular_loss_history)
        if train_scaled:
            self.compute_convergence_epochs(scaled_modular_loss_history)

        print(f"Unified PINN training time: {unified_time:.2f} seconds")
        print(f"Modular PINN training time: {modular_time:.2f} seconds")
        print(f"Relative training time: {((modular_time-unified_time)/unified_time):.2f}")
        if train_scaled:
            print(f"Scaled Modular PINN training time: {scaled_modular_time:.2f} seconds")
            print(f"Relative training time (to scaled): {((scaled_modular_time-unified_time)/unified_time):.2f}")

        if path == "":
            path = self.system.__class__.__name__
        else:
            path = self.system.__class__.__name__ + "_" + path
        self.save_model(unified_model, f"{path}_unified", modular=False)
        self.save_model(modular_model, f"{path}_modular", modular=True)
        if train_scaled:
            self.save_model(scaled_modular_model, f"{path}_scaled_modular_model", modular=True)

        if train_scaled:
            return unified_model, modular_model, scaled_modular_model
        return unified_model, modular_model

    @staticmethod
    def compare_component_losses(unified_component_losses, modular_component_losses, semi_modular_loss=None, scaled_modular_loss=None):
      assert unified_component_losses.shape[1] == modular_component_losses.shape[1], \
          "Number of components must match between Unified and Modular PINNs."

      num_components = unified_component_losses.shape[1]
      fig, axes = plt.subplots(1, num_components, figsize=(5 * num_components, 6), sharey=True)

      if num_components == 1:
          axes = [axes]

      for i in range(num_components):
          axes[i].plot(unified_component_losses[:, i], label=f"Unified Component {i + 1}", color="blue")
          axes[i].plot(modular_component_losses[:, i], label=f"Modular Component {i + 1}", color="green")
          if semi_modular_loss is not None:
              axes[i].plot(semi_modular_loss[:, i], label=f"Semi-Modular Component {i + 1}", color="red")
          if scaled_modular_loss is not None:
              axes[i].plot(scaled_modular_loss[:, i], label=f"Scaled Modular Component {i + 1}", color="purple")
          axes[i].set_xlabel("Epochs")
          axes[i].set_yscale("log")
          axes[i].set_title(f"Component {i + 1} Loss")
          axes[i].legend()
          axes[i].grid(True, which="both", linestyle="--", linewidth=0.5)

      plt.tight_layout()
      plt.show()

    @staticmethod
    def compare_overall_loss(unified_loss, modular_loss, semi_modular_loss=None, scaled_modular_loss=None):
        assert len(unified_loss) == len(modular_loss), \
            "Unified and Modular loss histories must have the same length."

        plt.figure(figsize=(8, 6))
        plt.plot(unified_loss, label="Unified Overall Loss", color="blue")
        plt.plot(modular_loss, label="Modular Overall Loss", color="green")
        if semi_modular_loss is not None:
            plt.plot(semi_modular_loss, label="Semi-Modular Overall Loss", color="red")
        if scaled_modular_loss is not None:
            plt.plot(scaled_modular_loss, label="Scaled Modular Overall Loss", color="purple")

        plt.xlabel("Epochs")
        plt.ylabel("Log Loss")
        plt.yscale("log")
        if semi_modular_loss is not None:
            plt.title("Unified vs Modular vs Semi-Modular Overall Loss")
        else:
            plt.title("Unified vs Modular Overall Loss")
        plt.legend()
        plt.grid(True, which="both", linestyle="--", linewidth=0.5)

        plt.tight_layout()
        plt.show()

    @staticmethod
    def compute_convergence_epochs(loss_history, thresholds: List[float] = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]):
        convergence_epochs = {}
        for threshold in thresholds:
            epoch = next((i for i, loss in enumerate(loss_history) if loss < threshold), None)
            convergence_epochs[threshold] = epoch
        print(f"Convergence epochs for different thresholds: {convergence_epochs}")
        return convergence_epochs

    @staticmethod
    def find_modular_architecture(unified_arch, num_components, max_iter=2000, tolerance=0.05):
        def compute_parameters(layers):
            weights = sum(layers[i] * layers[i + 1] for i in range(len(layers) - 1))
            biases = sum(layers)
            return weights + biases

        unified_params = compute_parameters(unified_arch)
        print(f"Unified Network Parameters: {unified_params}")

        # Preserve the proportion of neurons across layers
        total_neurons = sum(unified_arch)
        layer_proportions = [layer / total_neurons for layer in unified_arch]

        modular_arch = [max(1, int(neurons * sum(unified_arch) // num_components)) for neurons in unified_arch]

        for _ in range(max_iter):
            modular_params = num_components * compute_parameters(modular_arch)

            if abs(modular_params - unified_params) / unified_params < tolerance:
                print(f"Modular Network Parameters: {modular_params}")
                return modular_arch

            # Scale each layer proportionally to match parameters
            scale_factor = np.sqrt(unified_params / modular_params)
            modular_arch = [max(1, int(layer * scale_factor)) for layer in modular_arch]

        raise ValueError("Failed to match modular architecture within tolerance.")

    def generate_validation_points(self, N: int = 1000):
        spatial_dims = len(self.system.x_domains)
        spatial_points_per_dim = int(np.ceil(N ** (1 / (spatial_dims + 1))))
        temporal_points = int(np.ceil(N / (spatial_points_per_dim ** spatial_dims)))

        spatial_grids = [
            np.linspace(domain[0], domain[1], spatial_points_per_dim)
            for domain in self.system.x_domains
        ]
        spatial_mesh = np.meshgrid(*spatial_grids)
        spatial_coords = np.stack([mesh.flatten() for mesh in spatial_mesh], axis=1)

        t_min, t_max = self.system.t_domain
        temporal_coords = np.linspace(t_min, t_max, temporal_points).reshape(-1, 1)

        spatial_repeats = temporal_coords.shape[0]
        spatial_expanded = np.repeat(spatial_coords, spatial_repeats, axis=0)
        temporal_tiled = np.tile(temporal_coords, (spatial_coords.shape[0], 1))

        x_points = tf.convert_to_tensor(spatial_expanded, dtype=tf.float32)
        t_points = tf.convert_to_tensor(temporal_tiled, dtype=tf.float32)

        return x_points, t_points

    @staticmethod
    def save_model(model: tf.keras.Model, path: str, modular: bool = False):
        if modular:
            for i, sub_model in enumerate(model.models):
                sub_model.save_weights(f"{path}_model_{i}.weights.h5")
        else:
            model.save_weights(f"{path}.weights.h5")

    @staticmethod
    def load_model(path: str, num_components: int = None, modular: bool = False, architectures: List[Tuple[int, ...]]  = [(256, 256)], activation: str = 'tanh'):
        if modular:
            if num_components is None:
                raise ValueError("num_components must be specified for ModularPINNs.")
            modular_model = ModularPINNs(num_components=num_components,
                                         architectures=architectures,
                                         activation=activation)
            for i, sub_model in enumerate(modular_model.models):
                sub_model.load_weights(f"{path}_model_{i}.weights.h5")
            return modular_model
        else:
            unified_model = UnifiedPINN(output_dim=1, architecture=architectures[0], activation=activation)
            unified_model.load_weights(f"{path}.weights.h5")
            return unified_model

    def plot_neural_network_activation(self, model, x_points, t_points, save_path="network_visualization"):
        """Visualize neuron activations for all output components."""
        os.makedirs(save_path, exist_ok=True)
        combined_input = tf.concat([t_points, x_points], axis=1)

        if isinstance(model, SemiModularElasticWavePINN):
            # Handle semi-modular case
            self.plot_single_network_activation(
                model.displacement_hidden_layers, model.displacement_output_layer,
                combined_input, save_path, label="Displacement"
            )
            self.plot_single_network_activation(
                model.stress_hidden_layers, model.stress_output_layer,
                combined_input, save_path, label="Stress"
            )
        elif isinstance(model, ModularPINNs):
            # Handle modular networks
            for i, sub_model in enumerate(model.models):
                self.plot_single_network_activation(
                    sub_model.hidden_layers, sub_model.output_layer,
                    combined_input, save_path, label=f"Component_{i + 1}"
                )
        else:
            # Handle unified network
            self.plot_single_network_activation(
                model.hidden_layers, model.output_layer,
                combined_input, save_path, label="Unified"
            )

    def plot_single_network_activation(self, hidden_layers, output_layer, combined_input, save_path, label):
        """Helper function to visualize activations for a single network."""
        activations = []
        out = combined_input

        # Collect activations layer by layer
        for layer in hidden_layers:
            out = layer(out)
            activations.append(out.numpy())
        activations.append(output_layer(out).numpy())

        # Plot the network structure with activations
        plt.figure(figsize=(12, 8))
        max_neurons = max(layer.shape[1] for layer in activations)

        for layer_idx, layer_activations in enumerate(activations):
            num_neurons = layer_activations.shape[1]
            x_positions = np.linspace(-max_neurons / 2, max_neurons / 2, num_neurons)
            y_position = -layer_idx

            # Normalize activations to color scale
            norm_activations = (layer_activations - layer_activations.min()) / (
                layer_activations.max() - layer_activations.min()
            )

            # Plot neurons
            scatter = plt.scatter(
                x_positions,
                [y_position] * num_neurons,
                s=200,
                c=norm_activations.mean(axis=0),
                cmap="viridis",
                edgecolor="k",
                zorder=2,
            )

            # Connect layers
            if layer_idx < len(activations) - 1:
                next_layer_activations = activations[layer_idx + 1]
                next_x_positions = np.linspace(
                    -max_neurons / 2, max_neurons / 2, next_layer_activations.shape[1]
                )
                for i, x_pos in enumerate(x_positions):
                    for j, next_x_pos in enumerate(next_x_positions):
                        plt.plot(
                            [x_pos, next_x_pos], [y_position, y_position - 1], "gray", alpha=0.3, zorder=1
                        )

        cbar = plt.colorbar(scatter, pad=0.02)
        cbar.set_label("Activation Intensity")

        plt.title(f"Neuron Activations ({label})")
        plt.xlabel("Neuron Index")
        plt.ylabel("Layer")
        plt.gca().invert_yaxis()
        plt.grid(False)
        plt.tight_layout()
        plt.savefig(f"{save_path}/{label}_activations.png")
        plt.close()

    def collect_activations(self, model, x_points, t_points):
        combined_input = tf.concat([t_points, x_points], axis=1)
        activations = {}

        if isinstance(model, SemiModularElasticWavePINN):
            # Collect activations for displacement sub-network
            displacement_activations = []
            out_displacement = combined_input
            for layer in model.displacement_hidden_layers:
                out_displacement = layer(out_displacement)
                displacement_activations.append(out_displacement.numpy())
            displacement_activations.append(model.displacement_output_layer(out_displacement).numpy())

            # Collect activations for stress sub-network
            stress_activations = []
            out_stress = combined_input
            for layer in model.stress_hidden_layers:
                out_stress = layer(out_stress)
                stress_activations.append(out_stress.numpy())
            stress_activations.append(model.stress_output_layer(out_stress).numpy())

            # Store activations separately
            activations['displacement'] = displacement_activations
            activations['stress'] = stress_activations

        elif isinstance(model, ModularPINNs):
            # Collect activations for modular sub-networks
            for i, sub_model in enumerate(model.models):
                modular_activations = []
                out = combined_input
                for layer in sub_model.hidden_layers:
                    out = layer(out)
                    modular_activations.append(out.numpy())
                modular_activations.append(sub_model.output_layer(out).numpy())
                activations[f'component_{i + 1}'] = modular_activations

        else:  # Unified PINN
            unified_activations = []
            out = combined_input
            for layer in model.hidden_layers:
                out = layer(out)
                unified_activations.append(out.numpy())
            unified_activations.append(model.output_layer(out).numpy())
            activations['unified'] = unified_activations

        return activations


    def compute_neuron_activation_overlap(self, activations, threshold=0.05):
        """
        Compute overlap and correlation metrics between neuron activations
        for different layers (layer-wise) and across components (component-wise)
        across unified, modular, and semi-modular architectures.
        """
        layer_overlap_metrics = []
        layer_correlation_metrics = []
        component_overlap_metrics = []
        component_correlation_metrics = []

        def normalize_activations(act1, act2):
            min_dim = min(act1.shape[1], act2.shape[1])
            act1 = act1[:, :min_dim]
            act2 = act2[:, :min_dim]
            return act1, act2

        def calculate_overlap(activated_i, activated_j):
            union = np.sum(activated_i | activated_j)
            if union == 0:
                return 0  # Avoid division by zero
            return np.sum(activated_i & activated_j) / union

        # Unified Network
        if 'unified' in activations:
            unified_activations = activations['unified']
            num_layers = len(unified_activations)

            # Layer-wise analysis
            for i in range(num_layers - 1):
                layer_activations_i = unified_activations[i]
                layer_activations_j = unified_activations[i + 1]

                if layer_activations_i.shape[1] != layer_activations_j.shape[1]:
                    layer_activations_i, layer_activations_j = normalize_activations(layer_activations_i, layer_activations_j)

                activated_i = np.abs(layer_activations_i) > threshold
                activated_j = np.abs(layer_activations_j) > threshold
                overlap = calculate_overlap(activated_i, activated_j)
                correlation = np.corrcoef(layer_activations_i.flatten(), layer_activations_j.flatten())[0, 1]

                layer_overlap_metrics.append(overlap)
                layer_correlation_metrics.append(correlation)

            # Component-wise analysis
            for i in range(self.system.num_components):
                for j in range(i + 1, self.system.num_components):
                    component_activations_i = unified_activations[i]
                    component_activations_j = unified_activations[j]

                    if component_activations_i.shape[1] != component_activations_j.shape[1]:
                        component_activations_i, component_activations_j = normalize_activations(component_activations_i, component_activations_j)

                    activated_i = np.abs(component_activations_i) > threshold
                    activated_j = np.abs(component_activations_j) > threshold
                    overlap = calculate_overlap(activated_i, activated_j)
                    correlation = np.corrcoef(component_activations_i.flatten(), component_activations_j.flatten())[0, 1]

                    component_overlap_metrics.append(overlap)
                    component_correlation_metrics.append(correlation)

        # Modular Network
        elif isinstance(activations, dict) and all(key.startswith('component') for key in activations.keys()):
            keys = list(activations.keys())

            # Layer-wise analysis for each sub-network
            for key in keys:
                component_activations = activations[key]
                num_layers = len(component_activations)

                for i in range(num_layers - 1):
                    layer_activations_i = component_activations[i]
                    layer_activations_j = component_activations[i + 1]

                    if layer_activations_i.shape[1] != layer_activations_j.shape[1]:
                        layer_activations_i, layer_activations_j = normalize_activations(layer_activations_i, layer_activations_j)

                    activated_i = np.abs(layer_activations_i) > threshold
                    activated_j = np.abs(layer_activations_j) > threshold
                    overlap = calculate_overlap(activated_i, activated_j)
                    correlation = np.corrcoef(layer_activations_i.flatten(), layer_activations_j.flatten())[0, 1]

                    layer_overlap_metrics.append(overlap)
                    layer_correlation_metrics.append(correlation)

            # Component-wise analysis
            for i in range(len(keys)):
                for j in range(i + 1, len(keys)):
                    component_activations_i = activations[keys[i]]
                    component_activations_j = activations[keys[j]]

                    for layer_idx in range(len(component_activations_i)):
                        layer_activations_i = component_activations_i[layer_idx]
                        layer_activations_j = component_activations_j[layer_idx]

                        if layer_activations_i.shape[1] != layer_activations_j.shape[1]:
                            layer_activations_i, layer_activations_j = normalize_activations(layer_activations_i, layer_activations_j)

                        activated_i = np.abs(layer_activations_i) > threshold
                        activated_j = np.abs(layer_activations_j) > threshold
                        overlap = calculate_overlap(activated_i, activated_j)
                        correlation = np.corrcoef(layer_activations_i.flatten(), layer_activations_j.flatten())[0, 1]

                        component_overlap_metrics.append(overlap)
                        component_correlation_metrics.append(correlation)

        # Return all metrics
        avg_layer_overlap = np.mean(layer_overlap_metrics) if layer_overlap_metrics else 0
        avg_layer_correlation = np.mean(layer_correlation_metrics) if layer_correlation_metrics else 0
        avg_component_overlap = np.mean(component_overlap_metrics) if component_overlap_metrics else 0
        avg_component_correlation = np.mean(component_correlation_metrics) if component_correlation_metrics else 0

        return (
            layer_overlap_metrics,
            layer_correlation_metrics,
            avg_layer_overlap,
            avg_layer_correlation,
            component_overlap_metrics,
            component_correlation_metrics,
            avg_component_overlap,
            avg_component_correlation,
        )

    def analyze_interpretability(self, models, save_path="interpretability_analysis"):
        os.makedirs(save_path, exist_ok=True)

        # Generate test points
        x_points, t_points = self.generate_validation_points()

        for model, name in models:
            print(f"Analyzing {name}...")

            try:
                # Neuron Activation Visualization
                activations = self.collect_activations(model, x_points, t_points)
                (
                    layer_overlaps, layer_correlations, avg_layer_overlap, avg_layer_correlation,
                    component_overlaps, component_correlations, avg_component_overlap, avg_component_correlation
                ) = self.compute_neuron_activation_overlap(activations)

                print(f"{name} - Average Layer-wise Overlap: {avg_layer_overlap:.4f}, Average Layer-wise Correlation: {avg_layer_correlation:.4f}")
                print(f"{name} - Average Component-wise Overlap: {avg_component_overlap:.4f}, Average Component-wise Correlation: {avg_component_correlation:.4f}")

                # Save plots for neuron activations
                self.plot_neural_network_activation(model, x_points, t_points, save_path=f"{save_path}/{name}_activation")

                # Neuron sparsity
                self.analyze_network_sparsity(model, x_points, t_points)

            except Exception as e:
                print(f"Interpretability analysis failed for {name}: {e}")
                import traceback
                traceback.print_exc()

    def compute_activation_sparsity(self, activations, threshold=0.001):
        """
        Computes activation sparsity for each hidden layer:
        - Fraction of neurons that remain inactive (activation < threshold).
        - Returns sparsity per layer and overall average sparsity.
        """
        sparsity_metrics = []

        # Exclude the output layer (last layer)
        hidden_layers = activations[:-1]

        for layer_activations in hidden_layers:
            inactive_neurons = np.mean(np.abs(layer_activations) < threshold, axis=0)
            layer_sparsity = np.mean(inactive_neurons)
            sparsity_metrics.append(layer_sparsity)

        return sparsity_metrics, np.mean(sparsity_metrics)


    def analyze_network_sparsity(self, model, x_points, t_points):
        """
        Computes and plots activation sparsity for both
        unified and modular networks, **excluding the output layer**.
        """
        activations = self.collect_activations(model, x_points, t_points)

        sparsity_results = {}

        if isinstance(model, ModularPINNs):
            # For Modular Networks, compute sparsity per component
            for i, (component, component_activations) in enumerate(activations.items()):
                sparsity_metrics, avg_sparsity = self.compute_activation_sparsity(component_activations)

                sparsity_results[f'Component {i+1}'] = (sparsity_metrics, avg_sparsity)
        elif isinstance(model, SemiModularElasticWavePINN):
            # Compute sparsity for displacement and stress subnetworks separately
            sparsity_displacement, avg_sparsity_displacement = self.compute_activation_sparsity(activations['displacement'])
            sparsity_stress, avg_sparsity_stress = self.compute_activation_sparsity(activations['stress'])

            sparsity_results['Displacement'] = (sparsity_displacement, avg_sparsity_displacement)
            sparsity_results['Stress'] = (sparsity_stress, avg_sparsity_stress)
        else:
            # For Unified Networks, compute sparsity across all layers
            sparsity_metrics, avg_sparsity = self.compute_activation_sparsity(activations['unified'])

            sparsity_results['Unified'] = (sparsity_metrics, avg_sparsity)

        self.plot_sparsity(sparsity_results, activations)

    def compute_sparsity_cdf(self, activations_dict):
        """
        Computes the cumulative fraction of neurons that remain inactive (below threshold).
        This creates a sparsity CDF across different activation thresholds for unified and modular networks.
        """

        # Expand thresholds by adding intermediate values
        base_thresholds = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
        fine_thresholds = []

        for i in range(len(base_thresholds) - 1):
            start, end = base_thresholds[i], base_thresholds[i + 1]
            fine_thresholds.extend(np.linspace(start, end, 10, endpoint=False))  # Add 9 intermediate points

        fine_thresholds.append(1e-1)  # Ensure last value is included
        thresholds = sorted(fine_thresholds)  # Sorted for clarity

        sparsity_cdf_results = {}

        for key, activations in activations_dict.items():
            if not isinstance(activations, list) or len(activations) == 0:
                print(f"Skipping CDF for {key}: No activations available.")
                continue

            sparsity_cdf = []
            hidden_layers = activations[:-1]  # Exclude output layer

            for threshold in thresholds:
                inactive_counts = []
                for layer_activations in hidden_layers:
                    if isinstance(layer_activations, np.ndarray) and layer_activations.shape[0] > 0:
                        inactive_neurons = np.mean(np.abs(layer_activations) < threshold, axis=0)
                        inactive_counts.append(np.mean(inactive_neurons))
                    else:
                        inactive_counts.append(0)  # Default to 0 if no activations exist
                sparsity_cdf.append(np.mean(inactive_counts))  # Average across layers

            sparsity_cdf_results[key] = (thresholds, sparsity_cdf)

        return sparsity_cdf_results

    def plot_sparsity(self, sparsity_results, activations):
        """
        Plots activation sparsity and a CDF of sparsity for unified and modular networks.
        """
        fig, axes = plt.subplots(1, 2, figsize=(18, 5))

        for component, (sparsity_metrics, avg_sparsity) in sparsity_results.items():
            layer_indices = list(range(len(sparsity_metrics)))  # Ensure integer indices
            axes[0].plot(layer_indices, sparsity_metrics, label=f"{component} (Avg: {avg_sparsity:.4f})")

        axes[0].set_xlabel("Layer Index")
        axes[0].set_ylabel("Activation Sparsity")
        axes[0].set_title("Activation Sparsity Across Hidden Layers")
        axes[0].legend()
        axes[0].grid()
        axes[0].set_xticks(layer_indices)  # Set ticks to integer layer indices

        sparsity_cdf_results = self.compute_sparsity_cdf(activations)

        for component, (thresholds, sparsity_cdf) in sparsity_cdf_results.items():
            axes[1].plot(thresholds, sparsity_cdf, marker='o', label=f"{component}")

        axes[1].set_xlabel("Activation Threshold")
        axes[1].set_xscale("log")  # Log scale for better visualization
        axes[1].set_ylabel("Fraction of Neurons Inactive")
        axes[1].set_title("Sparsity CDF")
        axes[1].legend()
        axes[1].grid()

        # Modify x-axis ticks to only show base 1e-x values
        base_thresholds = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
        axes[1].set_xticks(base_thresholds)  # Keep major thresholds
        axes[1].set_xticklabels([f"$10^{{-{int(np.log10(x))}}}$" for x in base_thresholds])  # Use scientific notation

        plt.tight_layout()
        plt.show()

# Systems

In [5]:
# @title Reaction Diffusion Systems
class ReactionDiffusionSystem(PDESystem):
    def __init__(self):
        super().__init__(num_components=3, x_domains=[(0, 1)], t_domain=(0, 1))

    def residuals(self, model, t_points, x_points):
        with tf.GradientTape(persistent=True) as tape2:
            tape2.watch([t_points, x_points])

            with tf.GradientTape(persistent=True) as tape1:
                tape1.watch([t_points, x_points])
                u_pred = model(x_points, t_points)
                u1, u2, u3 = tf.split(u_pred, num_or_size_splits=3, axis=1)

            u1_t = tape1.gradient(u1, t_points)
            u2_t = tape1.gradient(u2, t_points)
            u3_t = tape1.gradient(u3, t_points)

            u1_x = tape1.gradient(u1, x_points)
            u2_x = tape1.gradient(u2, x_points)
            u3_x = tape1.gradient(u3, x_points)

        u1_xx = tape2.gradient(u1_x, x_points)
        u2_xx = tape2.gradient(u2_x, x_points)
        u3_xx = tape2.gradient(u3_x, x_points)

        del tape1, tape2

        res1 = u1_t - 0.1 * u1_xx + u1 * u2 - u3
        res2 = u2_t - 0.2 * u2_xx + u2 * u3 - u1
        res3 = u3_t - 0.3 * u3_xx + u3 * u1 - u2

        return [res1, res2, res3]

class ReactionDiffusionFirstOrderSystem(PDESystem):
    def __init__(self):
        super().__init__(num_components=6, x_domains=[(0, 1)], t_domain=(0, 1))

    def residuals(self, model, t_points, x_points):
        with tf.GradientTape(persistent=True) as tape:
            tape.watch([t_points, x_points])
            u_pred = model(x_points, t_points)
            u1, u2, u3, v1, v2, v3 = tf.split(u_pred, num_or_size_splits=6, axis=1)

            u1_t = tape.gradient(u1, t_points)
            u2_t = tape.gradient(u2, t_points)
            u3_t = tape.gradient(u3, t_points)

            v1_t = tape.gradient(v1, t_points)
            v2_t = tape.gradient(v2, t_points)
            v3_t = tape.gradient(v3, t_points)

            v1_x = tape.gradient(v1, x_points)
            v2_x = tape.gradient(v2, x_points)
            v3_x = tape.gradient(v3, x_points)

        del tape

        res1 = u1_t - 0.1 * v1_x + u1 * u2 - u3
        res2 = u2_t - 0.2 * v2_x + u2 * u3 - u1
        res3 = u3_t - 0.3 * v3_x + u3 * u1 - u2
        res4 = v1_t + v1_x
        res5 = v2_t + v2_x
        res6 = v3_t + v3_x

        return [res1, res2, res3, res4, res5, res6]

In [6]:
# @title Elastic Wave Systems
class ElasticWaveSystem(PDESystem):
    def __init__(self, rho=1.0, C11=1.0, C12=1.0, C22=1.0, C33=1.0):
        super().__init__(num_components=5, x_domains=[(0, 1), (0, 1)], t_domain=(0, 1))
        self.rho = rho
        self.C11 = C11
        self.C12 = C12
        self.C22 = C22
        self.C33 = C33

    def residuals(self, model, t_points, x_points):
        with tf.GradientTape(persistent=True) as tape:
            tape.watch([t_points, x_points])

            u_pred = model(x_points, t_points)
            u_x, u_y, sigma_xx, sigma_yy, sigma_xy = tf.split(u_pred, num_or_size_splits=5, axis=1)

            u_x_t = tape.gradient(u_x, t_points)
            u_y_t = tape.gradient(u_y, t_points)

            u_x_gradients = tape.gradient(u_x, x_points)
            u_y_gradients = tape.gradient(u_y, x_points)

            u_x_x1 = u_x_gradients[:, 0:1]
            u_x_x2 = u_x_gradients[:, 1:2]
            u_y_x1 = u_y_gradients[:, 0:1]
            u_y_x2 = u_y_gradients[:, 1:2]

            sigma_xx = self.C11 * u_x_x1 + self.C12 * u_y_x1
            sigma_yy = self.C22 * u_y_x2 + self.C12 * u_x_x2
            sigma_xy = self.C33 * (u_x_x2 + u_y_x1)

            u_x_tt = tape.gradient(u_x_t, t_points)
            u_y_tt = tape.gradient(u_y_t, t_points)

            sigma_xx_gradients = tape.gradient(sigma_xx, x_points)
            sigma_xy_gradients = tape.gradient(sigma_xy, x_points)
            sigma_yy_gradients = tape.gradient(sigma_yy, x_points)

        res_u_x = self.rho * u_x_tt - (sigma_xx_gradients[:, 0:1] + sigma_xy_gradients[:, 1:2])
        res_u_y = self.rho * u_y_tt - (sigma_xy_gradients[:, 0:1] + sigma_yy_gradients[:, 1:2])

        del tape
        return [res_u_x, res_u_y, sigma_xx, sigma_yy, sigma_xy]

class ElasticWaveSystemStronglyCoupled(ElasticWaveSystem):
    def __init__(self):
        super().__init__(rho=1.0, C11=1.0, C12=0.9, C22=1.0, C33=0.8)

class ElasticWaveSystemDirectionalDominance(ElasticWaveSystem):
    def __init__(self):
        super().__init__(rho=1.0, C11=2.0, C12=0.5, C22=1.0, C33=0.8)

class ElasticWaveSystemAnisotropic(ElasticWaveSystem):
    def __init__(self):
        super().__init__(rho=1.0, C11=5.0, C12=0.2, C22=1.0, C33=0.2)

# 1. Reaction-Diffusion System

In [None]:
#@title Experiments
comp = UnifiedVsModularComparison(ReactionDiffusionSystem())

for (arch, name) in experiments:
    uni, mod, scaled = comp.train_unified_and_modular(500, 1000, 1e-3, unified_architecture=arch, train_scaled=True)

    models = [
        (uni, "UnifiedPINN"),
        (mod, "ModularPINN"),
        (scaled, "ScaledModularPINN")
    ]

    comp.analyze_interpretability(models, save_path=f"ReactionDiffusion_{name}_Interpretability")

In [None]:
#@title Data
import pandas as pd
import matplotlib.pyplot as plt
from adjustText import adjust_text

data = {
    'Architecture': [(64, 64), (128, 64), (128, 64, 32), (64, 64, 64), (64, 128, 128, 64),
                     (256, 256), (64, 64, 64, 64), (512, 256, 128, 64), (256, 256, 256, 256)],
    'Total_Params': [4547, 8835, 10819, 8707, 33475, 67331, 12867, 174211, 198915],
    'Unified_Final_Loss': [4.93e-6, 1.21e-6, 1.37e-6, 1.55e-6, 5.76e-7, 9.25e-8, 1.46e-6, 5.22e-8, 1.82e-8],
    'Unified_Layerwise_Overlap': [0.3032, 0.2481, 0.316, 0.3397, 0.3136, 0.1633, 0.3673, 0.1249, 0.0853],
    'Unified_Layerwise_Abs_Correlation': [0.2121, 0.2498, 0.1045, 0.1357, 0.1067, 0.1426, 0.1124, 0.0748, 0.1374],
    'Unified_Componentwise_Overlap': [0.2113, 0.1635, 0.4492, 0.5120, 0.4282, 0.1062, 0.4940, 0.1408, 0.1097],
    'Unified_Componentwise_Abs_Correlation': [0.2735, 0.2343, 0.0893, 0.2809, 0.0577, 0.1368, 0.0170, 0.0282, 0.0306],
    'Unified_Componentwise_Correlation': [0.2735, 0.2343, 0.0893, 0.2809, 0.03999, 0.1368, -0.01128, 0.1328, 0.03068],
    'Modular_Final_Loss': [2.74e-6, 1.80e-6, 4.14e-6, 1.19e-6, 1.58e-6, 1.88e-7, 1.23e-6, 9.48e-6, 3.32e-7],
    'Modular_Layerwise_Overlap': [0.3731, 0.3285, 0.3716, 0.4088, 0.3513, 0.2185, 0.3943, 0.3289, 0.2095],
    'Modular_Layerwise_Abs_Correlation': [0.1959, 0.0919, 0.1252, 0.0530, 0.0509, 0.0874, 0.0752, 0.0740, 0.0481],
    'Modular_Componentwise_Overlap': [0.3647, 0.3192, 0.3368, 0.4012, 0.3720, 0.2209, 0.4096, 0.1807, 0.1927],
    'Modular_Componentwise_Abs_Correlation': [0.2135, 0.0861, 0.1523, 0.0645, 0.0654, 0.1010, 0.1126, 0.0278, 0.0290],
    'Modular_Componentwise_Correlation': [0.0021, -0.0758, 0.0071, 0.0021, -0.0478, -0.0758, -0.04318, -0.0361, -0.02264],
    'Scaled_Final_Loss': [1.33e-5, 2.47e-6, 4.71e-6, 6.06e-6, 2.36e-6, 3.67e-7, 1.60e-6, 3.45e-6, 4.38e-7],
    'Scaled_Layerwise_Overlap': [0.4287, 0.3818, 0.4261, 0.4423, 0.4735, 0.2897, 0.4771, 0.3497, 0.2544],
    'Scaled_Layerwise_Abs_Correlation': [0.1278, 0.1586, 0.1397, 0.1000, 0.0789, 0.1391, 0.0883, 0.0895, 0.0474],
    'Scaled_Componentwise_Overlap': [0.4511, 0.3758, 0.4424, 0.4733, 0.4400, 0.2861, 0.4994, 0.2562, 0.2563],
    'Scaled_Componentwise_Abs_Correlation': [0.1535, 0.1547, 0.0953, 0.1035, 0.0949, 0.1046, 0.0832, 0.0366, 0.0495],
    'Scaled_Componentwise_Correlation': [-0.0063, 0.0681, 0.0043, -0.0063, 0.0432, 0.0681, -0.01724, -0.0274, -0.01182],
    'Unified_Epochs_10^-4': [51, 20.6, 16.2, 16.8, 19.6, 15.6, 20.6, 36.8, 33.2],
    'Unified_Epochs_10^-5': [130.2, 43.8, 42.4, 43.2, 43.8, 41.2, 39.2, 82.0, 66.2],
    'Unified_Epochs_10^-6': [None, 93.4, 84.0, 92.4, 85.4, 75.6, 75.2, 176.4, 132.6],
    'Unified_Epochs_10^-7': [None, 258.0, None, None, None, 111.8, None, 264.6, 197.8],
    'Modular_Epochs_10^-4': [52.6, 18.4, 24.0, 24.6, 34.0, 33.2, 23.2, 52.0, 43.6],
    'Modular_Epochs_10^-5': [122.8, 42.0, 52.0, 45.8, 63.2, 70.4, 51.0, 86.0, 89.2],
    'Modular_Epochs_10^-6': [None, 81.0, 91.2, 88.6, 94.2, 116.6, 76.8, 256.8, 188.8],
    'Modular_Epochs_10^-7': [None, None, None, None, None, None, None, None, 306.3],
    'Scaled_Epochs_10^-4': [90.8, 21.2, 25.6, 25.0, 27.6, 21.2, 26.8, 46.4, 37.8],
    'Scaled_Epochs_10^-5': [224.3, 47.4, 58.0, 54.0, 57.6, 51.0, 51.2, 88.6, 78.2],
    'Scaled_Epochs_10^-6': [None, 122.0, 219.8, 278.2, 167.0, 88.8, 195.6, 101.3, 150.2],
    'Scaled_Epochs_10^-7': [None, None, None, None, None, None, None, 178.3, 262.0]
}

df = pd.DataFrame(data)

# Function to create scatter plots with adjusted labels
def scatter_plot_with_labels(x, y, title, x_label, y_label, filename):
    plt.figure(figsize=(12, 8))
    texts = []
    for label, col, color in zip(["Unified", "Modular", "Scaled"],
                                 ["Unified_" + y, "Modular_" + y, "Scaled_" + y],
                                 ['blue', 'green', 'red']):
        plt.scatter(df[x], df[col], label=label, color=color)
        for i, arch in enumerate(df['Architecture']):
            texts.append(plt.text(df[x][i], df[col][i], str(arch), fontsize=8))
    plt.xscale('log')
    if y in ["Final_Loss"]:
        plt.yscale('log')
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.grid(True)
    adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))
    plt.savefig(filename)
    plt.show()

# for y in ["-4", "-5", "-6", "-7"]:
#     scatter_plot_with_labels(
#         'Total_Params',
#         f'Epochs_10^{y}',
#         f'Convergence Speed vs. Total Parameters for $10^{{{y}}}$',
#         'Total Parameters (log scale)',
#         f'Average Epochs to Reach $10^{{{y}}}$',
#         f'convergence_speed_vs_params_{y}.png'
#     )

# scatter_plot_with_labels('Total_Params', 'Final_Loss',
#                          'Final Loss vs. Total Parameters',
#                          'Total Parameters (log scale)',
#                          'Final Loss (log scale)',
#                          'final_loss_vs_params.png')

scatter_plot_with_labels('Total_Params', 'Layerwise_Overlap',
                         'Layer-wise Overlap vs. Total Parameters',
                         'Total Parameters (log scale)',
                         'Layer-wise Overlap',
                         'overlap_vs_params.png')

scatter_plot_with_labels('Total_Params', 'Layerwise_Abs_Correlation',
                         'Layer-wise Abs. Correlation vs. Total Parameters',
                         'Total Parameters (log scale)',
                         'Layer-wise Abs. Correlation',
                         'correlation_vs_params.png')

scatter_plot_with_labels('Total_Params', 'Componentwise_Overlap',
                         'Component-wise Overlap vs. Total Parameters',
                         'Total Parameters (log scale)',
                         'Component-wise Overlap',
                         'overlap_vs_params.png')

scatter_plot_with_labels('Total_Params', 'Componentwise_Abs_Correlation',
                         'Component-wise Abs. Correlation vs. Total Parameters',
                         'Total Parameters (log scale)',
                         'Component-wise Abs. Correlation',
                         'correlation_vs_params.png')

scatter_plot_with_labels('Total_Params', 'Componentwise_Correlation',
                         'Component-wise Correlation vs. Total Parameters',
                         'Total Parameters (log scale)',
                         'Component-wise Correlation',
                         'correlation_vs_params.png')

# 2. Elastic Wave System

In [None]:
#@title Strongly Coupled
# strongly_coupled = UnifiedVsModularComparison(ElasticWaveSystemStronglyCoupled())

# for (arch, name) in experiments:
#     uni_strongly, mod_strongly = strongly_coupled.train_unified_and_modular(500, 1000, 1e-3, unified_architecture=arch, train_scaled=False)

#     models = [
#         (uni_strongly, "UnifiedPINN"),
#         (mod_strongly, "ModularPINN"),
#     ]

#     strongly_coupled.analyze_interpretability(models, save_path=f"ElasticWave_StronglyCoupled_{name}_Interpretability")

In [None]:
#@title Directionally Dominant
# directionally_dominant = UnifiedVsModularComparison(ElasticWaveSystemDirectionalDominance())

# for (arch, name) in experiments:
#     uni_dominant, mod_dominant = directionally_dominant.train_unified_and_modular(500, 1000, 1e-3, unified_architecture=arch, train_scaled=False)

#     models = [
#         (uni_dominant, "UnifiedPINN"),
#         (mod_dominant, "ModularPINN"),
#     ]

#     directionally_dominant.analyze_interpretability(models, save_path=f"ElasticWave_DirectionalDominance_{name}_Interpretability")

In [None]:
#@title Anisotropic
# anisotropic = UnifiedVsModularComparison(ElasticWaveSystemAnisotropic())

# for (arch, name) in experiments:
#     uni_anisotropic, mod_anisotropic = anisotropic.train_unified_and_modular(500, 1000, 1e-3, unified_architecture=arch, train_scaled=False)

#     models = [
#         (uni_anisotropic, "UnifiedPINN"),
#         (mod_anisotropic, "ModularPINN"),
#     ]

#     anisotropic.analyze_interpretability(models, save_path=f"ElasticWave_Anisotropic_{name}_Interpretability")

In [None]:
# @title SemiModularElasticWave Network
class SemiModularElasticWavePINN(tf.keras.Model):
    def __init__(self, displacement_architecture=(128, 64), stress_architecture=(128, 64, 32), activation='tanh'):
        super(SemiModularElasticWavePINN, self).__init__()

        # Displacement sub-network (u_x, u_y)
        self.displacement_hidden_layers = [
            tf.keras.layers.Dense(units, activation=activation) for units in displacement_architecture
        ]
        self.displacement_output_layer = tf.keras.layers.Dense(4)  # Outputs: u_x, u_y

        # Stress sub-network (sigma_xx, sigma_yy, sigma_xy)
        self.stress_hidden_layers = [
            tf.keras.layers.Dense(units, activation=activation) for units in stress_architecture
        ]
        self.stress_output_layer = tf.keras.layers.Dense(1)  # Outputs: sigma_xx, sigma_yy, sigma_xy

    def call(self, x_points, t_points):
        combined_input = tf.concat([t_points, x_points], axis=1)

        # Displacement sub-network
        out_displacement = combined_input
        for layer in self.displacement_hidden_layers:
            out_displacement = layer(out_displacement)
        displacement_outputs = self.displacement_output_layer(out_displacement)

        # Stress sub-network
        out_stress = combined_input
        for layer in self.stress_hidden_layers:
            out_stress = layer(out_stress)
        stress_outputs = self.stress_output_layer(out_stress)

        # Combine outputs
        return tf.concat([displacement_outputs, stress_outputs], axis=1)

    @property
    def hidden_layers(self):
        # Combine displacement and stress hidden layers for unified access
        return self.displacement_hidden_layers + self.stress_hidden_layers


comp = UnifiedVsModularComparison(ElasticWaveSystem())
semi_modular_model = SemiModularElasticWavePINN(
    displacement_architecture=(256, 256),
    stress_architecture=(256, 256)
)
unified_model = UnifiedPINN(
    output_dim=5,
    architecture=(256, 256),
    activation='tanh'
)
modular_model = ModularPINNs(
    num_components=5,
    architectures=[(256, 256)] * 5,
    activation='tanh'
)


# Train the semi-modular model
loss_history_semi, component_losses_semi, training_time_semi = comp.system.train(
    semi_modular_model, epochs=5, N=1000, lr=1e-3
)
loss_history_uni, component_losses_uni, training_time_uni = comp.system.train(
    unified_model, epochs=5, N=1000, lr=1e-3
)
loss_history_mod, component_losses_mod, training_time_mod = comp.system.train(
    modular_model, epochs=5, N=1000, lr=1e-3
)

UnifiedVsModularComparison.compare_overall_loss(loss_history_uni, loss_history_mod, loss_history_semi)
UnifiedVsModularComparison.compare_component_losses(component_losses_uni, component_losses_mod, component_losses_semi)

comp = UnifiedVsModularComparison(ElasticWaveSystem())
comp.analyze_interpretability([(semi_modular_model, "SemiModularPINN"), (unified_model, "UnifiedPINN"), (modular_model, "ModularPINN")], save_path="SemiModularElasticWave_Interpretability")