In [None]:
# Cell 1
import torch
import torch.nn as nn
from torch.quasirandom import SobolEngine
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Dict, Tuple, Optional, Union, Callable
from functools import cache

class PDEProblem(ABC):
    def __init__(
            self,
            name: str,
            input_vars: List[str] = ['x'],
            output_vars: List[str] = ['u'],
            time_var: Optional[str] = None,
            kappa_name: str = "kappa",
            default_kappa_value: float = 1.0
        ):
        self.name: str = name
        self.input_vars: List[str] = sorted(list(set(input_vars)))
        self.output_vars: List[str] = output_vars
        self.time_var: Optional[str] = time_var

        self.output_dim: int = len(self.output_vars)
        self.spatial_domain_dim: int = len(self.input_vars) - (1 if time_var else 0)
        self.time_dependent: bool = bool(time_var)

        self.kappa_name: str = kappa_name
        self.default_kappa_value: float = default_kappa_value

    @abstractmethod
    def get_domain_bounds(self) -> Dict[str, Tuple[float, float]]:
        """
        Returns a dictionary mapping input variable names to their (min, max) bounds.
        Example: {'x': (0.0, 1.0), 't': (0.0, 2.0)}
        """
        pass

    @abstractmethod
    def pde_residual(
            self,
            inputs: Dict[str, torch.Tensor],
            model_outputs: torch.Tensor, # Shape: (batch, output_dim)
            derivatives: Dict[str, torch.Tensor], # Keys like 'd(u)_dx(1)', 'd2(u)_dx(2)', 'd(v)_dt(1)' etc.
            kappa_value: float
        ) -> torch.Tensor: # Expected shape: (batch, num_pde_equations)
        """
        Calculates the PDE residual(s).
        - model_outputs: Tensor of shape (batch_size, self.output_dim)
        - derivatives: Dictionary where keys might be 'd(out_var)_d(in_var)(order)'
                       e.g., 'd1u_dx1' for du/dx, 'd2v_dydt1' for d^2v/dydt.
        Should return a tensor where each column is the residual of one PDE equation.
        For scalar PDEs, this will be (batch_size, 1).
        """
        pass

    @abstractmethod
    def boundary_conditions(
        self,
        inputs_bc: Dict[str, torch.Tensor],
        model_outputs_bc: torch.Tensor, # Shape: (batch_bc, output_dim),
        derivatives_bc: Dict[str, torch.Tensor], # Keys like 'd(u)_dx(1)', 'd(v)_dt(1)' etc.
        model: nn.Module,
        kappa_value: float
        ) -> torch.Tensor: # Scalar loss term
        pass

    def initial_conditions(
            self,
            inputs_ic: Dict[str, torch.Tensor],
            model_outputs_ic: torch.Tensor, # Shape: (batch_ic, output_dim)
            derivatives_ic: Dict[str, torch.Tensor], # Keys like 'd(u)_dx(1)', 'd(v)_dt(1)' etc.
            model: nn.Module,
            kappa_value: float
        ) -> torch.Tensor: # Scalar loss term
        if not self.time_dependent:
            device = 'cpu'
            if model:
                try: device = next(model.parameters()).device
                except StopIteration: pass
            elif isinstance(model_outputs_ic, torch.Tensor):
                device = model_outputs_ic.device
            return torch.tensor(0.0, device=device)
        raise NotImplementedError("Initial conditions must be implemented for time-dependent PDEs.")

    @abstractmethod
    def get_ground_truth(self,
                         inputs: Dict[str, torch.Tensor],
                         kappa_value: float) -> Optional[torch.Tensor]: # Shape: (batch, output_dim)
        pass

    def get_collocation_points(self,
                               num_points: int,
                               kappa_value: float,
                               device: Union[str, torch.device] = 'cpu',
                               strategy: str = 'uniform') -> Dict[str, torch.Tensor]:
        domain_bounds = self.get_domain_bounds()
        inputs = {}

        if strategy == 'sobol':
            num_input_dims_for_sampling = len(self.input_vars)
            if num_input_dims_for_sampling == 0: # Should not happen for collocation
                 return {}
            sobol = SobolEngine(dimension=num_input_dims_for_sampling, scramble=True)
            # Move Sobol samples to target device after generation
            samples_0_1 = sobol.draw(num_points).to(device)

        for i, var_name in enumerate(self.input_vars):
            if var_name not in domain_bounds:
                raise ValueError(f"Domain bounds not defined for variable: {var_name}")
            var_min, var_max = domain_bounds[var_name]

            if strategy == 'uniform':
                samples_var = torch.rand(num_points, 1, device=device) * (var_max - var_min) + var_min
            elif strategy == 'sobol':
                samples_var = samples_0_1[:, i:i+1] * (var_max - var_min) + var_min
            else:
                raise NotImplementedError(f"Collocation sampling strategy '{strategy}' not implemented for variable '{var_name}'.")

            inputs[var_name] = samples_var.requires_grad_(True)
        return inputs

    def get_boundary_points_hyperrect(self,
                            num_points_per_face: int,
                            kappa_value: float,
                            device: Union[str, torch.device] = 'cpu',
                            strategy: str = 'uniform') -> Dict[str, torch.Tensor]:
        domain_bounds = self.get_domain_bounds()
        all_bc_inputs = {v: [] for v in self.input_vars}
        # Spatial vars are input_vars excluding the time_var
        spatial_vars = [v for v in self.input_vars if v != self.time_var]

        if not spatial_vars: # No spatial dimensions, so no spatial boundaries
            return {v: torch.empty(0,1,device=device).detach() for v in self.input_vars}

        samples_bc_other_dims = None
        num_dims_to_sample_on_face = len(spatial_vars) - 1 + (1 if self.time_dependent else 0)

        if strategy == 'sobol' and num_dims_to_sample_on_face > 0:
            sobol_bc = SobolEngine(dimension=num_dims_to_sample_on_face, scramble=True)
            samples_bc_other_dims = sobol_bc.draw(num_points_per_face).to(device)
        elif strategy != 'uniform' and strategy != 'sobol': # if strategy is not uniform and sobol setup failed or not chosen
            raise NotImplementedError(f"Boundary sampling strategy '{strategy}' not supported.")


        for fixed_var_name in spatial_vars:
            other_sampling_vars = [v for v in spatial_vars if v != fixed_var_name]
            if self.time_dependent:
                other_sampling_vars.append(self.time_var)

            for boundary_value in domain_bounds[fixed_var_name]: # For min and max of this fixed_var
                current_face_inputs = {}
                current_face_inputs[fixed_var_name] = torch.full((num_points_per_face, 1),
                                                                boundary_value, dtype=torch.float32, device=device)

                sample_idx = 0
                for other_var_name in other_sampling_vars:
                    ov_min, ov_max = domain_bounds[other_var_name]
                    if strategy == 'sobol' and samples_bc_other_dims is not None:
                        current_face_inputs[other_var_name] = samples_bc_other_dims[:, sample_idx:sample_idx+1] * (ov_max - ov_min) + ov_min
                        sample_idx +=1
                    else: # Default to uniform if Sobol not applicable or not chosen
                        current_face_inputs[other_var_name] = torch.rand(num_points_per_face, 1, device=device) * (ov_max - ov_min) + ov_min

                # Append points for this face to the main list
                for var_n in self.input_vars:
                    all_bc_inputs[var_n].append(current_face_inputs[var_n])

        # Concatenate points from all faces
        final_bc_inputs = {}
        for var_n in self.input_vars:
            if all_bc_inputs[var_n]: # If any points were added for this variable
                final_bc_inputs[var_n] = torch.cat(all_bc_inputs[var_n], dim=0).detach() # BC points usually don't need grad
            else: # Should only happen if input_vars is empty or logic error
                final_bc_inputs[var_n] = torch.empty(0,1,device=device).detach()
        return final_bc_inputs

    def get_boundary_points_general(self,
                                       num_total_points: int, # Note: parameter name change
                                       kappa_value: float,
                                       device: Union[str, torch.device] = 'cpu',
                                       strategy: str = 'uniform' # Strategy for sampling on the general boundary
                                      ) -> Optional[Dict[str, torch.Tensor]]:
        """
        To be implemented by subclasses for non-rectangular/complex domains.
        Should return points lying *on* the boundary.
        Returns None to indicate this method is not implemented or not applicable,
        allowing fallback to get_boundary_points_hyperrect.
        """
        return None  # Indicating no general boundary points available

    def get_initial_points(self,
                           num_points: int,
                           kappa_value: float,
                           device: Union[str, torch.device] = 'cpu',
                           strategy: str = 'uniform') -> Dict[str, torch.Tensor]:
        if not self.time_dependent:
            return {v: torch.empty(0,1,device=device).requires_grad_(False) for v in self.input_vars}

        domain_bounds = self.get_domain_bounds()
        inputs = {}

        t_initial_val = domain_bounds[self.time_var][0]
        inputs[self.time_var] = torch.full((num_points, 1), t_initial_val, dtype=torch.float32, device=device)

        spatial_vars = [v for v in self.input_vars if v != self.time_var]
        if strategy == 'sobol' and spatial_vars: # only use sobol if there are spatial vars to sample
            sobol_ic = SobolEngine(dimension=len(spatial_vars), scramble=True)
            samples_0_1_ic = sobol_ic.draw(num_points).to(device)
        elif strategy != 'uniform' and strategy != 'sobol':
             raise NotImplementedError(f"IC sampling strategy '{strategy}' not supported.")


        for i, var_name in enumerate(spatial_vars):
            var_min, var_max = domain_bounds[var_name]
            if strategy == 'uniform':
                inputs[var_name] = torch.rand(num_points, 1, device=device) * (var_max - var_min) + var_min
            elif strategy == 'sobol' and spatial_vars: # check spatial_vars again for safety
                inputs[var_name] = samples_0_1_ic[:, i:i+1] * (var_max - var_min) + var_min
            # No else needed if strategy check is done above

        # Ensure all input_vars keys are present, even if fixed (like time)
        for var_name in self.input_vars:
            if var_name not in inputs: # e.g. if only time_var and no spatial_vars
                 if var_name == self.time_var: continue # already handled
                 # This case should be rare if input_vars is setup correctly with domain_bounds
                 inputs[var_name] = torch.empty(num_points, 1, device=device) # or handle error

            inputs[var_name].requires_grad_(False) # IC coords generally don't need grad
        return inputs

    @abstractmethod
    def get_required_derivative_orders(self) -> Dict[str, Dict[Tuple[str, ...], int]]:
        """
        Returns a dictionary specifying derivative requirements for each output variable.
        Structure:
          {
            'output_var_name_1': { # For the first output variable (e.g., 'u')
                # Simple derivatives:
                ('input_var_for_deriv',): order,  # e.g., ('x',): 2 for d2(u)/dx2
                # Mixed derivatives (sequence of differentiation):
                ('input_var_1', 'input_var_2', ...): 1, # e.g., ('x', 'y'): 1 for d/dy(d(u)/dx)
                                                       # The value (e.g., 1) indicates one application
                                                       # of this sequence of differentiations.
            },
            'output_var_name_2': { ... } # For the second output variable (e.g., 'v')
          }
        Example for -u_xx - u_yy = f (output_vars=['u']):
          {'u': {('x',): 2, ('y',): 2}}
        Example for u_t + v_x = 0, v_t + u_x = 0 (output_vars=['u', 'v']):
          {
            'u': {('t',): 1},
            'v': {('x',): 1, ('t',): 1} # Here u_x is not directly a derivative of 'v',
                                        # but if 'v' appears in an equation with u_x,
                                        # the PDE residual itself handles fetching u_x.
                                        # This dict is about derivatives OF the key output_var_name.
                                        # Let's refine this point below.
          }
        """
        pass

    def get_required_derivative_orders_for_bc(self) -> Optional[Dict[str, Dict[Tuple[str, ...], int]]]:
        """Specifies derivatives needed for boundary condition evaluation. Returns None if none needed."""
        return None

    def get_required_derivative_orders_for_ic(self) -> Optional[Dict[str, Dict[Tuple[str, ...], int]]]:
        """Specifies derivatives needed for initial condition evaluation. Returns None if none needed."""
        return None

    def calculate_specific_observables(self,
                                       inputs: Dict[str, torch.Tensor],
                                       model_outputs: torch.Tensor,
                                       ground_truth_outputs: Optional[torch.Tensor],
                                       kappa_value: float) -> Dict[str, float]:
        """
        Calculates PDE-specific physical observables and their errors.
        To be implemented by subclasses if relevant.
        Args:
            inputs: Dictionary of input tensors for the test grid.
            model_outputs: Tensor of model predictions on the test grid.
            ground_truth_outputs: Tensor of ground truth solutions on the test grid (if available).
            kappa_value: Current hardness parameter.
        Returns:
            A dictionary of observable names to their scalar values (e.g., errors).
            Example: {'soliton_amplitude_error': 0.01, 'shock_speed_error': 0.05}
        """
        return {} # Default implementation returns no specific observables


