In [None]:
import torch 
import torch.nn as nn
from torchdiffeq import odeint as odeint
import pylab as plt
from torch.utils.data import Dataset, DataLoader
from typing import Callable, List, Tuple, Union, Optional
from pathlib import Path  

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class LotkaVolterra(nn.Module):
    """ 
     The Lotka-Volterra equations are a pair of first-order, non-linear, differential equations
     describing the dynamics of two species interacting in a predator-prey relationship.
    """
    def __init__(self,
                 alpha: float = 3.0, # The alpha parameter of the Lotka-Volterra system
                 beta: float =  0.6,  # The beta parameter of the Lotka-Volterra system
                 gamma: float = 0.5, # The delta parameter of the Lotka-Volterra system
                 delta: float = 4.0  # The gamma parameter of the Lotka-Volterra system
                 ) -> None:
        super().__init__()
        self.model_params = torch.nn.Parameter(torch.tensor([alpha, beta, delta, gamma]))
        
        
    def forward(self, t, state):
        x = state[...,0]      #variables are part of vector array u 
        y = state[...,1]
        sol = torch.zeros_like(state)
        
        #coefficients are part of tensor model_params
        alpha, beta, delta, gamma = self.model_params    
        sol[...,0] = alpha*x - beta*x*y
        sol[...,1] = -delta*y + gamma*x*y
        return sol

In [None]:
lv_model = LotkaVolterra().to(device)
ts = torch.arange(0.0, 4.0, .08, device=device)
y0 = torch.tensor([10., 3.], device=device)
y_true = odeint(lv_model, y0, ts, method='dopri5').detach()
print(f'Dataset length: {y_true.shape[1]}')

plt.figure(figsize=(8, 2))
plt.scatter(ts.cpu().detach().numpy(), y_true[:,0].cpu().detach().numpy(), label='x', marker='.')
plt.scatter(ts.cpu().detach().numpy(), y_true[:,1].cpu().detach().numpy(), label='y', marker='.')
plt.xlabel('Time')
plt.ylabel('Population')
plt.legend()
plt.show()

In [None]:
# model = LotkaVolterra(alpha=1.0, beta=1.5, gamma=0.8, delta=3.0).to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.MSELoss()

# for epoch in range(100):
#     optimizer.zero_grad()
#     y_pred = odeint(model, y0, ts, method='dopri5')
#     loss = criterion(y_pred, y_true)
#     loss.backward()
#     optimizer.step()
#     if epoch % 10 == 0:
#         print(f'Epoch {epoch}, Loss {loss.item()}')

# # Get model parameters values
# alpha, beta, delta, gamma = model.model_params
# print(f'alpha: {alpha:.2f}, beta: {beta:.2f}, delta: {delta:.2f}, gamma: {gamma:.2f}')

In [None]:
class NeuralDiffeq(nn.Module):
    """
        Basic Neural ODE model
    """
    def __init__(self, dim: int = 2) -> None:
        super().__init__()
        self.func = nn.Sequential(
            nn.Linear(dim, 8),
            nn.Tanh(),
            nn.Linear(8, 16),
            nn.Tanh(),
            nn.Linear(16, 32),
            nn.Tanh(),
            nn.Linear(32, dim)
        )

    def forward(self, t, state):
        return self.func(state)

# Define the model, optimizer and loss function    
model = NeuralDiffeq(dim=2).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

In [None]:
total_loss = []
for epoch in range(2000):
    optimizer.zero_grad()
    y_pred = odeint(model, y0, ts, method='rk4')
    loss = criterion(y_pred, y_true)
    loss.backward()
    total_loss.append(loss.item())  
    optimizer.step()
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss {loss.item()}')

In [None]:
plt.figure(figsize=(8, 2))
plt.plot(total_loss)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')   
plt.title('Training loss')
plt.show()

In [None]:
# Plot the results
y_pred = odeint(model, y0, ts, method='dopri5').detach()
plt.figure(figsize=(8, 2))
plt.scatter(ts.cpu().detach().numpy(), y_true[:,0].cpu().detach().numpy(), label='x', marker='.')
plt.scatter(ts.cpu().detach().numpy(), y_true[:,1].cpu().detach().numpy(), label='y', marker='.')
plt.plot(ts.cpu().detach().numpy(), y_pred[:,0].cpu().detach().numpy(), label='x_pred')
plt.plot(ts.cpu().detach().numpy(), y_pred[:,1].cpu().detach().numpy(), label='y_pred')
plt.xlabel('Time')
plt.ylabel('Population')
plt.legend()
plt.show()