# COVID-19 Epidemiology Modelling Tutorial

---

## Overview

This tutorial focuses on modelling the spread of disease during epidemics using Physics-Informed Neural Networks (PINNs). Modelling in epidemiology allows for oncoming spikes and dips in cases of a disease to be predicted. This information can be used to guide public health policy decisions.

We will start by applying a PINN to predict the motion of a damped harmonic oscillator (a spring or pendulum with friction or air resistance being acknowledged), building on the procedure described [here](https://benmoseley.blog/my-research/so-what-is-a-physics-informed-neural-network/). Then, you will apply the same method to COVID-19 data from the US in 2020 to try to predict the spikes and dips in COVID cases that occurred.

To start, run the following cell to import the necessary libraries:

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from IPython.display import clear_output

You won't be doing any coding for the first half of this tutorial, but run the following cell to get the continuous feedback system set up:

In [None]:
# @title
import torch
import torch.nn as nn
from IPython.core.magic import register_cell_magic

# Reference implementation
class ReferenceCOVID_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 64)
        self.fc5 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        x = torch.tanh(self.fc4(x))
        x = self.fc5(x)
        return x

# Function to perform checks
def perform_checks_covid_nn(student_model):
    torch.manual_seed(123)
    reference_model = ReferenceCOVID_NN()

    # Check if the models have the same architecture
    ref_layers = list(reference_model.children())
    student_layers = list(student_model.children())

    if len(ref_layers) != len(student_layers):
        print(f"❌ Model has incorrect number of layers: Expected {len(ref_layers)}, Got {len(student_layers)}")
        return

    for i, (ref_layer, student_layer) in enumerate(zip(ref_layers, student_layers)):
        if not isinstance(ref_layer, type(student_layer)):
            print(f"❌ Layer {i+1} is of incorrect type: Expected {type(ref_layer)}, Got {type(student_layer)}")
            return
        if isinstance(ref_layer, nn.Linear):
            if ref_layer.in_features != student_layer.in_features or ref_layer.out_features != student_layer.out_features:
                print(f"❌ Layer {i+1} has incorrect dimensions: Expected ({ref_layer.in_features}, {ref_layer.out_features}), Got ({student_layer.in_features}, {student_layer.out_features})")
                return

    # Check the forward pass with a random input
    test_input = torch.randn(1, 1)
    reference_output = reference_model(test_input)
    student_output = student_model(test_input)

    if torch.allclose(reference_output, student_output, atol=1e-5):
        print("✅ This model is correct")
    else:
        print(f"❌ Model produces incorrect output: Expected {reference_output.detach().numpy()}, Got {student_output.detach().numpy()}")

# Define the cell magic function
@register_cell_magic
def check_covid_nn(line, cell):
    # Execute the student's code and retrieve the model
    exec(cell, globals())

    # Check if 'COVID_NN' is defined in the global namespace
    if 'COVID_NN' not in globals():
        print("❌ COVID_NN class is not defined")
        return

    # Instantiate the student's model
    torch.manual_seed(123)
    student_model = COVID_NN()

    # Perform the checks on the model
    perform_checks_covid_nn(student_model)

# Function to check the model and optimiser
def check_model_and_optimiser():
    # Check if 'COVID_model' is defined and correct
    if 'COVID_model' not in globals():
        print("❌ COVID_model is not defined.")
        return
    if not isinstance(COVID_model, COVID_NN):
        print("❌ COVID_model is not an instance of the expected model class.")
        return
    print("✅ COVID_model is correctly initialized.")

    # Check if 'COVID_optimiser' is defined and correct
    if 'COVID_optimiser' not in globals():
        print("❌ COVID_optimiser is not defined.")
        return
    if not isinstance(COVID_optimiser, torch.optim.Adam):
        print("❌ COVID_optimiser is not an instance of torch.optim.Adam.")
        return
    if COVID_optimiser.param_groups[0]['lr'] != 1e-2:
        print(f"❌ Learning rate is incorrect: Expected 1e-2, Got {COVID_optimiser.param_groups[0]['lr']}")
        return

    # Check that the optimiser is using the correct parameters
    reference_params = list(COVID_model.parameters())
    student_params = COVID_optimiser.param_groups[0]['params']
    if not all(p1.shape == p2.shape for p1, p2 in zip(reference_params, student_params)):
        print("❌ Optimiser parameters do not match the model parameters.")
        return

    print("✅ COVID_optimiser is correctly initialised with the correct learning rate and parameters.")