In [9]:
import torch
import torch.nn as nn
from typing import Union

def create_pinn_model(
    input_dim: int,
    output_dim: int,
    n_neurons_per_layer: int,
    n_hidden_layers: int = 1, # Default to SLN
    activation_str: str = "tanh",
    device: Union[str, torch.device] = 'cpu'
) -> nn.Module:
    """
    Creates a feedforward neural network (PINN model).

    Args:
        input_dim (int): Dimension of the input (e.g., 1 for u(x), 2 for u(x,t)).
        output_dim (int): Dimension of the output (e.g., 1 for scalar u).
        n_neurons_per_layer (int): Number of neurons in each hidden layer.
        n_hidden_layers (int): Number of hidden layers. Default is 1.
        activation_str (str): Activation function to use ('tanh', 'relu', 'sigmoid', 'leakyrelu').
                              Default is 'tanh'.
        device (Union[str, torch.device]): Device to send the model to ('cpu' or 'cuda').
                                           Default is 'cpu'.

    Returns:
        nn.Module: The PyTorch neural network model (nn.Sequential).
    """
    layers: list[nn.Module] = []

    if n_hidden_layers == 0: # Special case: linear model (no hidden layers)
        layers.append(nn.Linear(input_dim, output_dim))
    else:
        layers.append(nn.Linear(input_dim, n_neurons_per_layer))

        # Activation function selection
        if activation_str.lower() == 'tanh':
            activation_fn: nn.Module = nn.Tanh()
        elif activation_str.lower() == 'relu':
            activation_fn: nn.Module = nn.ReLU()
        elif activation_str.lower() == 'sigmoid':
            activation_fn: nn.Module = nn.Sigmoid()
        elif activation_str.lower() == 'leakyrelu':
            activation_fn: nn.Module = nn.LeakyReLU()
        else:
            raise ValueError(f"Unsupported activation: {activation_str}")

        layers.append(activation_fn)

        for _ in range(n_hidden_layers - 1):
            layers.append(nn.Linear(n_neurons_per_layer, n_neurons_per_layer))
            layers.append(activation_fn)

        # Output layer (connects last hidden layer to output_dim)
        layers.append(nn.Linear(n_neurons_per_layer, output_dim))

    model = nn.Sequential(*layers).to(device)

    # Apply initializations
    for i, layer in enumerate(model):
        if isinstance(layer, nn.Linear):
            if activation_str.lower() == 'tanh' or activation_str.lower() == 'sigmoid':
                nn.init.xavier_normal_(layer.weight) # Glorot normal
            elif activation_str.lower() == 'relu' or activation_str.lower() == 'leakyrelu':
                # For Kaiming, if the next layer is an activation, use that info.
                nn.init.kaiming_normal_(layer.weight, nonlinearity='relu' if activation_str.lower() == 'relu' else 'leaky_relu')

            if layer.bias is not None:
                nn.init.zeros_(layer.bias)

    return model


In [10]:
# Cell 3
import torch
import torch.optim as optim
import time
import numpy as np

