# Requirments

In [37]:
from collections import OrderedDict
import torch
from torch.func import functional_call, grad, vmap
import torchopt

# Differential equation

The differential equation to solve is:
$$
    \frac{d f}{dt} = R f(t)\left(1 - f(t)\right)
$$
with initial condition $f(0) = 0.5$.

This equation can be used to model population growth.

# Loss function

The loss function to train the neural network will be the sum of two terms, the first evaluates the differential equation in $M$ time points, the second enforces the initial conditions.
$$
    \begin{array}{lcl}
        L_{\mathrm{DE}} & = & \frac{1}{M} \sum_{j=1}^{M} \left( \frac{df_{\mathrm{NN}}}{dt}(t_j) - R f_{\mathrm{NN}}(t_j) \left( 1 - f_{\mathrm{NN}}(t_j) \right) \right)^2 \\
        L_{\mathrm{BC}} & = & \left( f_{\mathrm{NN}}(0) - 0.5 \right)^2 \\
        L & = & L_{\mathrm{DE}} + L_{\mathrm{BC}}
    \end{array}
$$

In [35]:
def tuple_to_dict_parameters(model, params):
    keys = list(dict(model.named_parameters()).keys())
    values = list(params)
    return OrderedDict(({k:v for k, v in zip(keys, values)}))

In [36]:
def make_forward_fn(model, derivative_order=1):

    def f(x: torch.Tensor, params: dict[str, torch.nn.Parameter] | tuple[torch.nn.Parameter, ...]) -> torch.Tensor:
        if isinstance(params, tuple):
            params_dict = tuple_to_dict_parameters(model, params)
        else:
            params_dict = params
        return functional_call(model, params_dict, (x, ))

    fns = [f]
    dfunc = f
    for _ in range(derivative_order):
        dfunc = grad(dfunc)
        fns.append(vmap(dfunc, in_dims=(0, None)))
    return fns

# Neural network

In [14]:
class PINN(torch.nn.Module):

    def __init__(self, nr_inputs, nr_layers, nr_neurons, activation=torch.nn.Tanh()):
        super().__init__()
        self.num_inputs = nr_inputs
        self.num_layers = nr_layers
        self.num_neurons = nr_neurons
        layers = []
        layers.append(torch.nn.Linear(self.num_inputs, self.num_neurons))
        for _ in range(self.num_layers):
            layers.append(torch.nn.Linear(self.num_neurons, self.num_neurons))
            layers.append(activation)
        layers.append(torch.nn.Linear(self.num_neurons, 1))
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x.reshape(-1, 1)).squeeze()

In [15]:
model = PINN(nr_inputs=1, nr_neurons=20, nr_layers=3)

In [38]:
f, dfdx = make_forward_fn(model, derivative_order=1)

In [39]:
R, x_boundary, f_boundary = 1.0, 0.0, 0.5

In [40]:
def loss_function(params, x):
    f_value = f(x, params)
    interior = dfdx(x, params) - R*f_value*(1.0 - f_value)
    boundaries = f(torch.tensor([x_boundary]), params) - torch.tensor([f_boundary])
    loss = torch.nn.MSELoss()
    return (loss(interior, torch.zeros_like(interior)) +
            loss(boundaries, torch.zeros_like(boundaries)))

In [41]:
batch_size = 30
nr_iters = 100
learning_rate = 1.0e-1
domain = (-5.0, 5.0)

In [42]:
optimizer = torchopt.FuncOptimizer(torchopt.adam(lr=learning_rate))

In [44]:
params = tuple(model.parameters())

In [46]:
for iteration in range(nr_iters):
    x = torch.FloatTensor(batch_size).uniform_(*domain)
    loss = loss_function(params, x)
    params = optimizer.step(loss, params)
    print(f'iteration {iteration + 1} with loss {float(loss)}')

iteration 1 with loss 0.11463413387537003
iteration 2 with loss 27.16790771484375
iteration 3 with loss 0.6420497894287109
iteration 4 with loss 2.8561654090881348
iteration 5 with loss 3.5104665756225586
iteration 6 with loss 2.035282611846924
iteration 7 with loss 0.14833369851112366
iteration 8 with loss 0.17722176015377045
iteration 9 with loss 0.29466158151626587
iteration 10 with loss 0.13089382648468018
iteration 11 with loss 0.3749958872795105
iteration 12 with loss 0.1261375993490219
iteration 13 with loss 0.17460887134075165
iteration 14 with loss 0.5559177994728088
iteration 15 with loss 0.41323214769363403
iteration 16 with loss 0.40801599621772766
iteration 17 with loss 0.33771443367004395
iteration 18 with loss 0.20494796335697174
iteration 19 with loss 0.10330705344676971
iteration 20 with loss 0.10414065420627594
iteration 21 with loss 0.18002170324325562
iteration 22 with loss 0.19086192548274994
iteration 23 with loss 0.1662447154521942
iteration 24 with loss 0.120684