# Define the cell magic function
@register_cell_magic
def check_model_and_optimiser_cell(line, cell):
    # Execute the student's code
    exec(cell, globals())
    # Run checks
    check_model_and_optimiser()

# Reference implementation for comparison
def reference_get_R(I):
    R = torch.zeros_like(I, requires_grad=True)
    end = np.floor(len(R)/2).astype(int)
    for count in range(1, end):
        ones = torch.zeros_like(I)
        ones[:len(R) - 2*count] = 1.0
        R = R + torch.roll(I * ones, 2*count)
    return R

# Function to check the student's implementation
def check_get_R():
    # Check if 'get_R' is defined
    if 'get_R' not in globals():
        print("❌ The function 'get_R' is not defined.")
        return

    # Check if 'get_R' is a function
    if not callable(get_R):
        print("❌ 'get_R' is not callable.")
        return

    # Test the function with a sample input
    test_input = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
    try:
        student_output = get_R(test_input)
    except Exception as e:
        print(f"❌ Error when running 'get_R': {e}")
        return

    # Check if the output is a tensor
    if not isinstance(student_output, torch.Tensor):
        print("❌ 'get_R' should return a tensor.")
        return

    # Check if the output requires gradients
    if not student_output.requires_grad:
        print("❌ Output tensor 'R' should require gradients.")
        return

    # Compare the student's output to the reference implementation
    reference_output = reference_get_R(test_input)
    if torch.allclose(student_output, reference_output, atol=1e-6):
        print("✅ 'get_R' produces the correct output.")
    else:
        print("❌ 'get_R' does not produce the correct output.")

# Define the cell magic function
@register_cell_magic
def check_get_R_cell(line, cell):
    # Execute the student's code
    exec(cell, globals())
    # Run checks
    check_get_R()

# Reference setup for comparison
torch.manual_seed(111)
reference_PINN_COVID_model = COVID_NN()
reference_optimiser = torch.optim.Adam(reference_PINN_COVID_model.parameters(), lr=1e-4)
reference_beta = torch.ones_like(t_physics, requires_grad=True)
reference_beta_optimiser = torch.optim.Adam([reference_beta], lr=5e-4)

# Check if the optimiser is using the correct model parameters
def check_optimiser_params(optimiser, reference_optimiser):
    student_params = list(optimiser.param_groups[0]['params'])
    reference_params = list(reference_optimiser.param_groups[0]['params'])

    if len(student_params) != len(reference_params):
        return False

    for student_param, reference_param in zip(student_params, reference_params):
        if not torch.equal(student_param, reference_param):
            return False
    return True