class Trainer:
    def __init__(self, model, pde_problem: 'PDEProblem', optimizer_str="adam", learning_rate=1e-3, device='cpu'):
        self.model = model.to(device)
        self.pde_problem = pde_problem # Type hint for clarity
        self.device = device
        self.lr = learning_rate

        if optimizer_str.lower() == "adam":
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        elif optimizer_str.lower() == "lbfgs":
            self.optimizer = optim.LBFGS(self.model.parameters(), lr=self.lr, max_iter=20, line_search_fn="strong_wolfe")
        else:
            raise ValueError(f"Unsupported optimizer: {optimizer_str}")

        self.optimizer_str = optimizer_str
        self.epoch_wise_log = []

    def _prepare_model_input(self, inputs_dict: dict) -> torch.Tensor | None:
        """
        Prepares a single tensor input for the model from the inputs_dict.
        The order of concatenation is defined by self.pde_problem.input_vars.
        """
        if not inputs_dict:
            return None # Or handle as appropriate if model expects input even for empty dict

        ordered_input_tensors = []
        for var_name in self.pde_problem.input_vars:
            if var_name in inputs_dict:
                ordered_input_tensors.append(inputs_dict[var_name])
            else:
                # This should ideally not happen if PDEProblem methods are consistent
                raise ValueError(f"Input variable '{var_name}' expected by PDEProblem.input_vars "
                                 f"but not found in provided inputs_dict keys: {list(inputs_dict.keys())}")

        if not ordered_input_tensors: # Should be caught by the first check if inputs_dict is empty
             return torch.empty(0, device=self.device)

        return torch.cat(ordered_input_tensors, dim=1)

    def _compute_derivatives(self,
                             inputs_dict_with_grad: Dict[str, torch.Tensor],
                             model_outputs_tensor: torch.Tensor,
                             required_specs: Dict[str, Dict[Tuple[str, ...], int]] = None
                             ) -> Dict[str, torch.Tensor]:
        """
        Computes derivatives based on pde_problem.get_required_derivative_orders().
        model_outputs_tensor has shape (batch, pde_problem.output_dim)
        Derivatives are taken with respect to the individual tensors in inputs_dict_with_grad.

        Returns a dictionary of derivatives.
        Naming convention examples:
        - d(u)_dx(1)       for first derivative of 'u' wrt 'x'
        - d2(u)_dx(2)      for second derivative of 'u' wrt 'x'
        - d(v)_dt(1)       for first derivative of 'v' wrt 't'
        - d(u)_dx(1)dy(1)  for d/dy(du/dx)
        """
        derivatives: Dict[str, torch.Tensor] = {}
        if required_specs is None:
            required_specs = self.pde_problem.get_required_derivative_orders()
        if not required_specs:
            return derivatives
        output_var_names = self.pde_problem.output_vars

        for out_idx, out_var_name in enumerate(output_var_names):
            if out_var_name not in required_specs: # If no derivatives are listed for this output var
                continue

            # current_output_component is (batch_size, 1)
            current_output_component = model_outputs_tensor[:, out_idx:out_idx+1]

            spec_for_this_output_var = required_specs[out_var_name]

            for input_var_sequence, order_val in spec_for_this_output_var.items():
                # input_var_sequence is a tuple, e.g., ('x',) or ('x', 't')
                # order_val for simple derivatives is the max order, e.g., 2 for d2u/dx2
                # order_val for mixed sequence is typically 1 (one application of the sequence)

                if not isinstance(input_var_sequence, tuple) or not input_var_sequence:
                    raise ValueError(f"Invalid input_var_sequence: {input_var_sequence} for {out_var_name}")

                # --- Handle Simple Derivatives (e.g., ('x',): 2 for d2u/dx2) ---
                if len(input_var_sequence) == 1:
                    input_var_name_for_deriv = input_var_sequence[0]
                    max_order = order_val

                    if input_var_name_for_deriv not in inputs_dict_with_grad:
                        raise RuntimeError(f"Input variable '{input_var_name_for_deriv}' needed for derivative of '{out_var_name}' "
                                           f"not found in inputs_dict_with_grad: {list(inputs_dict_with_grad.keys())}")
                    input_tensor_for_grad = inputs_dict_with_grad[input_var_name_for_deriv]
                    if not input_tensor_for_grad.requires_grad:
                        raise RuntimeError(f"Input tensor for '{input_var_name_for_deriv}' does not require grad.")

                    temp_deriv_target = current_output_component
                    for o in range(1, max_order + 1):
                        grads = torch.autograd.grad(
                            outputs=temp_deriv_target,
                            inputs=input_tensor_for_grad,
                            grad_outputs=torch.ones_like(temp_deriv_target),
                            create_graph=True,
                            retain_graph=True,
                            allow_unused=False # Be strict initially
                        )[0]
                        if grads is None:
                            raise RuntimeError(f"Gradient for d{o}({out_var_name})_d({input_var_name_for_deriv}){o} was None.")

                        deriv_name = f"d{o}({out_var_name})_d{input_var_name_for_deriv}({o})"
                        derivatives[deriv_name] = grads
                        temp_deriv_target = grads

                # --- Handle Mixed Derivatives (e.g., ('x', 't'): 1 for d/dt(du/dx)) ---
                elif len(input_var_sequence) > 1:
                    if order_val != 1:
                        # For now, assume mixed derivative specs like ('x','y'):1 mean one application of d/dy(d/dx(...))
                        # Higher order_val for mixed could mean repeated application of the sequence, but that's rare.
                        print(f"Warning: Mixed derivative for {out_var_name} wrt {input_var_sequence} has order_val {order_val} != 1. Interpreting as 1 application.")

                    temp_deriv_target = current_output_component

                    # Build the name like "d(u)_dx(1)dy(1)"
                    # The number before (out_var_name) will be len(input_var_sequence)
                    name_prefix = f"d{len(input_var_sequence)}({out_var_name})_d"
                    name_suffix_parts = []

                    for i, invar_name in enumerate(input_var_sequence):
                        if invar_name not in inputs_dict_with_grad:
                            raise RuntimeError(f"Input variable '{invar_name}' for mixed derivative of '{out_var_name}' "
                                               f"not in inputs_dict_with_grad.")
                        input_tensor_for_grad = inputs_dict_with_grad[invar_name]
                        if not input_tensor_for_grad.requires_grad:
                             raise RuntimeError(f"Input tensor for mixed deriv '{invar_name}' does not require grad.")

                        grads = torch.autograd.grad(
                            outputs=temp_deriv_target,
                            inputs=input_tensor_for_grad,
                            grad_outputs=torch.ones_like(temp_deriv_target),
                            create_graph=True, # Must be true if any further grads in sequence
                            retain_graph=True, # Must be true
                            allow_unused=False
                        )[0]
                        if grads is None:
                            raise RuntimeError(f"Mixed derivative part d/d{invar_name} for {out_var_name} failed.")
                        temp_deriv_target = grads
                        name_suffix_parts.append(f"{invar_name}(1)")

                    deriv_name = name_prefix + "".join(name_suffix_parts)
                    derivatives[deriv_name] = temp_deriv_target
        return derivatives

    def _calculate_error_metrics_on_test_grid(self, kappa_value, num_test_pts=1001):
        self.model.eval()
        domain_bounds = self.pde_problem.get_domain_bounds()
        test_inputs_dict_for_gt = {} # Populate this as before based on input_vars

        # ... (grid generation logic as in your full code for 1D/2D inputs) ...
        if len(self.pde_problem.input_vars) == 1:
            var_name = self.pde_problem.input_vars[0]
            var_min, var_max = domain_bounds[var_name]
            test_values_np = np.linspace(var_min, var_max, num_test_pts)
            test_values_torch = torch.tensor(test_values_np, dtype=torch.float32, device=self.device).unsqueeze(1)
            test_inputs_dict_for_gt[var_name] = test_values_torch
        elif len(self.pde_problem.input_vars) == 2:
            var1_name, var2_name = self.pde_problem.input_vars[0], self.pde_problem.input_vars[1]
            var1_min, var1_max = domain_bounds[var1_name]
            var2_min, var2_max = domain_bounds[var2_name]
            pts_per_dim = int(np.sqrt(num_test_pts))
            # ... (meshgrid logic) ...
            # (ensure num_test_pts is updated based on actual grid size)
            var1_vals = torch.linspace(var1_min, var1_max, pts_per_dim, device=self.device)
            var2_vals = torch.linspace(var2_min, var2_max, pts_per_dim, device=self.device)
            grid_var1, grid_var2 = torch.meshgrid(var1_vals, var2_vals, indexing='ij')
            test_inputs_dict_for_gt[var1_name] = grid_var1.reshape(-1, 1)
            test_inputs_dict_for_gt[var2_name] = grid_var2.reshape(-1, 1)
            num_test_pts = test_inputs_dict_for_gt[var1_name].shape[0]
        else:
            # For >2D, you'll need to implement a more general grid creation or accept it as an argument
            # For now, let's assume we won't hit this for the workshop's core PDEs
            print("Warning: Test grid generation for >2 input_vars not fully implemented in error metrics.")
            # Fallback or raise error
            return {key: float('nan') for key in ['L1_err', 'L2_err', 'Linf_err', 'L1_err_rel',
                                                  'L2_err_rel', 'Linf_err_rel', 'PDE_residual_max',
                                                  'error_median_abs', 'error_p90_abs']}


        test_model_input_tensor = self._prepare_model_input(test_inputs_dict_for_gt)
        if test_model_input_tensor is None or test_model_input_tensor.numel() == 0:
             print("Warning: No test model input tensor generated for error metrics.")
             return {key: float('nan') for key in ['L1_err', 'L2_err', 'Linf_err', 'L1_err_rel',
                                                  'L2_err_rel', 'Linf_err_rel', 'PDE_residual_max',
                                                  'error_median_abs', 'error_p90_abs']}


        with torch.no_grad():
            u_pred_test = self.model(test_model_input_tensor)

        u_true_test_torch = self.pde_problem.get_ground_truth(test_inputs_dict_for_gt, kappa_value)

        if u_true_test_torch is not None and u_pred_test.shape != u_true_test_torch.shape:
            try: u_pred_test = u_pred_test.reshape_as(u_true_test_torch)
            except RuntimeError: print(f"Warning: Cannot reshape u_pred_test for error calc.")

        metrics = {
            'L1_err': float('nan'), 'L2_err': float('nan'), 'Linf_err': float('nan'),
            'L1_err_rel': float('nan'), 'L2_err_rel': float('nan'), 'Linf_err_rel': float('nan'),
            'PDE_residual_max': float('nan'),
            'error_median_abs': float('nan'), 'error_p90_abs': float('nan') # New
        }

        if u_true_test_torch is not None:
            error_vec = (u_pred_test - u_true_test_torch).flatten() # Flatten for norms and quantiles
            actual_num_test_pts = len(error_vec)
            if actual_num_test_pts == 0: actual_num_test_pts = 1

            metrics['L1_err'] = torch.linalg.norm(error_vec, ord=1).item() / actual_num_test_pts
            metrics['L2_err'] = torch.linalg.norm(error_vec, ord=2).item() / np.sqrt(actual_num_test_pts)
            metrics['Linf_err'] = torch.linalg.norm(error_vec, ord=float('inf')).item()

            abs_error_vec = torch.abs(error_vec)
            metrics['error_median_abs'] = torch.median(abs_error_vec).item()
            if actual_num_test_pts > 0 : # Quantile needs at least one element
                 metrics['error_p90_abs'] = torch.quantile(abs_error_vec, 0.9).item()

            u_true_flat = u_true_test_torch.flatten()
            norm_u_true_l1 = torch.linalg.norm(u_true_flat, ord=1)
            norm_u_true_l2 = torch.linalg.norm(u_true_flat, ord=2)
            norm_u_true_linf = torch.linalg.norm(u_true_flat, ord=float('inf'))

            if norm_u_true_l1 > 1e-9: metrics['L1_err_rel'] = torch.linalg.norm(error_vec, ord=1).item() / norm_u_true_l1.item()
            if norm_u_true_l2 > 1e-9: metrics['L2_err_rel'] = torch.linalg.norm(error_vec, ord=2).item() / norm_u_true_l2.item()
            if norm_u_true_linf > 1e-9: metrics['Linf_err_rel'] = metrics['Linf_err'] / norm_u_true_linf.item()

        # Max PDE residual
        test_inputs_dict_for_res = {}
        for k, v_test in test_inputs_dict_for_gt.items():
            if v_test.numel() > 0: # Only process if tensor is not empty
                test_inputs_dict_for_res[k] = v_test.clone().detach().requires_grad_(True)

        if test_inputs_dict_for_res: # Proceed only if there are inputs for residual calculation
            res_model_input_tensor = self._prepare_model_input(test_inputs_dict_for_res)
            if res_model_input_tensor is not None and res_model_input_tensor.numel() > 0:
                u_pred_for_res = self.model(res_model_input_tensor)
                derivatives_for_res = self._compute_derivatives(test_inputs_dict_for_res, u_pred_for_res)
                pde_res_vals_on_grid = self.pde_problem.pde_residual(test_inputs_dict_for_res, u_pred_for_res, derivatives_for_res, kappa_value)
                if pde_res_vals_on_grid is not None and pde_res_vals_on_grid.numel() > 0 :
                    metrics['PDE_residual_max'] = torch.max(torch.abs(pde_res_vals_on_grid.detach())).item()

        # Calculate specific observables if the PDEProblem has this method
        if hasattr(self.pde_problem, 'calculate_specific_observables'):
            try:
                specific_obs = self.pde_problem.calculate_specific_observables(
                    test_inputs_dict_for_gt, # The dict of input tensors for the test grid
                    u_pred_test,             # Model predictions on the test grid
                    u_true_test_torch,       # Ground truth on the test grid
                    kappa_value
                )
                if specific_obs and isinstance(specific_obs, dict):
                    metrics.update(specific_obs) # Add them to the metrics dict for this epoch
            except Exception as e:
                print(f"Warning: Error calculating specific observables for {self.pde_problem.name}: {e}")

        self.model.train()
        return metrics

    def _closure_lbfgs(self, collocation_points_dict, bc_points_dict, ic_points_dict, kappa_value, loss_weights):
        self.optimizer.zero_grad()

        # PDE Loss
        colloc_model_input = self._prepare_model_input(collocation_points_dict)
        model_outputs_colloc = self.model(colloc_model_input)
        derivatives_colloc = self._compute_derivatives(collocation_points_dict, model_outputs_colloc)
        pde_res = self.pde_problem.pde_residual(collocation_points_dict, model_outputs_colloc, derivatives_colloc, kappa_value)
        loss_pde = torch.mean(pde_res**2)

        # BC Loss
        loss_bc = torch.tensor(0.0, device=self.device)
        if bc_points_dict: # Check if not empty
            bc_model_input = self._prepare_model_input(bc_points_dict)
            model_outputs_bc = self.model(bc_model_input)

            derivatives_at_bc = {}
            bc_deriv_spec = self.pde_problem.get_required_derivative_orders_for_bc()
            if bc_deriv_spec:
                # Ensure bc_points_dict tensors used for derivatives have requires_grad=True
                grad_enabled_bc_points_dict = {
                    k: v.clone().detach().requires_grad_(True) for k,v in bc_points_dict.items()
                }
                grad_enabled_bc_model_input = self._prepare_model_input(grad_enabled_bc_points_dict)
                model_outputs_bc_for_deriv = self.model(grad_enabled_bc_model_input) # Re-evaluate for graph

                derivatives_at_bc = self._compute_derivatives(grad_enabled_bc_points_dict, model_outputs_bc_for_deriv)

            loss_bc = self.pde_problem.boundary_conditions(bc_points_dict, model_outputs_bc, derivatives_at_bc, self.model, kappa_value)

        # IC Loss
        loss_ic = torch.tensor(0.0, device=self.device)
        if self.pde_problem.time_dependent and ic_points_dict:
            ic_model_input = self._prepare_model_input(ic_points_dict)
            model_outputs_ic = self.model(ic_model_input)

            derivatives_at_ic = {}
            ic_deriv_spec = self.pde_problem.get_required_derivative_orders_for_ic()
            if ic_deriv_spec:
                # Ensure ic_points_dict tensors used for derivatives have requires_grad=True
                grad_enabled_ic_points_dict = {
                    k: v.clone().detach().requires_grad_(True) for k,v in ic_points_dict.items()
                }
                grad_enabled_ic_model_input = self._prepare_model_input(grad_enabled_ic_points_dict)
                model_outputs_ic_for_deriv = self.model(grad_enabled_ic_model_input)

                derivatives_at_ic = self._compute_derivatives(grad_enabled_ic_points_dict, model_outputs_ic_for_deriv)

            loss_ic = self.pde_problem.initial_conditions(ic_points_dict, model_outputs_ic, derivatives_at_ic, self.model, kappa_value)

        total_loss = (loss_weights['pde'] * loss_pde +
                      loss_weights['bc'] * loss_bc +
                      loss_weights['ic'] * loss_ic)
        total_loss.backward()
        self._current_losses = {'pde': loss_pde.item(), 'bc': loss_bc.item(),
                                'ic': loss_ic.item(), 'total': total_loss.item()}
        return total_loss

    def train(self, num_epochs, kappa_value,
              num_collocation_pts, num_bc_pts_per_face, num_ic_pts, # Renamed for clarity
              collocation_strategy='uniform',
              loss_weights={'pde': 1.0, 'bc': 1.0, 'ic': 1.0},
              log_epochs=[0, 1000, 5000, 10000],
              num_test_pts_error_grid=1001):

        cumulative_time_s = 0.0
        self.epoch_wise_log = []

        for epoch in range(num_epochs + 1):
            epoch_start_time = time.time()

            # Common point sampling (moved outside optimizer-specific block)
            # These return dicts like {'x': tensor, 't': tensor}
            collocation_points_dict = self.pde_problem.get_collocation_points(
                num_collocation_pts, kappa_value, self.device, collocation_strategy
            )

            bc_points_dict = self.pde_problem.get_boundary_points_general(
                num_bc_pts_per_face, kappa_value, self.device, strategy=collocation_strategy # Use same strategy for BCs
            )
            if bc_points_dict is None: # Fallback to hyperrect if general not implemented
                bc_points_dict = self.pde_problem.get_boundary_points_hyperrect(
                    num_bc_pts_per_face, kappa_value, self.device, strategy=collocation_strategy
                )

            ic_points_dict = {}
            if self.pde_problem.time_dependent:
                ic_points_dict = self.pde_problem.get_initial_points(
                    num_ic_pts, kappa_value, self.device, strategy=collocation_strategy # Use same strategy for ICs
                )

            if self.optimizer_str == "adam":
                self.model.train()
                self.optimizer.zero_grad()

                # PDE Loss
                colloc_model_input = self._prepare_model_input(collocation_points_dict)
                model_outputs_colloc = self.model(colloc_model_input)
                pde_deriv_spec = self.pde_problem.get_required_derivative_orders()
                derivatives_colloc = self._compute_derivatives(collocation_points_dict, model_outputs_colloc, pde_deriv_spec)
                pde_res = self.pde_problem.pde_residual(collocation_points_dict, model_outputs_colloc, derivatives_colloc, kappa_value)
                loss_pde = torch.mean(pde_res**2)

                loss_bc = torch.tensor(0.0, device=self.device)
                if bc_points_dict: # Check if not empty dictionary from get_boundary_points
                    # Check if the Tensors within the dict are empty
                    if all(tensor.numel() > 0 for tensor in bc_points_dict.values()):
                        bc_model_input = self._prepare_model_input(bc_points_dict)
                        model_outputs_bc = self.model(bc_model_input)
                        derivatives_at_bc = {}
                        bc_deriv_spec = self.pde_problem.get_required_derivative_orders_for_bc()
                        if bc_deriv_spec:
                            grad_enabled_bc_points_dict = {
                                k: v.clone().detach().requires_grad_(True) for k,v in bc_points_dict.items()
                            }
                            grad_enabled_bc_model_input = self._prepare_model_input(grad_enabled_bc_points_dict)
                            # Need to re-evaluate model on grad-enabled inputs to build graph for derivatives
                            model_outputs_bc_for_deriv = self.model(grad_enabled_bc_model_input)
                            derivatives_at_bc = self._compute_derivatives(grad_enabled_bc_points_dict, model_outputs_bc_for_deriv, bc_deriv_spec)

                        loss_bc = self.pde_problem.boundary_conditions(
                            inputs_bc=bc_points_dict,
                            model_outputs_bc=model_outputs_bc, # This is from original non-grad points
                            derivatives_bc=derivatives_at_bc,  # But derivatives from grad-enabled
                            model=self.model,
                            kappa_value=kappa_value
                        )
                    else: # Handle case where bc_points_dict is not empty but contains empty tensors
                        loss_bc = torch.tensor(0.0, device=self.device)


                # IC Loss
                loss_ic = torch.tensor(0.0, device=self.device)
                if self.pde_problem.time_dependent and ic_points_dict:
                    if all(tensor.numel() > 0 for tensor in ic_points_dict.values()):
                        ic_model_input = self._prepare_model_input(ic_points_dict)
                        model_outputs_ic = self.model(ic_model_input)
                        derivatives_at_ic = {}
                        ic_deriv_spec = self.pde_problem.get_required_derivative_orders_for_ic()
                        if ic_deriv_spec:
                            grad_enabled_ic_points_dict = {
                                k: v.clone().detach().requires_grad_(True) for k,v in ic_points_dict.items()
                            }
                            grad_enabled_ic_model_input = self._prepare_model_input(grad_enabled_ic_points_dict)
                            model_outputs_ic_for_deriv = self.model(grad_enabled_ic_model_input)
                            derivatives_at_ic = self._compute_derivatives(grad_enabled_ic_points_dict, model_outputs_ic_for_deriv, ic_deriv_spec)

                        loss_ic = self.pde_problem.initial_conditions(
                            inputs_ic=ic_points_dict,
                            model_outputs_ic=model_outputs_ic,
                            derivatives_ic=derivatives_at_ic,
                            model=self.model,
                            kappa_value=kappa_value
                        )
                    else: # Handle case where ic_points_dict is not empty but contains empty tensors
                        loss_ic = torch.tensor(0.0, device=self.device)


                total_loss = (loss_weights['pde'] * loss_pde +
                              loss_weights['bc'] * loss_bc +
                              loss_weights['ic'] * loss_ic)

                if epoch > 0:
                    total_loss.backward()
                    self.optimizer.step()

                current_total_loss = total_loss.item()
                current_pde_loss = loss_pde.item()
                current_bc_loss = loss_bc.item()
                current_ic_loss = loss_ic.item()

            elif self.optimizer_str == "lbfgs":
                if epoch > 0:
                    self.model.train()
                    self.optimizer.step(lambda: self._closure_lbfgs(
                        collocation_points_dict, bc_points_dict, ic_points_dict, kappa_value, loss_weights
                    ))
                # For LBFGS, losses are updated within the closure
                current_total_loss = self._current_losses.get('total', float('nan')) if hasattr(self, '_current_losses') else float('nan')
                current_pde_loss = self._current_losses.get('pde', float('nan')) if hasattr(self, '_current_losses') else float('nan')
                current_bc_loss = self._current_losses.get('bc', float('nan')) if hasattr(self, '_current_losses') else float('nan')
                current_ic_loss = self._current_losses.get('ic', float('nan')) if hasattr(self, '_current_losses') else float('nan')

            # ... (rest of logging logic is good) ...
            epoch_duration_s = time.time() - epoch_start_time
            if epoch > 0 : cumulative_time_s += epoch_duration_s

            if epoch in log_epochs or epoch == num_epochs:
                grad_norm = 0.0
                if epoch > 0:
                    for p in self.model.parameters():
                        if p.grad is not None:
                            grad_norm += p.grad.detach().data.norm(2).item() ** 2
                    grad_norm = grad_norm ** 0.5 if grad_norm > 0 else 0.0

                error_metrics_on_grid = self._calculate_error_metrics_on_test_grid(kappa_value, num_test_pts_error_grid)

                # Track L2 norm of weights for regularization and diagnostics
                l2_norm_weights = 0.0
                for param in self.model.parameters():
                    if param.requires_grad: # Usually all model parameters do
                        l2_norm_weights += torch.linalg.norm(param.data).item()**2
                l2_norm_weights = np.sqrt(l2_norm_weights) if l2_norm_weights > 0 else 0.0

                gpu_mem_peak_mb = float('nan')
                if self.device.type == 'cuda':
                    # Peak memory allocated on this device since the last reset
                    gpu_mem_peak_mb = torch.cuda.max_memory_allocated(self.device) / (1024**2) # Convert to MB
                    torch.cuda.reset_peak_memory_stats(self.device) # Reset for the next interval

                log_entry = {
                    'epoch': epoch, 'time_s': cumulative_time_s,
                    'loss_total': current_total_loss, 'loss_pde': current_pde_loss,
                    'loss_bc': current_bc_loss, 'loss_ic': current_ic_loss,
                    'grad_norm_l2': grad_norm, 'l2_norm_weights': l2_norm_weights,
                    'gpu_mem_peak_mb': gpu_mem_peak_mb,
                }
                log_entry.update(error_metrics_on_grid)
                self.epoch_wise_log.append(log_entry)

                print(f"Epoch {epoch}/{num_epochs}, Loss: {current_total_loss:.3e}, "
                      f"L2_err_rel: {log_entry.get('L2_err_rel', float('nan')):.3e}, GradNorm: {grad_norm:.3e}")

        print(f"Training finished. Total active time: {cumulative_time_s:.2f}s")
        return self.epoch_wise_log

