In [1]:
from typing import List
import torch
from torch import nn, optim
import numpy as np
#from pinn import PINN, IPINN

In [None]:
from typing import List, Callable, Optional

In [None]:
from plotly import graph_objects as go, io as pio
from pathlib import Path
import matplotlib.pyplot as plt

In [None]:
torch.set_default_dtype(torch.double)
torch.set_default_device(torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))

In [None]:
class MLP(nn.Module):
    """Multi-Layer Perceptron (MLP) module."""

    def __init__(self, layer_size: List[int], activation: nn.Module = nn.Tanh()):
        super().__init__()

        self.linear = nn.ModuleList()
        for i in range(1, len(layer_size)):
            self.linear.append(nn.Linear(layer_size[i - 1], layer_size[i]))
        self.activation = activation

    def forward(self, x):
        for i, item in enumerate(self.linear[:-1]):
            x = self.activation(item(x))
        x = self.linear[-1](x)
        return x

In [None]:
class IPINN:
    """Physics-Informed Neural Network (PINN) class."""

    def __init__(self, layer_size: List[int], activation: nn.Module = nn.Tanh()):
        self.mlp = MLP(layer_size, activation)
        self.optimizer: Optional[optim.Optimizer] = None
        self.loss = nn.MSELoss()
        self.loss_history: List[float] = []
        self.param_history: List[List[np.ndarray]] = []
        self.epoch = 0

    def compile(
        self,
        ftns: List[
            Callable[[torch.Tensor, torch.Tensor, List[torch.Tensor]], torch.Tensor]
        ],
        pts: List[torch.Tensor],
        observation_pts: List[torch.Tensor],
        observation_val: List[torch.Tensor],
        parameter: List[torch.Tensor],
    ):
        self.ftns = ftns
        self.pts = pts
        self.observation_pts = observation_pts
        self.observation_val = observation_val
        parameter = list(map(lambda param: param.requires_grad_(True), parameter))
        self.parameter = parameter

        self.optimizer = optim.LBFGS(
            list(self.mlp.parameters()) + list(self.parameter),
            lr=1,
            max_iter=1000,
            max_eval=1250 * 1000 // 15000,
            tolerance_grad=1e-8,
            tolerance_change=0,
            history_size=100,
            line_search_fn="strong_wolfe",
        )
        if not len(self.ftns) == len(self.pts):
            raise ValueError(f"Arguments `ftns` and `pts` must have the same length.")
        if not len(self.observation_val) == len(self.observation_pts):
            raise ValueError(
                f"Arguments `observation_val` and `observation_pts` must have the same length."
            )

    def train(self, epochs: int, loss_weights: Optional[List[float]] = None):
        self.mlp.train()
        if loss_weights is None:
            loss_weights = [1.0 for _ in self.ftns + self.observation_val]
        else:
            if not len(loss_weights) == len(self.ftns) + len(self.observation_val):
                raise ValueError(
                    f"Arguments `loss_weights` and `ftns + observation_val` must have same length."
                )
        while self.epoch < epochs:

            def closure():
                losses = []
                for i, pt in enumerate(self.pts):
                    output = self.mlp(pt)
                    losses.append(
                        loss_weights[i]
                        * self.loss(
                            self.ftns[i](pt, output, self.parameter),
                            torch.zeros_like(output),
                        )
                    )
                for i, item in enumerate(self.observation_pts):
                    output = self.mlp(item)
                    losses.append(
                        loss_weights[i + len(self.pts)]
                        * self.loss(self.observation_val[i], output)
                    )

                total_loss = sum(losses)
                self.optimizer.zero_grad()
                total_loss.backward()

                self.loss_history.append(total_loss.item())
                self.param_history.append(
                    list(map(lambda item: item.detach().cpu().numpy(), self.parameter))
                )
                self.epoch += 1

                if self.epoch % 100 == 0:
                    print(
                        f"Epoch {self.epoch}/{epochs}, Loss: {self.loss_history[-1]:.6f}, "
                        f"L: {self.parameter[0][0].item():.6f}, "
                        f"R: {self.parameter[0][1].item():.6f}, "
                        f"C: {self.parameter[0][2].item():.6f}"
                    )

                return total_loss

            self.optimizer.step(closure)

    def validation(
        self,
        exact_param: List[float],
        exact_solution: Callable[[torch.Tensor], torch.Tensor],
    ):
        self.mlp.eval()

        # validation points
        x = torch.linspace(0, 6, 101).reshape(-1, 1)

        # exact solution
        y = exact_solution(x)

        # evaluation
        with torch.no_grad():
            y_eval = self.mlp(x)

        # calculate mean relative L_2 norm
        error = torch.sqrt(
            torch.trapezoid((y - y_eval) ** 2, x, dim=0)
            / torch.trapezoid(y**2, x, dim=0)
        ).item()
        param_error = torch.mean(
            (torch.asarray(exact_param) - self.parameter[0]) ** 2
            / torch.asarray(exact_param) ** 2
        ).item()

        print(f"Validation Error: {error * 100:.4f} [%]")
        print(f"Parameter Error:  {param_error * 100:.4f} [%]")

        x = x.detach().cpu().numpy().flatten()
        y = y.detach().cpu().numpy().flatten()
        y_eval = y_eval.detach().cpu().numpy().flatten()

        train_x = torch.cat(self.pts)
        train_y = exact_solution(train_x)

        train_x = train_x.detach().cpu().numpy().flatten()
        train_y = train_y.detach().cpu().numpy().flatten()

        observe_x = self.observation_pts[0].detach().cpu().numpy().flatten()
        observe_y = self.observation_val[0].detach().cpu().numpy().flatten()

        data = [
            go.Scatter(
                x=x,
                y=y,
                mode="lines",
                line=go.scatter.Line(width=5),
                name="True solution",
            ),
            go.Scatter(
                x=x,
                y=y_eval,
                mode="lines",
                line=go.scatter.Line(width=5, dash="dash"),
                name="Predicted solution",
            ),
            go.Scatter(
                x=train_x,
                y=train_y,
                mode="markers",
                marker=go.scatter.Marker(size=10),
                name="Training points",
            ),
            go.Scatter(
                x=observe_x,
                y=observe_y,
                mode="markers",
                marker=go.scatter.Marker(size=10),
                name="Observation points",
            ),
        ]
        layout = go.Layout(
            template="plotly_white",
            width=1300,
            height=1300,
            font=go.layout.Font(family="Times New Roman", size=25),
        )
        fig = go.Figure(data, layout)
        print(f"Results FIle Saved in {Path.cwd().absolute()}")
        pio.write_html(fig, Path.cwd() / "inverse_pinn_validation.html")

        data = [
            go.Scatter(
                y=self.loss_history,
                mode="lines",
                line=go.scatter.Line(width=5),
                name="loss",
            )
        ]
        fig = go.Figure(data, layout)
        pio.write_html(fig, Path.cwd() / "inverse_pinn_loss_history.html")

        param_history = np.asarray(self.param_history).T
        param_name = ["L", "R", "C"]
        data = []
        for i in range(len(exact_param)):
            data.append(
                go.Scatter(
                    y=exact_param[i] * np.ones_like(param_history[i, :]).flatten(),
                    mode="lines",
                    line=go.scatter.Line(width=5, dash="dash"),
                    name=f"Exact parameter {param_name[i]}",
                )
            )
        for i, item in enumerate(param_history):
            data.append(
                go.Scatter(
                    y=item.flatten(),
                    mode="lines",
                    line=go.scatter.Line(width=5),
                    name=f"Estimated parameter {param_name[i]}",
                )
            )
        fig = go.Figure(data, layout)
        pio.write_html(fig, Path.cwd() / "inverse_pinn_param_history.html")