# Function to check the student's implementation
def check_PINN_COVID_setup():
    # Check if 'PINN_COVID_model' is defined
    if 'PINN_COVID_model' not in globals():
        print("❌ 'PINN_COVID_model' is not defined.")
        return

    # Check if 'PINN_COVID_model' is an instance of 'COVID_NN'
    if not isinstance(PINN_COVID_model, COVID_NN):
        print("❌ 'PINN_COVID_model' should be an instance of 'COVID_NN'.")
        return
    else:
        print("✅ 'PINN_COVID_model' is correctly initialized as an instance of 'COVID_NN'.")

    # Check if 'optimiser' is defined
    if 'optimiser' not in globals():
        print("❌ 'optimiser' is not defined.")
        return

    # Check if 'optimiser' is an instance of 'torch.optim.Adam'
    if not isinstance(optimiser, torch.optim.Adam):
        print("❌ 'optimiser' should be an instance of 'torch.optim.Adam'.")
        return

    # Check the learning rate of the optimiser
    student_lr = optimiser.param_groups[0]['lr']
    if student_lr == 1e-4:
        print("✅ 'optimiser' learning rate is correct.")
    else:
        print(f"❌ 'optimiser' learning rate is incorrect: Expected 1e-4, Got {student_lr}")

    # Check if the optimiser is using the correct model parameters
    if check_optimiser_params(optimiser, reference_optimiser):
        print("✅ 'optimiser' is correctly linked to 'PINN_COVID_model' parameters.")
    else:
        print("❌ 'optimiser' is not correctly linked to 'PINN_COVID_model' parameters.")

    # Check if 'beta' is defined
    if 'beta' not in globals():
        print("❌ 'beta' is not defined.")
        return

    # Check if 'beta_optimiser' is defined
    if 'beta_optimiser' not in globals():
        print("❌ 'beta_optimiser' is not defined.")
        return

    # Check if 'beta_optimiser' is an instance of 'torch.optim.Adam'
    if not isinstance(beta_optimiser, torch.optim.Adam):
        print("❌ 'beta_optimiser' should be an instance of 'torch.optim.Adam'.")
        return

    # Check the learning rate of the beta optimiser
    student_beta_lr = beta_optimiser.param_groups[0]['lr']
    if student_beta_lr == 5e-4:
        print("✅ 'beta_optimiser' learning rate is correct.")
    else:
        print(f"❌ 'beta_optimiser' learning rate is incorrect: Expected 5e-4, Got {student_beta_lr}")

    # Check if the beta optimiser is linked to 'beta'
    if check_optimiser_params(beta_optimiser, reference_beta_optimiser):
        print("✅ 'optimiser' is correctly linked to 'PINN_COVID_model' parameters.")
    else:
        print("❌ 'optimiser' is not correctly linked to 'PINN_COVID_model' parameters.")


# Define the cell magic function
@register_cell_magic
def check_PINN_COVID_setup_cell(line, cell):
    # Execute the student's code
    exec(cell, globals())
    # Run checks
    check_PINN_COVID_setup()

In order to run a neural network, we need a dataset to train and make predictions on. We will use the analytic solution of the damped harmonic oscillator for this purpose, training our neural network on the initial motion of the oscillator then testing it by asking it to predict the trajectory after the training period.

The following cell sets up the required data and plots the trajectory along with the points along the trajectory that the neural network will be trained on:

In [None]:
def oscillator(d, w0, t):
    assert d < w0
    w = np.sqrt(w0**2-d**2)
    phi = np.arctan(-d/w)
    A = 1/(2*np.cos(phi))
    cos = torch.cos(phi+w*t)
    sin = torch.sin(phi+w*t)
    exp = torch.exp(-d*t)
    y  = exp*2*A*cos
    return y

d, w0 = 2, 20

# get the analytical solution over the full domain
t = torch.linspace(0,1,500).view(-1,1)
x = oscillator(d, w0, t).view(-1,1)

# slice out a small number of points from the LHS of the domain
t_data = t[0:300:25]
x_data = x[0:300:25]

plt.figure()
plt.plot(t, x, label="Exact solution")
plt.xlabel('time t')
plt.ylabel('position x')
plt.scatter(t_data, x_data, color="tab:orange", label="Training data")
plt.legend()
plt.show()

Our aim is to train a neural network using the orange points, in hopes that it will output the entire blue curve as the trajectory after it has been trained. To do this, we will create a simple 'feed forward' neural network with one input (a time value) and one output (a position value) using PyTorch in the following cell:

In [None]:
class ff_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1,32)
        self.fc2 = nn.Linear(32,32)
        self.fc3 = nn.Linear(32,32)
        self.fc4 = nn.Linear(32,32)
        self.fc5 = nn.Linear(32, 1)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        x = torch.tanh(self.fc4(x))
        x = self.fc5(x)
        return x

Now, we will train the model on the first portion of the analytic trajectory using the mean squared error as the loss function and the Adam optimiser algorithm.

In [None]:
torch.manual_seed(123) # This is to ensure the same random numbers are used each run, so the results are reproducible
model = ff_NN()
optimiser = torch.optim.Adam(model.parameters(),lr=1e-3)

