In [None]:
!pip install pydantic

In [None]:
import torch
import torch.nn as nn
from pydantic import BaseModel, Field

def fourier_series(n):
    # n even
    return -480/(np.pi**4*(n**4))


def compute_residual(model, x, t):
    x = x.clone().detach().requires_grad_(True)
    t = t.clone().detach().requires_grad_(True)

    u = model(x, t)
    ones = torch.ones_like(u)
    d_t = torch.autograd.grad(
        u, t, grad_outputs=ones, create_graph=True, retain_graph=True
        )[0]
    d_x = torch.autograd.grad(
        u, x, grad_outputs=ones, create_graph=True, retain_graph=True
    )[0]
    d_xx = torch.autograd.grad(
        d_x, x, grad_outputs=torch.ones_like(d_x), create_graph=True,
        retain_graph=True
    )[0]

    return d_t - pinnConfig().alpha * d_xx


def initial_condition(x):
    return 10*(x-x**2)**2 + 3


def heat_function(x, t):
    a_0 = 1/3 + 3
    sum = 0
    for i in range(1, 20):
        exponential = np.exp(-1*pinnConfig().alpha*(2*i*np.pi)**2*t)
        sum += fourier_series(2*i)*np.cos(np.pi*2*i*x)*exponential
    return a_0 + sum

class netConfig(BaseModel):
    save_path: str = Field(
        default="parameters_ic.pth",
        description="Parameter's path"
    )
    neuron_inputs: int = Field(
        default=2,
        description='Number of neurons'
    )
    neuron_hidden: int = Field(
        default=100,
        description='Number of neurons'
    )
    hidden_layers_numbers: int = Field(
        default=8,
    )
    neuron_outputs: int = Field(
        default=1
    )
    epochs: int = Field(
        default=5000,
        description='Number of times that the parameter actualize'
    )
    lr: float = Field(
        default=1e-3,
        description='Learning rate'
    )


class plotConfig(BaseModel):
    sample: int = Field(
        default=100,
        description='Number o'
    )
    snapshot_step: int = Field(
        default=10,
        description=""
    )
    snap_x: int = Field(
        default=1000,
        description=""
    )
    snap_t: int = Field(
        default=100,
        description=""
    )
    frames_snap: int = Field(
        default=100,
        description=""
    )


class pinnConfig(BaseModel):
    alpha: float = Field(
        default=0.1,
        description="Important for the PDE"
    )

    num_collocation_res: int = Field(
        default=500,
        description=""
    )
    num_collocation_ic: int = Field(
        default=100,
        description=''
    )
    num_collocation_bc: int = Field(
        default=200,
        description=''
    )
    lambda_residual: float = Field(
        default=10.0,
        description=''
    )
    lambda_ic: float = Field(
        default=10.0,
        description=''
    )
    lambda_bc: float = Field(
        default=10.0,
        description=''
    )
    error_x_sample: int = Field(
        default=10000,
        description=''
    )
    error_t_sample: int = Field(
        default=100,
        description=''
    )

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        layer = [nn.Linear(netConfig().neuron_inputs,
                           netConfig().neuron_hidden), nn.Tanh()]
        for i in range(netConfig().hidden_layers_numbers):
            layer += [nn.Linear(netConfig().neuron_hidden,
                                netConfig().neuron_hidden), nn.Tanh()]
        layer += [nn.Linear(netConfig().neuron_hidden,
                            netConfig().neuron_outputs), nn.Tanh()]
        self.net = nn.Sequential(*layer)

    def forward(self, x, t):
        inp = torch.cat([x, t], dim=1)
        return self.net(inp)


def train_pinn():

    model = NeuralNetwork()
    optimizer = torch.optim.Adam(model.parameters(), lr=netConfig().lr)
    num_collocation_res = pinnConfig().num_collocation_res
    num_collocation_ic = pinnConfig().num_collocation_ic
    num_collocation_bc = pinnConfig().num_collocation_bc
    lambda_residual = pinnConfig().num_collocation_bc
    lambda_ic = pinnConfig().num_collocation_bc
    lambda_bc = pinnConfig().lambda_bc

    # Residual Collocation
    x_col_res = torch.rand(num_collocation_res, 1)
    t_col_res = torch.rand(num_collocation_res, 1)

    # Initial Condition Collocation
    x_col_ic = torch.rand(num_collocation_ic, 1)
    t_col_ic = torch.zeros((num_collocation_ic, 1))

    # Boundary Condition Collocation
    t_x_bc = torch.rand(num_collocation_bc, 1)
    x_bc = torch.zeros((num_collocation_bc, 1), requires_grad=True)
    t_l_bc = torch.rand(num_collocation_bc, 1)
    l_bc = torch.ones((num_collocation_bc, 1), requires_grad=True)

    # Neumann
    ux_0_bc = torch.zeros((num_collocation_bc, 1))
    ux_1_bc = torch.zeros((num_collocation_bc, 1))

    # Snapshot values

    snapshots = torch.zeros((plotConfig().snap_x,
                             plotConfig().snap_t,
                             plotConfig().frames_snap))

    for _ in range(netConfig().epochs):
        optimizer.zero_grad()

        # Residual
        residual = compute_residual(model, x_col_res, t_col_res)
        loss_residual = torch.mean(residual**2)

        # Initial
        model_ic = model(x_col_ic, t_col_ic)
        loss_ic = torch.mean((model_ic-initial_condition(x_col_ic))**2)

        # Boundary
        u_0_bc = model(x_bc, t_x_bc)
        du_0_bc = torch.autograd.grad(
            u_0_bc, x_bc, grad_outputs=torch.ones_like(u_0_bc),
            create_graph=True
        )[0]

        u_l_bc = model(l_bc, t_l_bc)
        du_l_bc = torch.autograd.grad(
            u_l_bc, l_bc, grad_outputs=torch.ones_like(u_l_bc),
            create_graph=True
        )[0]

        loss_0_bc = torch.mean((du_0_bc-ux_0_bc)**2)
        loss_1_bc = torch.mean((du_l_bc-ux_1_bc)**2)
        loss_b = (loss_0_bc + loss_1_bc)
        loss = lambda_residual*loss_residual+lambda_ic*loss_ic+lambda_bc*loss_b
        loss.backward()
        optimizer.step()
        if _ % 100 == 0:
            print(loss, _)
        """
        plotter = plots()
        if _ % plotConfig().snapshot_step == 0:
            if _ == netConfig().epochs-1:
                plotter.animate_snapshot(model, snapshots, _, True)
            else:
                plotter.animate_snapshot(model, snapshots, _, False)
        """

    save_path = netConfig().save_path
    torch.save(
            {'model_state_dict': model.state_dict()}, save_path
        )
    return model, snapshots


def main(flag: bool):
    if flag:
        model = NeuralNetwork()
        loaded = torch.load(netConfig().save_path)
        model.load_state_dict(loaded["model_state_dict"])
        model.eval()
        plotter = plots()
        plotter.heat_comparation(model)
    else:
        model = train_pinn()


if __name__ == "__main__":
    main(False)