In [1]:
import numpy as np
import pandas as pd
from scipy.stats import truncnorm

In [7]:
# SIR model parameters
contact_rate = 0.3
recovery_rate = 0.1
total_pop = 1000

# Initial states
inf_init = 12
suscept_init = total_pop - inf_init
rec_init = 0

# Observations
observations = [12, 157, 982, 532, 89]

# Number of particles
n_particles = 100

# Updating particles
def update_particles(particles, contact_rate, recovery_rate, total_pop):
    suscept, infect, recovered = particles[:, 0], particles[:, 1], particles[:, 2]
    new_infections = np.random.binomial(suscept.astype(int), 1.0 - np.exp(-contact_rate * infect / total_pop))
    new_recoveries = np.random.binomial(infect.astype(int), 1.0 - np.exp(-recovery_rate))
    suscept -= new_infections
    infect += new_infections - new_recoveries
    recovered += new_recoveries
    particles[:, 0], particles[:, 1], particles[:, 2] = suscept, infect, recovered
    return particles

# Calculation of importance weights
target_sd = 50.0
def get_importance(p_vals, mean):
    zero_trunc_vals = -p_vals / target_sd
    target = np.array([mean] * n_particles)
    return truncnorm.pdf(target, zero_trunc_vals, np.inf, loc=p_vals, scale=target_sd)

# Initialise particles
old_particles = np.zeros((n_particles, 3))
old_particles[:, 0] = suscept_init
old_particles[:, 1] = inf_init
old_particles[:, 2] = rec_init

# Main loop
for obs in observations:
    
    # Prediction
    new_particles = update_particles(old_particles, contact_rate, recovery_rate, total_pop)

    # Importance
    weights = get_importance(new_particles[:, 1], obs)
    norm_weights = weights / sum(weights)
    
    # Resampling
    indices = np.random.choice(range(n_particles), size=n_particles, p=norm_weights)
    new_particles = new_particles[indices]
    
    # State estimation
    suscept_est = np.mean(new_particles[:, 0])
    infect_est = np.mean(new_particles[:, 1])
    rec_est = np.mean(new_particles[:, 2])
    
    print(f"Observation: {obs}")
    print(f"Estimated Susceptible: {suscept_est:.2f}")
    print(f"Estimated Infectious: {infect_est:.2f}")
    print(f"Estimated Recovered: {rec_est:.2f}\n")

    old_particles = new_particles

Observation: 12
Estimated Susceptible: 984.77
Estimated Infectious: 13.94
Estimated Recovered: 1.29

Observation: 157
Estimated Susceptible: 980.15
Estimated Infectious: 17.47
Estimated Recovered: 2.38

Observation: 982
Estimated Susceptible: 960.85
Estimated Infectious: 34.78
Estimated Recovered: 4.37

Observation: 532
Estimated Susceptible: 944.44
Estimated Infectious: 47.23
Estimated Recovered: 8.33

Observation: 89
Estimated Susceptible: 930.82
Estimated Infectious: 55.88
Estimated Recovered: 13.30