for i in range(2500):
    optimiser.zero_grad()
    x_nn = model(t_data)
    loss = torch.mean((x_nn-x_data)**2)
    loss.backward()
    optimiser.step()

    # plot the result as training progresses
    if (i+1) % 100 == 0 or i == 0:
        clear_output(wait=True)
        x_result = model(t).detach()
        plt.figure(figsize=(8,4))
        plt.plot(t,x, color="grey", linewidth=2, alpha=0.8, label="Exact solution")
        plt.plot(t,x_result, color="tab:blue", linewidth=2, alpha=0.8, label="Neural network prediction")
        plt.scatter(t_data, x_data, s=60, color="tab:orange", alpha=0.4, label='Training data')
        l = plt.legend(loc=(1.01,0.34), frameon=False, fontsize="large")
        plt.setp(l.get_texts(), color="k")
        plt.xlim(-0.05, 1.05)
        plt.ylim(-1.1, 1.1)
        plt.text(1.065,0.7,"Training step: %i"%(i+1),fontsize="xx-large",color="k")
        plt.axis("off")
        plt.show()

Clearly, the neural network correctly learned to interpolate the trajectory of the oscillator from the sampled points in the period on which it was trained. However, the neural network's prediction for the trajectory of the oscillator after this period is incorrect and physically nonsensical.


## Introducing PINNs
This is where PINNs come in. If we know some equations which physical reasoning dictates the system should obey, we can include the extent to which these equations are violated in the loss function used during training. Then, the neural network's output should not only be able to interpolate the trajectory in the period on which it was trained, but it should be able to make a physically-motivated guess about the subsequent motion by combining interpolation of training datapoints with the 'physical intuition' which we have granted it.

For a damped harmonic oscillator, the equation we can impose of this type is a second order differential equation which is simply $F=ma$ for this system:

\begin{equation}
m\frac{d^2x}{{dt}^2} +\mu \frac{dx}{dt}+kx=0
\end{equation}

$\mu$ and $k$ are constants that define the motion of the system, but we can learn their values as well as the correct weights of the neural network during training on the initial motion of the oscillator:

In [None]:
# Assemble a set of points across the entire period of time that we want to solve the trajectory for
t_physics = torch.linspace(0,1,30).view(-1,1).requires_grad_(True)

# Initialise the model and parameters
torch.manual_seed(123) # This is to ensure the same random numbers are used each run, so the results are reproducible
PINN_model = ff_NN()
mu = torch.tensor([5.0], requires_grad=True)
k = torch.tensor([360.0], requires_grad=True)
PINN_optimiser = torch.optim.Adam(PINN_model.parameters(),lr=5e-4)
params_optimiser = torch.optim.Adam([mu,k],lr=3e-3)

# Run training loop
for i in range(30000):
    PINN_optimiser.zero_grad()
    params_optimiser.zero_grad()

    # compute the "data loss"
    x_nn = PINN_model(t_data)
    loss1 = torch.mean((x_nn-x_data)**2)# use mean squared error

    # compute the "physics loss"
    x_nn_physics = PINN_model(t_physics)
    dx  = torch.autograd.grad(x_nn_physics, t_physics, torch.ones_like(x_nn_physics), create_graph=True)[0]# computes dy/dx
    dx2 = torch.autograd.grad(dx,  t_physics, torch.ones_like(dx),  create_graph=True)[0]# computes d^2y/dx^2
    physics = dx2 + mu*dx + k*x_nn_physics # computes the residual of the 1D harmonic oscillator differential equation
    loss2 = (1e-4)*torch.mean(physics**2)

    # backpropagate joint loss
    loss = loss1 + loss2# add two loss terms together
    loss.backward()
    PINN_optimiser.step()
    # only start optimising mu and k after the interpolation has begun
    if i > 5000:
        params_optimiser.step()


    # plot the result as training progresses
    if (i+1) % 500 == 0 or i == 0:
        clear_output(wait=True)
        x_results = PINN_model(t).detach()
        plt.figure(figsize=(8,4))
        plt.plot(t.detach().numpy(),x, color="grey", linewidth=2, alpha=0.8, label="Exact solution")
        plt.plot(t.detach().numpy(),x_results, color="tab:blue", linewidth=2, alpha=0.8, label="Neural network prediction")
        plt.scatter(t_data.detach().numpy(), x_data, s=60, color="tab:orange", alpha=0.4, label='Training data')
        l = plt.legend(loc=(1.01,0.34), frameon=False, fontsize="large")
        plt.setp(l.get_texts(), color="k")
        plt.xlim(-0.05, 1.05)
        plt.ylim(-1.1, 1.1)
        plt.text(1.065,0.7,"Training step: %i"%(i+1),fontsize="xx-large",color="k")
        plt.axis("off")
        plt.show()
        print(f'mu={mu.item()}, k={k.item()}')