In [None]:
def differential_equation(input: torch.Tensor, output: torch.Tensor, params: List[torch.Tensor]) -> torch.Tensor:
    """
    Physics-Informed equation (L * d2i/dt2 + R * di/dt + 1/C * i = 0)

    ...

    Parameters
    ----------
    input : torch.Tensor
            Input tensor of shape (batch_size, 1)
    output : torch.Tensor
            Output tensor of shape (batch_size, 1)
    params: List[torch.Tensor]
            List of the parameters (L, R, C)

    Returns
    -------
    Physics-Informed equation
    """

    i = output
    L = params[0][0]
    R = params[0][1]
    C = params[0][2]
    di_dt = derivative(input, output)
    d2i_dt2 = second_derivative(input, output)

    return L * d2i_dt2 + R * di_dt + i / C

def initial_condition(input: torch.Tensor, output: torch.Tensor, params: List[torch.Tensor]) -> torch.Tensor:
    """
    Initial condition (i(0) = 0)

    ...

    Parameters
    ---------
    input : torch.Tensor
            Input tensor of shape (the number of initial points, 1)
    output : torch.Tensor
            Output tensor of shape (the number of initial points, 1)
    params: List[torch.Tensor]
            List of the parameters

    Returns
    -------
    Initial value at t = 0
    """

    i = output
    return i