In [11]:
# Cell 4
import itertools
import os
import json
import pandas as pd
from dataclasses import dataclass, asdict, field

@dataclass
class ExperimentConfig:
    # Identification
    pde_name: str
    kappa_val: float
    activation_str: str
    seed: int

    # Architecture
    depth: int # Number of hidden layers
    width: int # Neurons per hidden layer

    # Optimizer
    optimizer_type: str
    lr: float

    # Training
    epochs: int

    # Logging & Error Evaluation
    log_epochs_list: list = field(default_factory=lambda: [x for x in range(0, 10001, 100)])
    num_test_pts_error_grid: int = 1001

    # Loss Weights
    loss_weight_pde: float = 1.0
    loss_weight_bc: float = 1.0
    loss_weight_ic: float = 1.0

    # Collocation (points takes precedence over factor)
    M_collocation_pts: int = field(default=None)
    M_collocation_factor: int = field(default=10)

    # IC/BC Points
    num_total_bc_pts: Optional[int] = field(default=None) # Total BC points across all faces
    num_bc_pts_per_face: Optional[int] = field(default=None) # If not specified, heuristic will be used
    num_ic_pts: Optional[int] = field(default=None) # Only for time-dependent PDEs
    collocation_scheme: str = field(default='uniform')

class ExperimentRunner:
    def __init__(self, base_results_dir="data/", pde_map=None, device='cpu'):
        self.base_results_dir = base_results_dir
        self.pde_map = pde_map if pde_map is not None else {}
        self.device = device
        os.makedirs(self.base_results_dir, exist_ok=True)

    def _get_run_dir(self, config: ExperimentConfig):
        # Format kappa_val for filename safety, e.g., replace decimal point
        kappa_str = f"{config.kappa_val:.1e}".replace('.', 'p').replace('+', '') # e.g. 1p0e-03

        run_path = os.path.join(
            self.base_results_dir,
            config.pde_name,
            f"kappa_{kappa_str}",
            f"act_{config.activation_str}",
            f"N_{config.width}",
            f"D_{config.depth}",
            f"seed_{config.seed}"
        )
        os.makedirs(run_path, exist_ok=True)
        return run_path

    def run_single_experiment(self, config: ExperimentConfig):
        pde_instance = self.pde_map.get(config.pde_name)
        if pde_instance is None:
            print(f"Error: PDE problem '{config.pde_name}' not found in pde_map.")
            return

        run_dir = self._get_run_dir(config)
        print(f"\n--- Running Experiment: {run_dir} ---")
        print(f"Config: {config}")

        with open(os.path.join(run_dir, "config.json"), 'w') as f:
            json.dump(asdict(config), f, indent=2)

        torch.manual_seed(config.seed)
        np.random.seed(config.seed)

        # Determine model input_dim based on pde_instance.input_vars
        # Could also use pde_instance.spatial_domain_dim + (1 if pde_instance.time_dependent else 0)
        model_input_dim = len(pde_instance.input_vars)

        model = create_pinn_model(
            input_dim=model_input_dim,
            output_dim=pde_instance.output_dim,
            n_hidden_layers=config.depth,
            n_neurons_per_layer=config.width,
            activation_str=config.activation_str,
            device=self.device
        )

        trainer = Trainer(model, pde_instance,
                          optimizer_str=config.optimizer_type,
                          learning_rate=config.lr,
                          device=self.device)


        # Collocation points:
        if config.M_collocation_pts is not None:
            M_collocation = config.M_collocation_pts
        else:
            # Specified or default collocation factor if direct points not given
            M_collocation = config.width * config.M_collocation_factor

        # For BC points:
        actual_num_bc_pts_per_face = 0
        if config.num_total_bc_pts is not None:
            num_spatial_dims = pde_instance.spatial_domain_dim
            num_faces = 2 * num_spatial_dims if num_spatial_dims > 0 else 0
            if num_faces > 0:
                actual_num_bc_pts_per_face = config.num_total_bc_pts // num_faces
            else: # If no spatial dims, num_total_bc_pts should ideally be 0 or ignored
                actual_num_bc_pts_per_face = 0 # or handle appropriately
        elif config.num_bc_pts_per_face is not None:
            actual_num_bc_pts_per_face = config.num_bc_pts_per_face
        else: # Fallback or default heuristic if not specified
            num_spatial_dims = pde_instance.spatial_domain_dim
            num_faces = 2 * num_spatial_dims if num_spatial_dims > 0 else 0
            if num_faces > 0:
                heuristic_bc_factor = 20
                actual_num_bc_pts_per_face = M_collocation // (heuristic_bc_factor * num_faces)
                actual_num_bc_pts_per_face = max(10, actual_num_bc_pts_per_face) # Min points
            else:
                actual_num_bc_pts_per_face = 0

        # For IC points:
        actual_num_ic_pts = 0
        if pde_instance.time_dependent:
            if config.num_ic_pts is not None:
                actual_num_ic_pts = config.num_ic_pts
            else: # Fallback or default heuristic
                heuristic_ic_factor = 10
                actual_num_ic_pts = M_collocation // heuristic_ic_factor
                actual_num_ic_pts = max(10, actual_num_ic_pts) # Min points

        # Adjust log_epochs based on actual config.epochs
        actual_log_epochs = [e for e in config.log_epochs_list if e <= config.epochs]
        if config.epochs not in actual_log_epochs:
            actual_log_epochs.append(config.epochs)
        actual_log_epochs = sorted(list(set(actual_log_epochs)))
        if 0 not in actual_log_epochs : actual_log_epochs.insert(0,0)


        epoch_wise_log_data = trainer.train(
            num_epochs=config.epochs,
            kappa_value=config.kappa_val,
            num_collocation_pts=M_collocation,
            num_bc_pts_per_face=actual_num_bc_pts_per_face,
            num_ic_pts=actual_num_ic_pts,
            collocation_strategy=config.collocation_scheme,
            log_epochs=actual_log_epochs,
            num_test_pts_error_grid=config.num_test_pts_error_grid
        )

        df_epoch_log = pd.DataFrame(epoch_wise_log_data)
        df_epoch_log.to_csv(os.path.join(run_dir, "training_log.csv"), index=False)

        final_metrics = {}
        if not df_epoch_log.empty:
            last_epoch_data = df_epoch_log.iloc[-1]
            final_metrics = {
                key: last_epoch_data.get(key, float('nan'))
                for key in ['time_s', 'loss_total', 'L1_err_rel', 'L2_err_rel',
                            'Linf_err_rel', 'PDE_residual_max', 'grad_norm_l2']
            }

        summary_data = {"final_metrics": final_metrics, "fit_results": {}} # Config saved separately
        with open(os.path.join(run_dir, "summary.json"), 'w') as f:
            json.dump(summary_data, f, indent=2, cls=NpEncoder) # Handle numpy types if any

        print(f"Finished experiment. Final L2_err_rel: {final_metrics.get('L2_err_rel', 'N/A'):.3e}")