Clearly, the PINN framework is much better at predicting the trajectory of the oscillator. Further, the correct values of the parameters in the model were successfully learned to within about a 1% error!

All of this is promising - we want to make predictions about the dynamics of an epidemic and we have just seen that PINNs can combine information from past data with models for how we expect the system to react to get good predictions.

## Predicting COVID-19 Trends

Before we jump in applying the PINN framework to the COVID-19 prediction problem, let's first see how well a simple neural network trained to fit the data does in predicting trends.

The following cell imports US national statistics on the COVID-19 pandemic from a [New York Times dataset](https://github.com/nytimes/covid-19-data/blob/master/us.csv):

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

url = 'https://raw.githubusercontent.com/nytimes/covid-19-data/master/us.csv'
# Load the CSV file into a DataFrame
df = pd.read_csv(url)

# Calculate the 'Currently Infected' count
df['Currently Infected'] = df['cases'] - df['deaths']
# Calculate the rolling difference over the previous ten days
df['Infected'] = df['Currently Infected'].diff(periods=10)
# Calculate recovered count:
df['Recovered'] = df['Currently Infected'].diff(periods=10)

x = np.array(df['Infected'][60:300]).astype(np.float32)/1e6
t = np.arange(len(df))[60:300]
t = (t/239).astype(np.float32)

x = torch.from_numpy(x).view(-1,1)
t = torch.from_numpy(t).view(-1,1)

# slice out a small number of points from the LHS of the domain
t_data = t[0:110:5]
x_data = x[0:110:5]

plt.figure()
plt.plot(t, x, label="Exact solution")
plt.scatter(t_data, x_data, color="tab:orange", label="Training data")
plt.ylabel('Number of people with COVID (Millions)')
plt.legend()
plt.show()

Again, our aim will be to train a simple feed-forward neural network with tanh as the activation function on the orange datapoints. Then, we will ask the neural network for a prediction across the entire range of time plotted above.

Set up a new neural network class like the ```ff_NN()``` class defined above, but with 64 nodes per layer.

In [None]:
%%check_covid_nn # Leave this here - it will give you feedback on your code when you run the cell

class COVID_NN(nn.Module):
    # Put your code here:

Now, initialise the model, define the optimiser and write the code for the training loop. Use the Adam optimiser with learning rate of 0.01.

In [None]:
%%check_model_and_optimiser_cell # Leave this here - it will give you feedback on your code when you run the cell

torch.manual_seed(123) # This is to ensure the same random numbers are used each run, so the results are reproducible
COVID_model =  # Initialise the model
COVID_optimiser =  # Define the optimiser

In [None]:
# Training loop
for i in range(5000):
    # Zero the gradient stored in the optimiser
    x_nn = # Call the neural network on t_data
    loss =  # Use the mean squared error as the loss
    # Get gradient of loss
    # Take a step with the optimiser

    # plot the result as training progresses
    if (i+1) % 1000 == 0 or i == 0:
        clear_output(wait=True)
        x_results = COVID_model(t).detach()
        plt.figure(figsize=(6,4))
        plt.plot(t.detach().numpy(),x, color="grey", linewidth=2, alpha=0.8, label="Exact solution")
        plt.plot(t.detach().numpy(),x_results, color="tab:blue", linewidth=2, alpha=0.8, label="Neural network prediction")
        plt.scatter(t_data.detach().numpy(), x_data, s=60, color="tab:orange", alpha=0.4, label='Training data')
        l = plt.legend(loc=(1.01,0.34), frameon=False, fontsize="large")
        plt.setp(l.get_texts(), color="k")
        plt.xlim(0.2, 1.3)
        plt.ylim(0, 1.4)
        plt.text(1.3,1.0,"Training step: %i"%(i+1),fontsize="xx-large",color="k")
        plt.axis("off")
        plt.show()

Just like in the case of the oscillator, this naive approch is successful in interpolating the behaviour between the training datapoints, but is not of any real value for making predictions about the future.

To apply a PINN, we need an analogue of the $F=ma$ differential equation for this system. There are plenty of models for epidemiology that can supply us with differential equations to use for this purpose - for this reason, several academic papers applying different epidemiological models in this way have been written. We will use the simplest model, which is known as the SIR model.

SIR stands for Susceptible, Infected, Removed - those are the three classes that we split the population up into in this model. The susceptible group is the set of all people who haven't had the virus yet, the infected group is the set of people who are currently infected by the virus so are transmitters of it, and the removed group is the set of all people who can no longer be infected or transmit the virus (i.e. the set of people who recovered from or died due to the virus). We define the number of people in each group as $S(t)$, $I(t)$, and $R(t)$. Assuming a very large population (a valid assumption for our dataset), the population is approximately constant, so $S(t)+I(t)+R(t)=N$.

Taking $\gamma$ as the rate of recovery or death - about 0.1 days$^{-1}$ in the case of COVID-19, and $\beta$ as the transmission rate, the movement of people from one group to another can be described by the following differential equations:

$$\frac{dS}{dt}=-\frac{\beta S I}{N}$$
$$\frac{dI}{dt}=\frac{\beta S I}{N} - \gamma I$$
$$\frac{dR}{dt}= \gamma I$$

This simply says that rate of change changes in the number of infected people is proportional to the number of infected people times the fraction of the population that is susceptible minus the rate of recovery. We are only learning and predicting $I(t)$, so we would like to capture this information in a differential equation in terms of $I(t)$ only. We can do this by taking the first and last of these differential equations as definitions of the system.

Using $S(t)+I(t)+R(t)=N$ and $\frac{dR}{dt}= \gamma I$, we get
$$S(t) = N- I(t)-\gamma\int\limits_{0}^{t}I(t')dt'$$
Subbing this into the second differential equation, we get a condition for the system to obey if it satisfies the SIR model:
$$ \frac{\beta I(t)}{N}\left(N-I(t)-\gamma\int\limits_{0}^{t}I(t')dt'\right) - \gamma I(t) - \frac{dI}{dt}=0$$

Using the PINN framework, the left hand side of this expression will be included in our loss function.

The integral for $R$ can be done using our data by simply adding together the entries in the array of $I$ values that came in increments of ten steps before. This means that for the first ten days, $R(t)=0$, on day eleven $R=I(\text{day 1})$, and in general on day $n$ where $n>10$,
$$R(t)=\sum\limits_{i=1}^{\text{floor}(n/10)}I(\text{day }n-10i)$$
Write a definition for a function that does this below, taking every fifth day's $I(t)$ value as an input and outputting a PyTorch tensor that contains the value of $R(t)$ on the same days. The process must be capable of undergoing backpropogation for gradient descent when optimising the neural network, so you can't add slices to tensors, but you can get around this by using the product of a tensor of ones and zeros with the tensor that you want to take a slice of. You may also find the ```torch.roll()``` function useful for this function.

(This is quite a tricky one, so in case you get stuck the answer is in the hidden code in the cell that follows.)

In [None]:
%%check_get_R_cell # Leave this here - it will give you feedback on your code when you run the cell
def get_R(I):
    R = torch.zeros_like(I,requires_grad=True)
    # Start coding here:

    return R

In [None]:
# @title
# My solution is:

#def get_R(I):
#    R = torch.zeros_like(I,requires_grad=True)
#    end = np.floor(len(R)/2).astype(int)
#    for count in range(1,end):
#        ones = torch.zeros_like(I)
#        ones[:len(R)-2*count] = 1.0
#        R = R + torch.roll(I*ones, 2*count)
#    return R

Now, we will use the function you just defined and the PINN framework to make predictions on the dynamics of COVID-19. Fill in the required fields to set up the training process and run the training loop. Use the Adam optimiser for both sets of parameters, with learning rate of 0.0001 for the neural network and 0.0005 for the values of beta at each point of time.

In [None]:
%%check_PINN_COVID_setup_cell # Leave this here - it will give you feedback on your code when you run the cell

t_physics = t[::5].requires_grad_(True) # These are the points in time where we will evaluate the physics loss
g = 23.9 # This is the value of gamma that you get by rescaling 1/(10 days) with 239 days = 1 unit

torch.manual_seed(111) # This is to ensure the same random numbers are used each run, so the results are reproducible
PINN_COVID_model =
optimiser =
beta = torch.ones_like(t_physics, requires_grad=True)
beta_optimiser =

In [None]:
# Training loop:
for i in range(9500):
    # Zero the gradient stored in the optimiser

    # compute the "data loss"
    x_nn = # Pass t_data to the model
    loss1 = # Use mean squared error

    # compute the "physics loss"
    x_nn_physics = # Pass t_physics to the model
    dx = # Compute dy/dx
    physics = # Compute the residual of the equation
    loss2 = (1e-4)*torch.mean(physics**2)

    # backpropagate the joint loss - no need to wait for a certain number of steps before starting to optimise beta
    loss = # add the two loss terms together
    # Get gradient of loss
    # Take a step with the neural network optimiser
    # Take a step with the beta optimiser

    # plot the result as training progresses
    if (i+1) % 500 == 0:
        clear_output(wait=True)
        x_results = PINN_COVID_model(t).detach()
        plt.figure(figsize=(6,4))
        plt.plot(t.detach().numpy(),x, color="grey", linewidth=2, alpha=0.8, label="Exact solution")
        plt.plot(t.detach().numpy(),x_results, color="tab:blue", linewidth=2, alpha=0.8, label="Neural network prediction")
        plt.scatter(t_data.detach().numpy(), x_data, s=60, color="tab:orange", alpha=0.4, label='Training data')
        l = plt.legend(loc=(1.01,0.34), frameon=False, fontsize="large")
        plt.setp(l.get_texts(), color="k")
        plt.xlim(0.2, 1.3)
        plt.ylim(0, 1.4)
        plt.text(1.3,1,"Training step: %i"%(i+1),fontsize="xx-large",color="k")
        plt.axis("off")
        plt.show()

Although the PINN clearly didn't manage to capture the spike at the end of the time period we're looking at, it was much more successful than the 'physics-less' neural network in predicting that there would be a peak and subsequent fall-off after the end of the training period.

We trained the model for long enough for it to fit the data and had we ran it for longer, it's possible that the output would change and no longer be quite so correct. This may seem like we're cheating and getting the correct answer because we know what it is. In a sense it's true that knowing the correct prediction is very helpful in knowing when to stop training, but this is a general flaw of using neural networks for this type of predictive task. The question of when to stop training when working with neural networks is more of an art than an exact science; train too long and you will 'overfit' the data (i.e. fit the training data really well, but not generalise well beyond the training set), but train too little and you risk losing important information hidden in the training set.

The problem would be much simpler to solve if we weren't constricted to data on just one COVID-19 epidemic. If we had data on thousands of different epidemics, we could abandon the PINN idea and completely change the architecture of the neural network, so that its job is to predict the cases during the second half of the time period based on knowledge of the first half of the period. In this case, overfitting would be much easier to monitor without ever checking the output of the model for the epidemic that we want to find out information about.

Using more detailed models of the epidemic dynamics such as the SAIRD model (which differentiates between symptomatic and assymptomatic, recovered and dead) could help make more accurate, longer term predictions with our PINN - although we did adjust the SIR model by giving $\beta$ the ability to vary with time, it is still quite a crude model of a rather complex system. With that said, there is a chance it could be improved by simply making the time dependence of $\beta$ more informed by including information on current affairs such as changes in public health policy in the dataset (using newspaper headlines perhaps).

As mentioned earlier, there are several academic papers and preprints out there on the topic of using PINNs to retrospectively predict COVID-19 spikes and dips - here are some links if you feel like having a read or trying to recreate their results:

[Approaching epidemiological dynamics of COVID-19 with physics-informed neural networks](https://arxiv.org/abs/2302.08796)

PINN Training using Biobjective Optimization:
The Trade-off between Data Loss and Residual Loss - [Github](https://git.uni-wuppertal.de/heldmann/covid-prediction-pinn) & [ArXiv](https://arxiv.org/abs/2302.01810)