## PINN for predicting a SIR model
This is the base case of modeling a SIR model using a Feed Forward Neural Network.

In [1]:
# import libraries
import matplotlib
from matplotlib import cm
from matplotlib.ticker import LinearLocator
from scipy.interpolate import griddata
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import r2_score
import math

### Define time, initial conditions and parameters

In [2]:
# define amount of days being predicted and create time tensor
t_days = 500
t_tensor = torch.linspace(0, t_days, t_days+1, requires_grad = True).view(-1, 1)
t_0 = torch.tensor([0], dtype=torch.float32).view(-1,1)

#define initial conditions for SIR model
S_0 = 9999
I_0 = 1
R_0 = 0

#define model parameters
alpha = 1
beta = 1

In [3]:
def loss_ode(net, t):
    SIR = net(t)
    print(SIR.size())
    
    # Calculate derivates
    dSdt = torch.autograd.grad(SIR[:,0], t, grad_outputs=SIR[:,0], create_graph=True)[0]
    dIdt = torch.autograd.grad(SIR[:,1], t, grad_outputs=SIR[:,1], create_graph=True)[0]
    dRdt = torch.autograd.grad(SIR[:,2], t, grad_outputs=SIR[:,2], create_graph=True)[0]
    print(dSdt.size())
    
    # Calculate the residuals using vectorized operations
    S, I, R = SIR[:, 0], SIR[:, 1], SIR[:, 2]
    S_r = dSdt - (- beta * S * I)
    I_r = dIdt - (beta * S * I - alpha * I)
    R_r = dRdt - (alpha * I)
    
    # Combine the residuals into a single tensor and calculate the mean squared error
    residuals = torch.cat((S_r, I_r, R_r))
    loss_ode = torch.mean(residuals**2)
    return loss_ode


#loss function for initial conditoons of S, I and R
def loss_ic(net):
    SIR_t0 = net(t_0)
    loss_ic_squared = (SIR_t0[:, 0]-S_0)**2 + (SIR_t0[:, 1]-I_0)**2 + (SIR_t0[:, 2]-R_0)**2
    loss_ic = torch.sqrt(loss_ic_squared)/3
    return loss_ic

def loss_obs(net, t, SIR_obs):
    ## still has to be done##
    return loss_obs

In [4]:
# define network architecture
input_dim = 1
output_dim = 3
num_hidden = 50

class Net(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_dim, num_hidden),
            nn.ReLU(),
            #nn.Linear(num_hidden, num_hidden),
            #nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, output_dim),
        )
        
        # Apply Kaiming initialization to the layers
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, t):
        SIR = self.linear_relu_stack(t)
        return SIR
    
net = Net(num_hidden)
    
# hyperparameters
learning_rate = 1e-3
batch_size = 20
num_epochs = 50

#initialize lambdas for soft-adaptation
lambda_ode = torch.nn.Parameter(torch.tensor([1.0], requires_grad=True))
lambda_ic = torch.nn.Parameter(torch.tensor([1.0], requires_grad=True))
lambda_obs = torch.nn.Parameter(torch.tensor([1.0], requires_grad=True))

#optimizer: weights updates the net, ode and ic update the lambda for soft-adaptation
optimizer_weights = optim.Adam(net.parameters(), lr=learning_rate)
optimizer_ode = optim.Adam([lambda_ode], lr=learning_rate)
optimizer_ic = optim.Adam([lambda_ic], lr=learning_rate)
optimizer_obs = optim.Adam([lambda_obs], lr=learning_rate)

In [5]:
# train network

#getting epoch sizes
num_samples_train = t_tensor.shape[0]
num_batches_train = num_samples_train // batch_size

# setting up lists for handling loss/accuracy
train_acc, train_loss = [], []
cur_loss = 0
losses, ode_losses, ic_losses, obs_losses = [], [], [], []

get_slice = lambda i, size: range(i * size, (i + 1) * size) #get slices for each batch

for epoch in range(num_epochs):
    # Forward -> Backprob -> Update params
    ## Train
    cur_loss = 0
    ode_loss = 0
    ic_loss = 0
    obs_loss = 0
    net.train()
    for i in range(num_batches_train):
        # Zero the gradients for all optimizers
        optimizer_weights.zero_grad()
        optimizer_ode.zero_grad()
        optimizer_ic.zero_grad()
        
        slce = get_slice(i, batch_size)
        output = net(t_tensor[slce])

        # compute gradients given loss
        batch_loss_ode = loss_ode(net, t_tensor)
        batch_loss_ic = loss_ic(net)
        batch_loss = lambda_ode * batch_loss_ode + lambda_ic * batch_loss_ic + lambda_obs * batch_loss_obs # with soft adaptation
        batch_loss.backward()
        
        #maximize gradients of lambdas by inverting the gradient
        with torch.no_grad():
            lambda_ode.grad *= -1
            lambda_ic.grad *= -1
            lambda_obs.grad *= -1
        
        #update net and lambdas
        optimizer_weights.step()
        optimizer_ode.step()
        optimizer_ic.step()
        optimizer_obs.step()

        cur_loss += batch_loss.detach().numpy()
        ode_loss += batch_loss_ode.detach().numpy()
        ic_loss += batch_loss_ic.detach().numpy()
        obs_loss += batch_loss_ic.detach().numpy()
        
    losses.append(cur_loss / batch_size)
    ode_losses.append(ode_loss / batch_size)
    ic_losses.append(ic_loss / batch_size)
    obs_losses.append(ic_loss / batch_size)

    net.eval()

    if epoch % 5 == 0:
        print(
            f"epoch {epoch+1} : Total loss {np.round(losses[-1].item(), decimals=6)} , "
            f"ode loss: {np.round(ode_losses[-1].item(), decimals=6)} , "
            f"ode lambda: {np.round(lambda_ode.item(), decimals=3)} , "
            f"ic loss: {np.round(ic_losses[-1].item(), decimals=6)} , "
            f"ic lambda: {np.round(lambda_ic.item(), decimals=3)}"
    )
        
        
epoch = np.arange(len(losses))
plt.figure()
plt.plot(epoch, losses, 'r')
plt.plot(epoch, ode_losses, 'g')
plt.plot(epoch, ic_losses, 'b')
plt.legend(['Loss', 'ODE loss', 'IC loss'])
plt.xlabel('Updates'), plt.ylabel('Loss')

torch.Size([501, 3])
torch.Size([501, 1])


NameError: name 'batch_loss_obs' is not defined

## Visualization