# Helper for JSON serialization if numpy types are used in summary
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

## Concerete PDEProblem Implementations

In [12]:
class TrivialLinearPDE(PDEProblem):
    def __init__(self):
        super().__init__(name="TrivialLinear",
                         input_vars=['x'],
                         output_vars=['u'],
                         kappa_name="N/A",
                         default_kappa_value=1.0)

    def get_domain_bounds(self) -> Dict[str, Tuple[float, float]]:
        return {'x': (0.0, 1.0)}

    def pde_residual(self, inputs: Dict[str, torch.Tensor],
                     model_outputs: torch.Tensor,
                     derivatives: Dict[str, torch.Tensor],
                     kappa_value: float) -> torch.Tensor:
        # PDE: -u_xx = 0  => u_xx = 0. model_outputs is u(x)
        return derivatives['d2(u)_dx(2)'] # Residual is u_xx

    def boundary_conditions(self, inputs_bc: Dict[str, torch.Tensor],
                            model_outputs_bc: torch.Tensor,
                            derivatives_bc: Dict[str, torch.Tensor], # Added
                            model: nn.Module, # Added model for API consistency
                            kappa_value: float) -> torch.Tensor:
        # BC: u(0)=0, u(1)=1. model_outputs_bc is NN(x_bc) -> predicted u at boundary
        x_vals = inputs_bc['x'].squeeze()
        loss = torch.tensor(0.0, device=model_outputs_bc.device)

        # u(0) = 0
        u_at_0_pred = model_outputs_bc[x_vals == self.get_domain_bounds()['x'][0]]
        if u_at_0_pred.numel() > 0:
            loss += torch.mean((u_at_0_pred - 0.0)**2)

        # u(1) = 1
        u_at_1_pred = model_outputs_bc[x_vals == self.get_domain_bounds()['x'][1]]
        if u_at_1_pred.numel() > 0:
            loss += torch.mean((u_at_1_pred - 1.0)**2)
        return loss

    def get_ground_truth(self, inputs: Dict[str, torch.Tensor],
                         kappa_value: float) -> Optional[torch.Tensor]:
        # Solution: u(x) = x
        x = inputs['x']
        return x.clone()

    def get_required_derivative_orders(self) -> Dict[str, Dict[Tuple[str, ...], int]]:
        return {
            'u': {
                ('x',): 2  # We need up to d2u/dx2
            }
        }

class PoissonPDE(PDEProblem):
    def __init__(self):
        super().__init__(name="Poisson",
                         input_vars=['x'],
                         output_vars=['u'],
                         kappa_name="N/A",
                         default_kappa_value=1.0)
        self.forcing_fn_torch = lambda x_tensor: torch.sin(torch.pi * x_tensor)
        self.analytical_sol_np = lambda x_np: (1.0 / (np.pi**2)) * np.sin(np.pi * x_np.squeeze(-1) if x_np.ndim > 1 else np.pi * x_np)


    def get_domain_bounds(self) -> Dict[str, Tuple[float, float]]:
        return {'x': (0.0, 1.0)}

    def pde_residual(self, inputs: Dict[str, torch.Tensor],
                     model_outputs: torch.Tensor,
                     derivatives: Dict[str, torch.Tensor],
                     kappa_value: float) -> torch.Tensor:
        # PDE: -u_xx = sin(pi*x)  =>  u_xx + sin(pi*x) = 0
        x = inputs['x']
        # model_outputs is u(x), derivatives['d2(u)_dx(2)'] is u_xx
        u_xx = derivatives['d2(u)_dx(2)']
        f_x = self.forcing_fn_torch(x)
        return u_xx + f_x

    def boundary_conditions(self, inputs_bc: Dict[str, torch.Tensor],
                            model_outputs_bc: torch.Tensor,
                            derivatives_bc: Dict[str, torch.Tensor], # Added
                            model: nn.Module,
                            kappa_value: float) -> torch.Tensor:
        # BC: u(0)=0, u(1)=0
        x_vals = inputs_bc['x'].squeeze()
        loss = torch.tensor(0.0, device=model_outputs_bc.device)

        u_at_0_pred = model_outputs_bc[x_vals == self.get_domain_bounds()['x'][0]]
        if u_at_0_pred.numel() > 0:
            loss += torch.mean((u_at_0_pred - 0.0)**2)

        u_at_1_pred = model_outputs_bc[x_vals == self.get_domain_bounds()['x'][1]]
        if u_at_1_pred.numel() > 0:
            loss += torch.mean((u_at_1_pred - 0.0)**2)
        return loss

    def get_ground_truth(self, inputs: Dict[str, torch.Tensor],
                         kappa_value: float) -> Optional[torch.Tensor]:
        x_tensor = inputs['x']
        x_np = x_tensor.detach().cpu().numpy()
        u_true_np = self.analytical_sol_np(x_np)
        return torch.tensor(u_true_np, dtype=torch.float32, device=x_tensor.device).reshape_as(x_tensor)

    def get_required_derivative_orders(self) -> Dict[str, Dict[Tuple[str, ...], int]]:
        return {
            'u': {
                ('x',): 2
            }
        }

