In [None]:
# Imports 
from __future__ import annotations

from dataclasses import dataclass
from typing import (
    Callable,
    Dict,
    Iterable,
    Mapping,
    MutableMapping,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import torch
from torch import nn

from pina import Condition, LabelTensor
from pina.equation import Equation, FixedValue
from pina.model import FeedForward
from pina.operator import div, grad
from pina.solver.physics_informed_solver import PINN
from pina.solver.physics_informed_solver.rba_pinn import RBAPINN
from pina.solver.physics_informed_solver.self_adaptive_pinn import SelfAdaptivePINN
from pina.trainer import Trainer

from . import bc as bc_utils
from .geometry import PlanePatch, TowerAGeometry

In [None]:
# Helpers to standardize 
Axes = Tuple[str, ...]
TensorLike = Union[float, int, Sequence[float], torch.Tensor, LabelTensor]
ConductivityLike = Union[TensorLike, Callable[[LabelTensor], TensorLike]]
BoundaryValueLike = Union[TensorLike, Callable[[LabelTensor], TensorLike]]

SolverType = Union[PINN, SelfAdaptivePINN, RBAPINN]
SolverClass = Union[type[PINN], type[SelfAdaptivePINN], type[RBAPINN]]

DEFAULT_AXES: Axes = ("x", "y", "z")


def _as_tensor(value: TensorLike, input_: LabelTensor, components: int = 1) -> torch.Tensor:
    """Convert ``value`` into a tensor aligned with ``input_``'s batch."""
    device = input_.device
    dtype = input_.dtype
    batch = input_.shape[0]

    if isinstance(value, LabelTensor):
        tensor = value.tensor.to(device=device, dtype=dtype)
    elif isinstance(value, torch.Tensor):
        tensor = value.to(device=device, dtype=dtype)
    elif isinstance(value, (Sequence,)):
        tensor = torch.as_tensor(value, dtype=dtype, device=device)
    else:
        tensor = torch.tensor(value, dtype=dtype, device=device)

    if tensor.ndim == 0:
        tensor = tensor.reshape(1, 1).repeat(batch, components)
    elif tensor.ndim == 1:
        if tensor.shape[0] == batch and components == 1:
            tensor = tensor.reshape(batch, 1)
        elif tensor.shape[0] == components:
            tensor = tensor.reshape(1, components).repeat(batch, 1)
        elif tensor.shape[0] == batch and components != 1:
            tensor = tensor.reshape(batch, 1).repeat(1, components)
        else:
            raise ValueError(
                f"Cannot reshape tensor of shape {tuple(tensor.shape)} to ({batch}, {components})"
            )
    elif tensor.ndim == 2:
        if tensor.shape[0] != batch:
            raise ValueError(
                f"Expected first dimension equal to batch size {batch}, got {tensor.shape}"
            )
        if tensor.shape[1] != components:
            if tensor.shape[1] == 1 and components > 1:
                tensor = tensor.repeat(1, components)
            else:
                raise ValueError(
                    f"Expected second dimension {components}, got {tensor.shape}"
                )
    else:
        raise ValueError("Expected scalar, vector, or matrix conductivity representation.")

    return tensor


def evaluate_sigma_diag(conductivity: ConductivityLike, input_: LabelTensor, axes: Axes) -> torch.Tensor:
    """Evaluate (potentially spatially varying) conductivity on ``input_`` points.

    Returns an ``(batch, len(axes))`` tensor corresponding to the diagonal of
    :math:`\\boldsymbol{\\sigma}` in the coordinate system spanned by ``axes``.
    """
    if callable(conductivity):
        value = conductivity(input_)
    else:
        value = conductivity
    sigma = _as_tensor(value, input_, components=len(axes))
    if sigma.shape[1] == 1 and len(axes) > 1:
        sigma = sigma.repeat(1, len(axes))
    return sigma


def evaluate_scalar(value: BoundaryValueLike, input_: LabelTensor) -> torch.Tensor:
    """Evaluate a scalar boundary target on ``input_`` points."""
    if callable(value):
        evaluated = value(input_)
    else:
        evaluated = value
    return _as_tensor(evaluated, input_, components=1).squeeze(-1)



# Equation Builders â€” Laplace, Neumann, Dirichlet

These three functions define the **physics-informed residuals** 
Each returns a `pina.Equation` object that computes the deviation from the desired physical law or boundary condition at sampled points.

---

## `build_laplace_equation`

```python
def build_laplace_equation(conductivity: ConductivityLike, axes: Axes = DEFAULT_AXES) -> Equation:
    """Return the âˆ‡Â·(Ïƒâˆ‡Ï†) residual as a PINA equation."""
```

### Purpose

Builds the **core PDE residual** for the quasi-static electric potential:
[
\nabla \cdot (\sigma \nabla \phi) = 0
]
which ensures current conservation within the domain (no internal sources/sinks).

---

### Step-by-step

```python
def residual(input_: LabelTensor, output_: LabelTensor) -> LabelTensor:
```

Defines how the residual is computed for a batch of spatial coordinates (`input_`) and predicted potentials (`output_`).

---

```python
grad_phi = grad(output_, input_, components=["phi"], d=list(axes))
```

Computes the gradient of Ï† with respect to the coordinates â€”
i.e., (\nabla \phi = [âˆ‚Ï†/âˆ‚x, âˆ‚Ï†/âˆ‚y, âˆ‚Ï†/âˆ‚z]).
This corresponds to the local **electric field direction** (up to sign).

---

```python
sigma_diag = evaluate_sigma_diag(conductivity, input_, axes)
```

Evaluates the diagonal conductivity tensor Ïƒ(x) at each point â€”
either a constant scalar (homogeneous medium) or a callable function Ïƒ(x).

---

```python
flux_tensor = sigma_diag * grad_phi.tensor
```

Computes the **current density vector**
[
\mathbf{J} = \sigma \nabla \phi
]
elementwise along each axis.

---

```python
flux = LabelTensor(flux_tensor, labels=[f"J_{ax}" for ax in axes])
```

Wraps the flux components (`J_x`, `J_y`, `J_z`) into a `LabelTensor` so that the next operator (divergence) knows which axis each belongs to.

---

```python
divergence = div(flux, input_, components=flux.labels, d=list(axes))
```

Computes the divergence (\nabla \cdot \mathbf{J})
(i.e., âˆ‚Jâ‚“/âˆ‚x + âˆ‚Jáµ§/âˆ‚y + âˆ‚J_z/âˆ‚z).
If the PDE is satisfied, this value should be **zero everywhere**.

---

```python
return Equation(residual)
```

Wraps the residual function into a `pina.Equation` so it can be used as a loss term in the PINN training loop.

---

### Summary

| Symbol             | Meaning                      |
| :----------------- | :--------------------------- |
| âˆ‡Ï†                 | Electric field direction     |
| Ïƒâˆ‡Ï†                | Current density              |
| âˆ‡Â·(Ïƒâˆ‡Ï†)            | Conservation of current      |
| Residual = âˆ‡Â·(Ïƒâˆ‡Ï†) | Target â†’ 0 inside the domain |

---

## ðŸ§­ `build_neumann_equation`

```python
def build_neumann_equation(
    patch: PlanePatch,
    target_flux_density: BoundaryValueLike,
    conductivity: ConductivityLike,
    axes: Axes = DEFAULT_AXES,
) -> Equation:
    """Return the Neumann residual enforcing `-nÂ·Ïƒâˆ‡Ï† = target` on a patch."""
```

### Purpose

Defines a **flux boundary condition** (Neumann BC):
[

* \mathbf{n} \cdot \sigma \nabla \phi = q_{\text{target}}
  ]
  used for electrodes or insulating surfaces where the **normal current density** is prescribed.

---

### Step-by-step

```python
axis = patch.axis
axis_index = axes.index(axis)
```

Determines which coordinate axis the surface is aligned with (e.g., `'z'` for a top/bottom face).

---

```python
grad_component = grad(output_, input_, components=["phi"], d=[axis])
```

Computes âˆ‚Ï†/âˆ‚n â€” the derivative of potential along the surface normal direction.

---

```python
sigma_diag = evaluate_sigma_diag(conductivity, input_, axes)
sigma_axis = sigma_diag[:, axis_index]
```

Extracts the relevant conductivity component Ïƒâ‚™ along that axis.

---

```python
normal_flux = -patch.normal_sign * sigma_axis * grad_component.tensor.squeeze(-1)
```

Forms the **normal current density**
[
-nÂ·Ïƒâˆ‡Ï†
]
The `normal_sign` accounts for whether the patch normal points inward or outward.

---

```python
target = evaluate_scalar(target_flux_density, input_)
return normal_flux - target
```

Subtracts the desired flux (the applied or known boundary value).
The residual â†’ 0 when the model reproduces the target boundary flux.

---

### Summary

| Term                       | Meaning                                       |
| :------------------------- | :-------------------------------------------- |
| âˆ’nÂ·Ïƒâˆ‡Ï†                     | Normal current leaving the surface            |
| target                     | Desired flux (e.g. electrode current density) |
| residual = âˆ’nÂ·Ïƒâˆ‡Ï† âˆ’ target | Enforced to 0 on boundary points              |

---

## âš¡ `build_dirichlet_equation`

```python
def build_dirichlet_equation(target_value: BoundaryValueLike) -> Equation:
    """Return a Dirichlet residual enforcing `phi = target`."""
```

### Purpose

Defines a **potential boundary condition** (Dirichlet BC):
[
\phi = \phi_{\text{target}}
]
used to fix reference or grounded surfaces.

---

### Step-by-step

```python
target = evaluate_scalar(target_value, input_)
```

Evaluates the prescribed potential value (can be constant or spatially varying).

---

```python
phi = output_.extract(["phi"]).tensor.squeeze(-1)
return phi - target
```

Computes the difference between the predicted potential and the target value.
The residual â†’ 0 when Ï† matches the boundary condition.

---

### Summary

| Term                  | Meaning                              |
| :-------------------- | :----------------------------------- |
| Ï†                     | Predicted potential                  |
| target                | Fixed potential (boundary condition) |
| residual = Ï† âˆ’ target | Enforced to 0 at Dirichlet surfaces  |

---

## ðŸ§  Conceptual Recap

| Builder                    | Enforces                | Equation        | Region                 |
| :------------------------- | :---------------------- | :-------------- | :--------------------- |
| `build_laplace_equation`   | Conservation of current | âˆ‡Â·(Ïƒâˆ‡Ï†)=0       | Interior               |
| `build_neumann_equation`   | Specified current flux  | âˆ’nÂ·Ïƒâˆ‡Ï† = target | Boundary (flux patch)  |
| `build_dirichlet_equation` | Fixed potential         | Ï† = target      | Boundary (fixed patch) |

---

Would you like me to add a short **diagram (Markdown + math)** showing how the three regions interact (interior vs. boundary patches)? It helps visualize how these residuals combine during training.


In [None]:
def build_laplace_equation(conductivity: ConductivityLike, axes: Axes = DEFAULT_AXES) -> Equation:
    """Return the âˆ‡Â·(Ïƒâˆ‡Ï†)  residual as a PINA equation."""

    def residual(input_: LabelTensor, output_: LabelTensor) -> LabelTensor:
        grad_phi = grad(output_, input_, components=["phi"], d=list(axes))
        sigma_diag = evaluate_sigma_diag(conductivity, input_, axes)
        flux_tensor = sigma_diag * grad_phi.tensor
        flux = LabelTensor(
            flux_tensor, labels=[f"J_{ax}" for ax in axes]
        )
        divergence = div(flux, input_, components=flux.labels, d=list(axes))
        return divergence

    return Equation(residual)


def build_neumann_equation(
    patch: PlanePatch,
    target_flux_density: BoundaryValueLike,
    conductivity: ConductivityLike,
    axes: Axes = DEFAULT_AXES,
) -> Equation:
    """Return the Neumann residual enforcing ``-nÂ·Ïƒâˆ‡Ï† = target`` on a patch."""
    axis = patch.axis
    if axis not in axes:
        raise ValueError(f"Patch axis {axis!r} not present in axes {axes}.")
    axis_index = axes.index(axis)

    def residual(input_: LabelTensor, output_: LabelTensor) -> LabelTensor:
        grad_component = grad(output_, input_, components=["phi"], d=[axis])
        sigma_diag = evaluate_sigma_diag(conductivity, input_, axes)
        sigma_axis = sigma_diag[:, axis_index]
        normal_flux = -patch.normal_sign * sigma_axis * grad_component.tensor.squeeze(-1)
        target = evaluate_scalar(target_flux_density, input_)
        return normal_flux - target

    return Equation(residual)


def build_dirichlet_equation(target_value: BoundaryValueLike) -> Equation:
    """Return a Dirichlet residual enforcing ``phi = target``."""

    def residual(input_: LabelTensor, output_: LabelTensor) -> LabelTensor:
        target = evaluate_scalar(target_value, input_)
        phi = output_.extract(["phi"]).tensor.squeeze(-1)
        return phi - target

    return Equation(residual)