def initial_condition2(input: torch.Tensor, output: torch.Tensor, params: List[torch.Tensor]) -> torch.Tensor:
    """
    Initial condition of derivative (L * di/dt = V0)

    ...

    Parameters
    ----------
    input : torch.Tensor
            Input tensor of shape (the number of initial points, 1)
    output : torch.Tensor
            Output tensor of shape (the number of initial points, 1)
    params : List[torch.Tensor]
            List of the parameters

    Returns
    -------
    Initial value of derivative at t = 0
    """

    di_dt = derivative(input, output)
    L = params[0][0]
    return L * di_dt - V0

In [None]:
def derivative(input: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
    """
    Calculate the derivative of `output` with respect to `input`.

    ...

    Parameters
    ----------
    input : torch.Tensor
            Input tensor of shape (batch_size, 1)
    output : torch.Tensor
            Output tensor of shape (batch_size, 1)

    Returns
    -------
    Derivative of `output` with respect to `input`.
    """

    return torch.autograd.grad(output, input, grad_outputs=torch.ones_like(output), create_graph=True)[0]

def second_derivative(input: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
    """
    Calculate the second derivative of `output` with respect to `input`.

    ...

    Parameters
    ----------
    input : torch.Tensor
            Input tensor of shape (batch_size, 1)
    output : torch.Tensor
            Output tensor of shape (batch_size, 1)

    Returns
    -------
    Second derivative of `output` with respect to `input`.
    """

    deri = derivative(input, output)
    return torch.autograd.grad(deri, input, grad_outputs=torch.ones_like(deri), create_graph=True)[0]

In [None]:
# exact parameter
R = 1.2
L = 1.5
C = 0.3

def exact_solution(input: torch.Tensor) -> torch.Tensor:
    """
    Exact solution

    ...

    Parameters
        ----------
    input : torch.Tensor

    Returns
    -------
    Exact solution
    """

    t = input
    return 5.57 * torch.exp(-0.4 * t) * torch.sin(1.44 * t)

In [None]:
V0 = 12.0

# Initial Parameter
param = torch.as_tensor([1.0, 1.0, 1.0])

# the number of points for training
n_points = 41

# Construct training points randomly in [0, 6]
x = (6.0*torch.rand((n_points, 1))).requires_grad_(True)

# Boundary points (x = 0)
bp = torch.zeros((1, 1)).requires_grad_(True)

# the number of observation points
n_observation = 20

# Construct observation points in [0, 6]
observation_pts = torch.linspace(0, 6, n_observation).reshape(-1, 1)

# Observation value (exact solution at observation points)
observation_value = exact_solution(observation_pts)

# Differential equations and boundary value functions
ftns = [differential_equation, initial_condition, initial_condition2]

# Points list corresponding to ftns
pts = [x, bp, bp]

pinn = IPINN(layer_size=[1, 10, 10, 10, 1],activation=torch.nn.Tanh(),)
pinn.compile(ftns=list(ftns),pts=list(pts),observation_pts=[observation_pts],observation_val=[observation_value],parameter=[param],)

pinn.train(epochs=1500)
pinn.validation([L, R, C], exact_solution)

Epoch 100/1500, Loss: 0.539962, L: 1.080745, R: 1.958168, C: 0.364849
Epoch 200/1500, Loss: 0.121694, L: 1.193725, R: 1.391230, C: 0.362414
Epoch 300/1500, Loss: 0.010572, L: 1.571493, R: 1.202770, C: 0.287699
Epoch 400/1500, Loss: 0.001659, L: 1.502306, R: 1.200264, C: 0.298464
Epoch 500/1500, Loss: 0.000314, L: 1.497300, R: 1.197113, C: 0.299228
Epoch 600/1500, Loss: 0.000147, L: 1.496914, R: 1.197950, C: 0.299264
Epoch 700/1500, Loss: 0.000107, L: 1.497341, R: 1.198046, C: 0.299149
Epoch 800/1500, Loss: 0.000096, L: 1.496552, R: 1.197629, C: 0.299264
Epoch 900/1500, Loss: 0.000077, L: 1.495455, R: 1.197162, C: 0.299457
Epoch 1000/1500, Loss: 0.000033, L: 1.495510, R: 1.196556, C: 0.299445
Epoch 1100/1500, Loss: 0.000017, L: 1.495518, R: 1.196098, C: 0.299403
Epoch 1200/1500, Loss: 0.000014, L: 1.495396, R: 1.196755, C: 0.299438
Epoch 1300/1500, Loss: 0.000010, L: 1.495424, R: 1.196805, C: 0.299421
Epoch 1400/1500, Loss: 0.000007, L: 1.495807, R: 1.196762, C: 0.299355
Epoch 1500/1500