In [None]:
class BurgersPDE(PDEProblem):
    def __init__(self, periodic_bc_scheme: str = "soft_constraint"):
        super().__init__(name="Burgers",
                         input_vars=['t', 'x'],
                         output_vars=['u'],
                         time_var='t',
                         kappa_name="1/nu",
                         default_kappa_value=100.0) # Default nu = 0.01
        self.periodic_bc_scheme = periodic_bc_scheme # "soft_constraint" or "hard_transform" (more complex)
        # Ground truth might be complex, consider a helper or precomputed data
        self.ground_truth_solver = None # Placeholder for a more sophisticated solver

    def get_domain_bounds(self) -> Dict[str, Tuple[float, float]]:
        return {'t': (0.0, 1.0), 'x': (-1.0, 1.0)} # Common domain

    def get_required_derivative_orders(self) -> Dict[str, Dict[Tuple[str, ...], int]]:
        return {
            'u': {
                ('t',): 1,  # For u_t
                ('x',): 2   # For u_x and u_xx
            }
        }

    def pde_residual(self, inputs: Dict[str, torch.Tensor],
                     model_outputs: torch.Tensor,
                     derivatives: Dict[str, torch.Tensor],
                     kappa_value: float) -> torch.Tensor:
        # PDE: u_t + u * u_x - nu * u_xx = 0
        # model_outputs is u(t,x) (shape: batch, 1)
        # derivatives: 'd1(u)_dt(1)', 'd1(u)_dx(1)', 'd2(u)_dx(2)'
        nu = 1.0 / kappa_value
        u = model_outputs[:, 0:1] # Ensure it's (batch, 1)

        u_t = derivatives['d1(u)_dt(1)']
        u_x = derivatives['d1(u)_dx(1)']
        u_xx = derivatives['d2(u)_dx(2)']

        residual = u_t + u * u_x - nu * u_xx
        return residual # Shape (batch, 1)

    def initial_conditions(self, inputs_ic: Dict[str, torch.Tensor],
                           model_outputs_ic: torch.Tensor,
                           model: nn.Module,
                            derivatives_ic: Dict[str, torch.Tensor],
                           kappa_value: float) -> torch.Tensor:
        # IC: u(0, x) = -sin(pi*x)
        # inputs_ic will have t=0 and varying x
        # model_outputs_ic are the PINN predictions u(0,x)_pred (shape: batch, 1)
        x_ic = inputs_ic['x']
        u_true_ic = -torch.sin(torch.pi * x_ic)

        loss_ic = torch.mean((model_outputs_ic - u_true_ic)**2)
        return loss_ic

    def boundary_conditions(self, inputs_bc: Dict[str, torch.Tensor],
                        model_outputs_bc: torch.Tensor, # Output on all points from get_boundary_points_hyperrect
                        derivatives_bc: Dict[str, torch.Tensor], # Derivatives on all those points
                        model: nn.Module, # Keep for flexibility
                        kappa_value: float) -> torch.Tensor:
        if self.periodic_bc_scheme == "soft_constraint":
            t_b = inputs_bc[self.time_var] # (N_total_bc_pts, 1)
            x_b = inputs_bc['x']           # (N_total_bc_pts, 1)

            x_min_val, x_max_val = self.get_domain_bounds()['x']

            # Indices for points at x_min and x_max
            # This assumes get_boundary_points_hyperrect samples t values consistently for both faces
            # A robust way to ensure pairing:
            # Find all unique t values that appear on BOTH x_min and x_max boundaries

            # Simplified pairing assuming structure from get_boundary_points_hyperrect:
            # (num_points_per_face for x_min, then num_points_per_face for x_max, with matching t's)
            # This relies on num_points_per_face being the same for both faces.
            # And t values being sampled identically.

            num_pts_on_one_face = x_b[x_b == x_min_val].shape[0]

            # Assuming t_coords are sampled consistently for x_min and x_max faces
            # and are stacked, e.g., [t_face1, t_face2, ...]
            # More careful selection/pairing might be needed if sampling strategy is complex.

            # A more direct way for soft periodic BC:
            # Generate N_periodic_t time points *within this function*.
            # This bypasses the need to perfectly align points from get_boundary_points_hyperrect.

            domain_b = self.get_domain_bounds()
            t_min_b, t_max_b = domain_b[self.time_var]
            x_min_b, x_max_b = domain_b['x']

            # Use a portion of the provided BC points to define num_periodic_pts
            # num_periodic_pts = model_outputs_bc.shape[0] // 2 # Assuming half on x_min, half on x_max
            # Or a fixed number, e.g., from config.num_bc_pts_per_face
            num_periodic_pts = inputs_bc[self.time_var][inputs_bc['x'] == x_min_b].shape[0]
            if num_periodic_pts == 0: return torch.tensor(0.0, device=model_outputs_bc.device)


            t_periodic = torch.rand(num_periodic_pts, 1, device=model_outputs_bc.device) * (t_max_b - t_min_b) + t_min_b
            t_periodic.requires_grad_(True) # In case t-derivatives were needed at BC

            x_at_min_periodic = torch.full_like(t_periodic, x_min_b).requires_grad_(True)
            x_at_max_periodic = torch.full_like(t_periodic, x_max_b).requires_grad_(True)

            # Prepare inputs for model
            inputs_at_xmin_dict = {self.time_var: t_periodic, 'x': x_at_min_periodic}
            inputs_at_xmax_dict = {self.time_var: t_periodic, 'x': x_at_max_periodic}

            model_input_xmin = torch.cat([inputs_at_xmin_dict[var] for var in self.input_vars], dim=1)
            model_input_xmax = torch.cat([inputs_at_xmax_dict[var] for var in self.input_vars], dim=1)

            u_at_xmin = model(model_input_xmin)
            u_at_xmax = model(model_input_xmax)
            loss_val_periodic = torch.mean((u_at_xmin - u_at_xmax)**2)

            loss_deriv_periodic = torch.tensor(0.0, device=loss_val_periodic.device)
            # For derivative u_x(t, x_min) = u_x(t, x_max)
            # We need to compute derivatives w.r.t x at these specific (t_periodic, x_at_min/max_periodic) points

            # This uses the Trainer's _compute_derivatives, which needs the Trainer instance.
            # This is a bit of a circular dependency if not handled carefully.
            # Alternative: Compute grads directly here.

            # Compute u_x at x_min for t_periodic
            du_dx_at_xmin = torch.autograd.grad(u_at_xmin, x_at_min_periodic, grad_outputs=torch.ones_like(u_at_xmin), create_graph=True)[0]
            # Compute u_x at x_max for t_periodic
            du_dx_at_xmax = torch.autograd.grad(u_at_xmax, x_at_max_periodic, grad_outputs=torch.ones_like(u_at_xmax), create_graph=True)[0]

            if du_dx_at_xmin is not None and du_dx_at_xmax is not None:
                loss_deriv_periodic = torch.mean((du_dx_at_xmin - du_dx_at_xmax)**2)
            else:
                print("Warning: Could not compute derivatives for periodic BC enforcement.")

            return loss_val_periodic + loss_deriv_periodic # Add relative weighting if desired

        elif self.periodic_bc_scheme == "hard_transform":
            raise NotImplementedError("Hard periodic BC transform not implemented here.")
        return torch.tensor(0.0, device=model_outputs_bc.device) # Default if no scheme matches

    def get_required_derivative_orders_for_bc(self) -> Optional[Dict[str, Dict[Tuple[str, ...], int]]]:
        if self.periodic_bc_scheme == "soft_constraint":
            return {'u': {('x',): 1}} # Need u_x for periodic u_x
        return None

    @cache
    def _compute_theta_0_coeffs(self, nu: float, N_terms_fourier: int,
                             N_points_integrate: int, device: torch.device) -> torch.Tensor:
        """
        Computes Fourier coefficients c_n(0) for theta_0(x) for n = 0, 1, ..., N_terms_fourier.
        theta_0(x) = exp( (1 - cos(pi*x)) / (2*nu*pi) )
        c_n(0) = (1/L) * integral_{-L/2}^{L/2} theta_0(x) * exp(-i*2*pi*n*x/L) dx, with L=2.
            = 0.5 * integral_{-1}^{1} theta_0(x) * exp(-i*n*pi*x) dx
        Since theta_0(x) is real and even, c_n are real and c_n = c_{-n}.
        So, c_n(0) = 0.5 * integral_{-1}^{1} theta_0(x) * cos(n*pi*x) dx.
        We return an array [c_0(0), c_1(0), ..., c_N_terms_fourier(0)].
        """
        xi = torch.linspace(-1.0, 1.0, N_points_integrate, dtype=torch.float32, device=device)

        # theta_0(x) = exp( (1 - cos(pi*x)) / (2*nu*pi) )  <-- Corrected sign from original derivation
        factor = 1.0 / (2.0 * nu * torch.pi)
        theta_0_on_xi = torch.exp(factor * (1.0 - torch.cos(torch.pi * xi)))

        c_n_0_array = torch.zeros(N_terms_fourier + 1, dtype=torch.float32, device=device)

        for n_val in range(N_terms_fourier + 1): # n = 0, 1, ..., N_terms_fourier
            integrand = theta_0_on_xi * torch.cos(n_val * torch.pi * xi)
            c_n_0_array[n_val] = 0.5 * torch.trapezoid(y=integrand, x=xi) # type: ignore

        return c_n_0_array

    def get_ground_truth(self, inputs: Dict[str, torch.Tensor],
                       kappa_value: float) -> Optional[torch.Tensor]:
        x_eval = inputs['x']
        t_eval = inputs['t']
        current_device = x_eval.device

        if kappa_value == 0: # Avoid division by zero if nu is kappa
            nu = float('inf') # effectively no viscosity, though formula breaks down
            # Handle inviscid case separately if needed, or raise error
            print("Warning: kappa_value is 0, results for Burgers' GT might be ill-defined with this formula.")
            # For inviscid Burgers with u0 = -sin(pi*x), shock forms. This solution method is for viscous.
            return torch.full_like(x_eval, float('nan'))

        nu = 1.0 / kappa_value

        N_terms_fourier = getattr(self, 'N_terms_fourier_gt', 75) # Can be tuned, make it an attribute
        N_points_integrate_coeffs = getattr(self, 'N_points_integrate_coeffs_gt', 4096)

        # --- (Optional) Caching logic for c_n_0_coeffs ---
        # cache_key = (float(nu), N_terms_fourier, N_points_integrate_coeffs, str(current_device))
        # if hasattr(self, '_cached_theta_coeffs') and cache_key in self._cached_theta_coeffs:
        #    c_n_0_real = self._cached_theta_coeffs[cache_key]
        # else:
        #    c_n_0_real = self._compute_theta_0_coeffs(nu, N_terms_fourier, N_points_integrate_coeffs, device=current_device)
        #    if not hasattr(self, '_cached_theta_coeffs'): self._cached_theta_coeffs = {}
        #    self._cached_theta_coeffs[cache_key] = c_n_0_real
        # For simplicity now, recompute:
        c_n_0_real = self._compute_theta_0_coeffs(nu, N_terms_fourier, N_points_integrate_coeffs, device=current_device)

        x = x_eval.view(-1, 1)
        t = t_eval.view(-1, 1)

        n_vals = torch.arange(0, N_terms_fourier + 1, device=current_device, dtype=torch.float32).view(1, -1)
        c_n_coeffs_for_sum = c_n_0_real.view(1, -1)

        # Theta sum: theta(x,t) = c_0(0) + Sum_{n=1 to N} 2 * c_n(0) * cos(n*pi*x) * exp(-nu*(n*pi)^2*t)
        # Theta_x sum: theta_x(x,t) = Sum_{n=1 to N} 2 * c_n(0) * (-n*pi*sin(n*pi*x)) * exp(-nu*(n*pi)^2*t)

        # n=0 term for theta
        theta_val = c_n_coeffs_for_sum[:, 0:1] * torch.ones_like(x) # Term for c_0
        theta_x_val = torch.zeros_like(x) # n=0 term for theta_x is zero

        # n > 0 terms
        n_pos_vals = n_vals[:, 1:]  # Shape (1, N_terms_fourier)
        c_n_pos_coeffs = c_n_coeffs_for_sum[:, 1:] # Shape (1, N_terms_fourier)

        # Common factors for n > 0
        n_pi_x = n_pos_vals * torch.pi * x  # Shape (B, N_terms_fourier) via broadcasting
        exp_decay = torch.exp(-nu * (n_pos_vals * torch.pi)**2 * t) # Shape (B, N_terms_fourier)

        # Sum for theta
        sum_terms_theta = 2.0 * c_n_pos_coeffs * torch.cos(n_pi_x) * exp_decay
        theta_val += torch.sum(sum_terms_theta, dim=1, keepdim=True)

        # Sum for theta_x
        sum_terms_theta_x = 2.0 * c_n_pos_coeffs * (-n_pos_vals * torch.pi * torch.sin(n_pi_x)) * exp_decay
        theta_x_val += torch.sum(sum_terms_theta_x, dim=1, keepdim=True)

        epsilon = 1e-12 # Small epsilon to prevent division by zero if theta is numerically zero
        u_final = -2.0 * nu * (theta_x_val / (theta_val + epsilon))

        return u_final.view_as(x_eval)

