<a href="https://colab.research.google.com/github/matinmoezzi/ebola-virus-ode-dnn/blob/main/neurodiffeq_example_lbfgs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install neurodiffeq

In [None]:
from itertools import chain
from neurodiffeq import diff
from neurodiffeq.conditions import IVP
from neurodiffeq.monitors import Monitor1D
from neurodiffeq.solvers import Solver1D
from neurodiffeq.callbacks import MonitorCallback
from neurodiffeq.networks import FCNN
from neurodiffeq.generators import Generator1D
from torch.optim import LBFGS
import torch.autograd
import matplotlib.pyplot as plt
import torch.nn as nn
import numpy as np
import torch
%matplotlib inline

In [None]:
class MySolver(Solver1D):
  def _run_epoch(self, key):
        r"""Run an epoch on train/valid points, update history, and perform an optimization step if key=='train'.

        :param key: {'train', 'valid'}; phase of the epoch
        :type key: str

        .. note::
            The optimization step is only performed after all batches are run.
        """
        self._phase = key
        epoch_loss = 0.0
        batch_loss = 0.0
        metric_values = {name: 0.0 for name in self.metrics_fn}

        # perform forward pass for all batches: a single graph is created and release in every iteration
        # see https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/17
        for batch_id in range(self.n_batches[key]):
            batch = self._generate_batch(key)

            def closure():
                nonlocal batch_loss
                if key == 'train':
                    self.optimizer.zero_grad()
                funcs = [
                    self.compute_func_val(n, c, *batch) for n, c in zip(self.nets, self.conditions)
                ]

                for name in self.metrics_fn:
                    value = self.metrics_fn[name](*funcs, *batch).item()
                    metric_values[name] += value
                residuals = self.diff_eqs(*funcs, *batch)
                residuals = torch.cat(residuals, dim=1)
                loss = self.criterion(residuals) + \
                    self.additional_loss(funcs, key)

                # normalize loss across batches
                # loss /= self.n_batches[key]

                # accumulate gradients before the current graph is collected as garbage
                if key == 'train':
                    loss.backward()
                    batch_loss = loss.item()
                return loss
                # epoch_loss += loss.item()
            if key == 'train':
                self._do_optimizer_step(closure=closure)
                epoch_loss += batch_loss
            else:
                epoch_loss += closure().item()

        # calculate mean loss of all batches and register to history
        self._update_history(epoch_loss / self.n_batches[key], 'loss', key)

        # perform optimization step when training
        # self.optimizer.zero_grad()
        # update lowest_loss and best_net when validating
        if key == 'valid':
            self._update_best()

        # calculate average metrics across batches and register to history
        for name in self.metrics_fn:
            self._update_history(
                metric_values[name] / self.n_batches[key], name, key)
            
  def _do_optimizer_step(self, closure=None):
        r"""Optimization procedures after gradients have been computed. Usually ``self.optimizer.step()`` is sufficient.
        At times, users can overwrite this method to perform gradient clipping, etc. Here is an example::

            import itertools
            class MySolver(Solver)
                def _do_optimizer_step(self):
                    nn.utils.clip_grad_norm_(itertools.chain([net.parameters() for net in self.nets]), 1.0, 'inf')
                    self.optimizer.step()
        """
        return self.optimizer.step(closure=closure)

In [None]:
def system_ode(u1, u2, t): return [diff(u1, t) - torch.cos(t) - u1**2 - u2 + (
    1 + t**2 + torch.sin(t)**2), diff(u2, t) - 2*t + (1 + t**2)*torch.sin(t) - u1*u2]

In [None]:
init_vals_pc = [
    IVP(t_0=0.0, u_0=0.0),
    IVP(t_0=0.0, u_0=1.0)]

In [None]:
monitor = Monitor1D(t_min=0, t_max=3, check_every=10)
monitor_callback = MonitorCallback(monitor)


def my_callback(solver):
    if solver.lowest_loss < 1e-6:
        solver._stop_training = True


In [None]:
nets_lv = [
    FCNN(n_input_units=1, n_output_units=1,
         n_hidden_units=10, actv=nn.Sigmoid),
    FCNN(n_input_units=1, n_output_units=1,
         n_hidden_units=10, actv=nn.Sigmoid)]

In [None]:
lbfgs = LBFGS(chain.from_iterable(n.parameters()
                                  for n in nets_lv), lr=0.01, max_iter=10)

In [None]:
solver = MySolver(
    ode_system=system_ode, conditions=init_vals_pc, t_min=0.0, t_max=3.0, nets=nets_lv, optimizer=lbfgs)

In [None]:
solver.fit(max_epochs=100, callbacks=[monitor_callback, my_callback])
solution_pc = solver.get_solution()

In [None]:
ts = np.linspace(0, 3, 3000)
y1, y2 = solution_pc(ts, to_numpy=True)
fig, axes = plt.subplots(1, 2)
axes[0].plot(ts, y1, label='ANN solution')
axes[0].plot(ts, np.sin(ts), label='analytical solution')
axes[0].legend()
axes[1].plot(ts, y2, label='ANN solution')
axes[1].plot(ts, 1 + ts**2, label='analytical solution')
axes[1].legend()
plt.show()