class KdVPDE(PDEProblem):
    def __init__(self, periodic_bc_scheme: str = "soft_constraint"):
        # kappa_value will be amplitude A
        super().__init__(name="KdV",
                         input_vars=['t', 'x'],
                         output_vars=['u'],
                         time_var='t',
                         kappa_name="A",
                         default_kappa_value=1.0)
        self.periodic_bc_scheme = periodic_bc_scheme

    def get_domain_bounds(self) -> Dict[str, Tuple[float, float]]:
        # Domain needs to be wide enough for solitons
        return {'t': (0.0, 4.0 / self.default_kappa_value), # Time to see interaction, scales with 1/A
                'x': (-20.0, 20.0)}

    def get_required_derivative_orders(self) -> Dict[str, Dict[Tuple[str, ...], int]]:
        return {
            'u': {
                ('t',): 1,  # For u_t
                ('x',): 3   # For u_x (in u*u_x) and u_xxx
            }
        }

    def get_required_derivative_orders_for_bc(self) -> Optional[Dict[str, Dict[Tuple[str, ...], int]]]:
        if self.periodic_bc_scheme == "soft_constraint":
            return {'u': {('x',): 1}} # Need u_x for periodic u_x
        return None

    def pde_residual(self, inputs: Dict[str, torch.Tensor],
                     model_outputs: torch.Tensor,
                     derivatives: Dict[str, torch.Tensor],
                     kappa_value: float) -> torch.Tensor:
        # PDE: u_t + A * u * u_x + u_xxx = 0
        # A = kappa_value
        # model_outputs is u(t,x) (shape: batch, 1)
        # derivatives: 'd1(u)_dt(1)', 'd1(u)_dx(1)', 'd3(u)_dx(3)'
        A = kappa_value
        u = model_outputs[:, 0:1]

        u_t = derivatives['d1(u)_dt(1)']
        u_x = derivatives['d1(u)_dx(1)']
        u_xxx = derivatives['d3(u)_dx(3)'] # Trainer needs to provide this key

        residual = u_t + A * u * u_x + u_xxx
        return residual # Shape (batch, 1)

    def _sech(self, x_tensor: torch.Tensor) -> torch.Tensor:
        return 1.0 / torch.cosh(x_tensor)

    def initial_conditions(self, inputs_ic: Dict[str, torch.Tensor],
                           model_outputs_ic: torch.Tensor,
                           model: nn.Module,
                           derivatives_ic: Dict[str, torch.Tensor],
                           kappa_value: float) -> torch.Tensor:
        # IC: 1-soliton u(0, x) = A * sech^2( sqrt(A/2) * (x - x0) )
        # A = kappa_value. Let x0 be center of domain, e.g., 0 for x in [-L, L]
        x0 = 0.0
        A = kappa_value
        x_ic = inputs_ic['x']

        # Argument of sech^2. Ensure A is positive.
        if A <= 0: A = 1e-6 # Avoid sqrt of non-positive, or raise error

        sqrt_arg = torch.tensor(A / 2.0, dtype=x_ic.dtype, device=x_ic.device)
        arg = torch.sqrt(sqrt_arg) * (x_ic - x0)
        u_true_ic = A * (self._sech(arg)**2)

        loss_ic = torch.mean((model_outputs_ic - u_true_ic)**2)
        return loss_ic

    def boundary_conditions(self, inputs_bc: Dict[str, torch.Tensor],
                        model_outputs_bc: torch.Tensor, # Output on all points from get_boundary_points_hyperrect
                        derivatives_bc: Dict[str, torch.Tensor], # Derivatives on all those points
                        model: nn.Module, # Keep for flexibility
                        kappa_value: float) -> torch.Tensor:
        if self.periodic_bc_scheme == "soft_constraint":
            t_b = inputs_bc[self.time_var] # (N_total_bc_pts, 1)
            x_b = inputs_bc['x']           # (N_total_bc_pts, 1)

            x_min_val, x_max_val = self.get_domain_bounds()['x']

            # Indices for points at x_min and x_max
            # This assumes get_boundary_points_hyperrect samples t values consistently for both faces
            # A robust way to ensure pairing:
            # Find all unique t values that appear on BOTH x_min and x_max boundaries

            # Simplified pairing assuming structure from get_boundary_points_hyperrect:
            # (num_points_per_face for x_min, then num_points_per_face for x_max, with matching t's)
            # This relies on num_points_per_face being the same for both faces.
            # And t values being sampled identically.

            num_pts_on_one_face = x_b[x_b == x_min_val].shape[0]

            # Assuming t_coords are sampled consistently for x_min and x_max faces
            # and are stacked, e.g., [t_face1, t_face2, ...]
            # More careful selection/pairing might be needed if sampling strategy is complex.

            # A more direct way for soft periodic BC:
            # Generate N_periodic_t time points *within this function*.
            # This bypasses the need to perfectly align points from get_boundary_points_hyperrect.

            domain_b = self.get_domain_bounds()
            t_min_b, t_max_b = domain_b[self.time_var]
            x_min_b, x_max_b = domain_b['x']

            # Use a portion of the provided BC points to define num_periodic_pts
            # num_periodic_pts = model_outputs_bc.shape[0] // 2 # Assuming half on x_min, half on x_max
            # Or a fixed number, e.g., from config.num_bc_pts_per_face
            num_periodic_pts = inputs_bc[self.time_var][inputs_bc['x'] == x_min_b].shape[0]
            if num_periodic_pts == 0: return torch.tensor(0.0, device=model_outputs_bc.device)


            t_periodic = torch.rand(num_periodic_pts, 1, device=model_outputs_bc.device) * (t_max_b - t_min_b) + t_min_b
            t_periodic.requires_grad_(True) # In case t-derivatives were needed at BC

            x_at_min_periodic = torch.full_like(t_periodic, x_min_b).requires_grad_(True)
            x_at_max_periodic = torch.full_like(t_periodic, x_max_b).requires_grad_(True)

            # Prepare inputs for model
            inputs_at_xmin_dict = {self.time_var: t_periodic, 'x': x_at_min_periodic}
            inputs_at_xmax_dict = {self.time_var: t_periodic, 'x': x_at_max_periodic}

            model_input_xmin = torch.cat([inputs_at_xmin_dict[var] for var in self.input_vars], dim=1)
            model_input_xmax = torch.cat([inputs_at_xmax_dict[var] for var in self.input_vars], dim=1)

            u_at_xmin = model(model_input_xmin)
            u_at_xmax = model(model_input_xmax)
            loss_val_periodic = torch.mean((u_at_xmin - u_at_xmax)**2)

            loss_deriv_periodic = torch.tensor(0.0, device=loss_val_periodic.device)
            # For derivative u_x(t, x_min) = u_x(t, x_max)
            # We need to compute derivatives w.r.t x at these specific (t_periodic, x_at_min/max_periodic) points

            # This uses the Trainer's _compute_derivatives, which needs the Trainer instance.
            # This is a bit of a circular dependency if not handled carefully.
            # Alternative: Compute grads directly here.

            # Compute u_x at x_min for t_periodic
            du_dx_at_xmin = torch.autograd.grad(u_at_xmin, x_at_min_periodic, grad_outputs=torch.ones_like(u_at_xmin), create_graph=True)[0]
            # Compute u_x at x_max for t_periodic
            du_dx_at_xmax = torch.autograd.grad(u_at_xmax, x_at_max_periodic, grad_outputs=torch.ones_like(u_at_xmax), create_graph=True)[0]

            if du_dx_at_xmin is not None and du_dx_at_xmax is not None:
                loss_deriv_periodic = torch.mean((du_dx_at_xmin - du_dx_at_xmax)**2)
            else:
                print("Warning: Could not compute derivatives for periodic BC enforcement.")

            return loss_val_periodic + loss_deriv_periodic # Add relative weighting if desired

        elif self.periodic_bc_scheme == "hard_transform":
            raise NotImplementedError("Hard periodic BC transform not implemented here.")
        return torch.tensor(0.0, device=model_outputs_bc.device) # Default if no scheme matches

    def get_ground_truth(self, inputs: Dict[str, torch.Tensor],
                         kappa_value: float) -> Optional[torch.Tensor]:
        # 1-soliton solution: u(t, x) = A * sech^2( sqrt(A/2) * (x - A*t - x0) )
        # A = kappa_value. Let x0 (initial position) = 0.
        # Speed c = A.
        x0 = 0.0
        A = kappa_value
        c_soliton = A

        t = inputs['t']
        x = inputs['x']

        if A <= 0: A = 1e-6 # Avoid issues, or problem is ill-defined for this solution

        sqrt_arg = torch.tensor(A / 2.0, dtype=x.dtype, device=x.device)
        arg = torch.sqrt(sqrt_arg) * (x - c_soliton * t - x0)
        u_true = A * (self._sech(arg)**2)
        return u_true

class SineGordonPDE(PDEProblem):
    def __init__(self, stationary_kink=True):
        super().__init__(name="SineGordon",
                         input_vars=['t', 'x'],
                         output_vars=['u'],
                         time_var='t',
                         kappa_name="beta", # Coefficient of sin(u)
                         default_kappa_value=1.0)
        self.stationary_kink = stationary_kink # If true, c=0

    def get_domain_bounds(self) -> Dict[str, Tuple[float, float]]:
        # Domain needs to be wide enough for solitons/kinks
        return {'t': (0.0, 10.0), 'x': (-20.0, 20.0)}

    def get_required_derivative_orders(self) -> Dict[str, Dict[Tuple[str, ...], int]]:
        return {
            'u': {
                ('t',): 2,  # For u_tt
                ('x',): 2   # For u_xx
            }
        }

    def get_required_derivative_orders_for_ic(self) -> Optional[Dict[str, Dict[Tuple[str, ...], int]]]:
    # We need the PINN's d(u)/dt(1) at t=0 to compare against the analytical u_t(0,x).
        return {'u': {('t',): 1}}


    def pde_residual(self, inputs: Dict[str, torch.Tensor],
                       model_outputs: torch.Tensor,
                       derivatives: Dict[str, torch.Tensor],
                       kappa_value: float) -> torch.Tensor:
        # PDE: u_tt - u_xx + beta * sin(u) = 0. beta = kappa_value
        u = model_outputs[:, 0:1]
        u_tt = derivatives['d2(u)_dt(2)']
        u_xx = derivatives['d2(u)_dx(2)']
        beta = kappa_value

        residual = u_tt - u_xx + beta * torch.sin(u)
        return residual

    def _kink_solution_val(self, x: torch.Tensor, t: Optional[torch.Tensor], kappa_value: float, c: float = 0.0) -> torch.Tensor:
        # General kink: 4 * arctan(exp(gamma * (x - c*t - x0)))
        # gamma = sqrt(kappa / (1-c^2))
        # For stationary (c=0): 4 * arctan(exp(sqrt(kappa) * (x - x0)))
        # Let x0 = 0
        x0 = 0.0
        beta = kappa_value
        if beta <=0: beta = 1e-6 # ensure positive beta

        if abs(c) >= 1.0: raise ValueError("Kink speed |c| must be < 1 for this solution form.")

        sqrt_arg = torch.tensor(beta / (1.0 - c**2 + 1e-9), dtype=x.dtype, device=x.device) # Ensure same device
        gamma = torch.sqrt(sqrt_arg) # Add epsilon for c very close to 1

        arg_exp = x - x0
        if t is not None:
            arg_exp = arg_exp - c * t

        return 4 * torch.arctan(torch.exp(gamma * arg_exp))

    def _kink_solution_dt_val(self, x: torch.Tensor, t: torch.Tensor, kappa_value: float, c: float) -> torch.Tensor:
        # u_t = -c * gamma * 4 * exp(gamma*(x-ct-x0)) / (1 + exp(2*gamma*(x-ct-x0)))
        # u_t = -c * gamma * 2 * sech(gamma*(x-ct-x0))
        if c == 0.0:
            return torch.zeros_like(x)

        x0 = 0.0
        beta = kappa_value
        if beta <=0: beta = 1e-6
        if abs(c) >= 1.0: raise ValueError("Kink speed |c| must be < 1 for this solution form.")
        sqrt_arg = torch.tensor(beta / ((1.0 - c**2 + 1e-9)), dtype=x.dtype, device=x.device) # Ensure same device
        gamma = torch.sqrt(sqrt_arg)

        arg_exp = gamma * (x - c * t - x0)
        # Using the sech form: 2 / (exp(arg) + exp(-arg)) = 2 * exp(arg) / (exp(2*arg) + 1)
        # Simplified: u_t = -c * deriv_of_u_wrt_arg_of_exp
        # u = 4 arctan(exp(Y)), du/dY = 4 * exp(Y) / (1+exp(2Y))
        # Y = gamma*(x-ct-x0), dY/dt = -c*gamma
        # So u_t = (4 * exp(Y) / (1+exp(2Y))) * (-c*gamma)
        exp_Y = torch.exp(arg_exp)
        u_t_val = (4 * exp_Y / (1 + exp_Y**2)) * (-c * gamma)
        return u_t_val


    def initial_conditions(self, inputs_ic: Dict[str, torch.Tensor],
                               model_outputs_ic: torch.Tensor,
                               derivatives_ic: Dict[str, torch.Tensor],
                               model: nn.Module,
                               kappa_value: float) -> torch.Tensor:
        t_ic = inputs_ic['t'] # Should be all zeros
        x_ic = inputs_ic['x']

        # Kink speed c. If stationary_kink is True, c=0. Otherwise, pick a c, e.g. 0.5
        c_kink = 0.0 if self.stationary_kink else 0.5

        u_true_at_ic = self._kink_solution_val(x_ic, t_ic, kappa_value, c=c_kink)
        loss_u = torch.mean((model_outputs_ic - u_true_at_ic)**2)

        # For u_t(0,x)
        # The model doesn't directly output u_t. We need to compute it from the model.
        # The IC is on u_t, so we need derivatives_ic['d1(u)_dt(1)'] from the PINN.
        u_t_true_at_ic = self._kink_solution_dt_val(x_ic, t_ic, kappa_value, c=c_kink)

        loss_ut = torch.tensor(0.0, device=loss_u.device)
        if 'd1(u)_dt(1)' in derivatives_ic: # if Trainer provided u_t from PINN
            pinn_ut_at_ic = derivatives_ic['d1(u)_dt(1)']
            loss_ut = torch.mean((pinn_ut_at_ic - u_t_true_at_ic)**2)
        elif c_kink != 0.0: # If u_t is non-zero and not provided by trainer, it's an issue
             print("Warning (SineGordon IC): Non-zero u_t IC but d1(u)_dt(1) not in derivatives_ic.")


        return loss_u + loss_ut # Add weights if needed

    def boundary_conditions(self, inputs_bc: Dict[str, torch.Tensor],
                                model_outputs_bc: torch.Tensor,
                                derivatives_bc: Dict[str, torch.Tensor],
                                model: nn.Module,
                                kappa_value: float) -> torch.Tensor:
        # Fix u to its analytical kink profile values at boundaries x_min, x_max for all t in inputs_bc
        t_bc = inputs_bc['t']
        x_bc = inputs_bc['x'] # Contains points at x_min and x_max

        x_min_val, x_max_val = self.get_domain_bounds()['x']
        c_kink = 0.0 if self.stationary_kink else 0.5

        u_true_at_bc = self._kink_solution_val(x_bc, t_bc, kappa_value, c=c_kink)

        # Assuming model_outputs_bc corresponds one-to-one with inputs_bc
        loss_bc = torch.mean((model_outputs_bc - u_true_at_bc)**2)

        # Optional: Fix u_x to analytical u_x at boundaries for smoother enforcement
        # ux_true_at_bc = ... (compute analytical du/dx)
        # pinn_ux_at_bc = derivatives_bc['d1(u)_dx(1)']
        # loss_ux_bc = torch.mean((pinn_ux_at_bc - ux_true_at_bc)**2)
        # return loss_bc + loss_ux_bc

        return loss_bc

    def get_ground_truth(self, inputs: Dict[str, torch.Tensor],
                           kappa_value: float) -> Optional[torch.Tensor]:
        t = inputs['t']
        x = inputs['x']
        c_kink = 0.0 if self.stationary_kink else 0.5
        return self._kink_solution_val(x, t, kappa_value, c=c_kink)

class AllenCahnPDE(PDEProblem):
    def __init__(self):
        # kappa = 1/D, where D is diffusion coeff. Small D -> sharp interface.
        # Let's rename kappa_name to "inv_D" for clarity. Default D=0.1 -> kappa=10
        super().__init__(name="AllenCahn",
                         input_vars=['t', 'x'],
                         output_vars=['u'],
                         time_var='t',
                         kappa_name="inv_D",
                         default_kappa_value=10.0)

    def get_domain_bounds(self) -> Dict[str, Tuple[float, float]]:
        return {'t': (0.0, 5.0), 'x': (-5.0, 5.0)} # Adjust as needed

    def get_required_derivative_orders(self) -> Dict[str, Dict[Tuple[str, ...], int]]:
        return {
            'u': {
                ('t',): 1,  # For u_t
                ('x',): 2   # For u_xx
            }
        }

    def pde_residual(self, inputs: Dict[str, torch.Tensor],
                       model_outputs: torch.Tensor,
                       derivatives: Dict[str, torch.Tensor],
                       kappa_value: float) -> torch.Tensor:
        # PDE: u_t = D * u_xx - u*(u^2-1).  D = 1.0 / kappa_value
        # Residual: u_t - D * u_xx + u*(u^2-1) = 0
        u = model_outputs[:, 0:1]
        u_t = derivatives['d1(u)_dt(1)']
        u_xx = derivatives['d2(u)_dx(2)']

        inv_D = kappa_value
        if inv_D <= 1e-6 : inv_D = 1e-6 # Avoid division by zero if D is kappa
        D_coeff = 1.0 / inv_D

        # Reaction term f(u) = u^3 - u  (or u(u-1)(u+1))
        # The PDE is often u_t = D u_xx - f'(u) where F(u) = 1/4 (u^2-1)^2
        # so f'(u) = u(u^2-1).
        # Residual: u_t - D * u_xx + u*(u^2-1)
        reaction_term = u * (u**2 - 1.0)
        residual = u_t - D_coeff * u_xx + reaction_term
        return residual

    def _stationary_front_val(self, x: torch.Tensor, D_coeff: float) -> torch.Tensor:
        # u(x) = tanh(x / sqrt(2*D))
        if D_coeff <= 1e-9: D_coeff = 1e-9 # Avoid sqrt of zero or negative

        # Convert the argument of sqrt to a tensor on the same device as x
        sqrt_arg = torch.tensor(2.0 * D_coeff, dtype=x.dtype, device=x.device)

        return torch.tanh(x / torch.sqrt(sqrt_arg))

    def initial_conditions(self, inputs_ic: Dict[str, torch.Tensor],
                               model_outputs_ic: torch.Tensor,
                               derivatives_ic: Dict[str, torch.Tensor],
                               model: nn.Module,
                               kappa_value: float) -> torch.Tensor:
        # IC: u(0, x) = tanh(x / sqrt(2D))
        x_ic = inputs_ic['x']
        inv_D = kappa_value
        if inv_D <= 1e-6 : inv_D = 1e-6
        D_coeff = 1.0 / inv_D

        u_true_at_ic = self._stationary_front_val(x_ic, D_coeff)
        loss_ic = torch.mean((model_outputs_ic - u_true_at_ic)**2)
        return loss_ic

    def boundary_conditions(self, inputs_bc: Dict[str, torch.Tensor],
                                model_outputs_bc: torch.Tensor,
                                derivatives_bc: Dict[str, torch.Tensor],
                                model: nn.Module,
                                kappa_value: float) -> torch.Tensor:
        # BC: u(t, x_min) = -1, u(t, x_max) = +1 (for the tanh-like profile)
        # These are the asymptotic values of the stationary front.
        t_bc = inputs_bc['t'] # Not used if BC is time-independent
        x_bc = inputs_bc['x'] # Contains points at x_min and x_max

        x_min_val, x_max_val = self.get_domain_bounds()['x']
        loss_val = torch.tensor(0.0, device=model_outputs_bc.device)

        # Points at x_min
        mask_xmin = (x_bc == x_min_val)
        if torch.any(mask_xmin):
            u_pred_xmin = model_outputs_bc[mask_xmin]
            u_true_xmin = torch.full_like(u_pred_xmin, -1.0)
            loss_val += torch.mean((u_pred_xmin - u_true_xmin)**2)

        # Points at x_max
        mask_xmax = (x_bc == x_max_val)
        if torch.any(mask_xmax):
            u_pred_xmax = model_outputs_bc[mask_xmax]
            u_true_xmax = torch.full_like(u_pred_xmax, 1.0)
            loss_val += torch.mean((u_pred_xmax - u_true_xmax)**2)

        return loss_val

    def get_ground_truth(self, inputs: Dict[str, torch.Tensor],
                           kappa_value: float) -> Optional[torch.Tensor]:
        # For Allen-Cahn, the stationary front u(x) = tanh(x / sqrt(2D)) is a solution if c=0.
        # If the initial condition is this front, and BCs match, it should remain stationary.
        # So, u(t,x) = tanh(x / sqrt(2D)).
        x = inputs['x']
        # t = inputs['t'] # Not used for stationary solution

        inv_D = kappa_value
        if inv_D <= 1e-6 : inv_D = 1e-6
        D_coeff = 1.0 / inv_D

        return self._stationary_front_val(x, D_coeff)


## Experiment Runner

In [14]:
if __name__ == '__main__':
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    pde_instances_map = {
        "TrivialLinear": TrivialLinearPDE(),
        "Poisson": PoissonPDE(),
        "Burgers": BurgersPDE(), # Ground truth and BCs still placeholders
        "KdV": KdVPDE(),         # BCs still placeholders
        "SineGordon": SineGordonPDE(stationary_kink=True), # Test stationary first
        # "SineGordonMoving": SineGordonPDE(stationary_kink=False), # For later
        "AllenCahn": AllenCahnPDE(),
    }

    runner = ExperimentRunner(base_results_dir="experiment_data_test/",
                              pde_map=pde_instances_map,
                              device=DEVICE)

    configs_to_run = []

    # --- Test Configs ---
    # test_pdes_to_run = ["TrivialLinear", "Poisson", "Burgers", "KdV", "SineGordon", "AllenCahn"] # Focus on one new PDE at a time for initial testing
    test_pdes_to_run = ["Burgers", "KdV"]
    test_widths = [30]
    test_activations = ["tanh"]
    test_seeds = [42]
    test_epochs = 1000 # Short run for testing

    # Define kappa values for each PDE specifically for testing
    KAPPA_VALS_MAP_TEST = {
        "TrivialLinear": [1.0], # Kappa not used
        "Poisson": [1.0],       # Kappa not used
        "Burgers": [10.0],      # nu = 0.1 (relatively easy)
        "KdV": [0.5],           # Amplitude A = 0.5
        "SineGordon": [0.25],   # beta = 0.25 (weaker nonlinearity)
        "AllenCahn": [2.0]      # inv_D = 2.0  => D = 0.5 (less sharp interface)
    }

    for pde_name_to_test in test_pdes_to_run:
        if pde_name_to_test not in pde_instances_map:
            print(f"Skipping {pde_name_to_test}, not in pde_instances_map for testing.")
            continue

        kappas_for_this_pde = KAPPA_VALS_MAP_TEST.get(pde_name_to_test, [pde_instances_map[pde_name_to_test].default_kappa_value])

        for kappa_v_test in kappas_for_this_pde:
            for width_v_test in test_widths:
                for act_v_test in test_activations:
                    for seed_v_test in test_seeds:
                        # Ensure log_epochs list is sensible for short test_epochs
                        log_epochs_test = sorted(list(set([0, test_epochs // 2, test_epochs])))

                        configs_to_run.append(ExperimentConfig(
                            pde_name=pde_name_to_test,
                            kappa_val=kappa_v_test,
                            activation_str=act_v_test,
                            seed=seed_v_test,
                            depth=1, # SLN
                            width=width_v_test,
                            M_collocation_factor=10, # M_collocation = width * 10
                            # M_collocation_pts = 200, # Or fixed number
                            num_bc_pts_per_face=50, # Reasonable fixed number for 1D spatial
                            num_ic_pts=100,         # Reasonable fixed number for 1D spatial
                            collocation_scheme="uniform",
                            optimizer_type="adam",
                            lr=1e-3,
                            epochs=test_epochs,
                            log_epochs_list=log_epochs_test,
                            num_test_pts_error_grid=501 # Fewer points for faster testing
                        ))

    for cfg_test in configs_to_run:
        runner.run_single_experiment(cfg_test)

Using device: cpu

--- Running Experiment: experiment_data_test/Burgers\kappa_1p0e01\act_tanh\N_30\D_1\seed_42 ---
Config: ExperimentConfig(pde_name='Burgers', kappa_val=10.0, activation_str='tanh', seed=42, depth=1, width=30, optimizer_type='adam', lr=0.001, epochs=1000, log_epochs_list=[0, 500, 1000], num_test_pts_error_grid=501, loss_weight_pde=1.0, loss_weight_bc=1.0, loss_weight_ic=1.0, M_collocation_pts=None, M_collocation_factor=10, num_total_bc_pts=None, num_bc_pts_per_face=50, num_ic_pts=100, collocation_scheme='uniform')
Epoch 0/1000, Loss: 3.672e+00, L2_err_rel: nan, GradNorm: 0.000e+00
Epoch 500/1000, Loss: 4.370e-01, L2_err_rel: nan, GradNorm: 1.160e-01
Epoch 1000/1000, Loss: 4.057e-01, L2_err_rel: nan, GradNorm: 3.035e-01
Training finished. Total active time: 3.86s
Finished experiment. Final L2_err_rel: nan

--- Running Experiment: experiment_data_test/KdV\kappa_5p0e-01\act_tanh\N_30\D_1\seed_42 ---
Config: ExperimentConfig(pde_name='KdV', kappa_val=0.5, activation_str='t