# Bayesian SIR Model with Change Points

In [1]:
import numpy as np
from scipy.stats import binom, gamma, beta, expon, poisson, uniform, bernoulli
from joblib import Parallel, delayed
from scipy.special import gamma as gammaFunc
import random

import pandas as pd

import matplotlib.pyplot as plt

import tqdm

import itertools

## Data Simulation

In [2]:
def sim_dataset(chg_pt, scenarios, T, S0, I0, R0, n_datasets):

    N = S0 + I0 + R0
    n_sc = scenarios.shape[0]

    # create array of transmission and removal rate parameters at each time step
    beta  = np.array([[scenarios[i,(chg_pt <= t+1).sum(),0] for t in range(T)] for i in range(n_sc)])
    gamma = np.array([[scenarios[i,(chg_pt <= t+1).sum(),1] for t in range(T)] for i in range(n_sc)])


    Delta_I = np.zeros(shape=(n_datasets, n_sc, T), dtype=np.int32)
    Delta_R = np.zeros(shape=(n_datasets, n_sc, T), dtype=np.int32)
    S       = np.zeros(shape=(n_datasets, n_sc, T), dtype=np.int32)
    I       = np.zeros(shape=(n_datasets, n_sc, T), dtype=np.int32)
    R       = np.zeros(shape=(n_datasets, n_sc, T), dtype=np.int32)

    Delta_I[:,:,0] = binom.rvs(S0, 1-np.exp(-beta[:,0]*I0/N), size=(n_datasets, n_sc))
    Delta_R[:,:,0] = binom.rvs(I0, gamma[:,0], size=(n_datasets, n_sc))
    S[:,:,0]       = S0 - Delta_I[:,:,0]
    I[:,:,0]       = I0 + Delta_I[:,:,0] - Delta_R[:,:,0]
    R[:,:,0]       = R0 + Delta_R[:,:,0]

    for t in range(1, T):
        Delta_I[:,:,t] = binom.rvs(S[:,:,t-1], 1-np.exp(-beta[:,t]*I[:,:,t-1]/N))
        Delta_R[:,:,t] = binom.rvs(I[:,:,t-1], gamma[:,t])
        S[:,:,t]       = S[:,:,t-1] - Delta_I[:,:,t]
        I[:,:,t]       = I[:,:,t-1] + Delta_I[:,:,t] - Delta_R[:,:,t]
        R[:,:,t]       = R[:,:,t-1] + Delta_R[:,:,t]

    return Delta_I, Delta_R, S, I, R

In [3]:
def plot_SRI(S, I, R, sc=0, d=None, start_cond=(999_950, 50, 0), tot=1_000_000, time=100):

    if d is None:
        S = np.expand_dims(np.mean(S, axis=0), 0)
        I = np.expand_dims(np.mean(I, axis=0), 0)
        R = np.expand_dims(np.mean(R, axis=0), 0)
        d = 0

    S = np.concatenate([[start_cond[0]], S[d,sc]])
    I = np.concatenate([[start_cond[1]], I[d,sc]])
    R = np.concatenate([[start_cond[2]], R[d,sc]])

    # plot
    fig, ax = plt.subplots()

    y = np.vstack([S, I, R])
    ax.stackplot(np.arange(time+1), y/tot, labels=["S","I","R"], alpha=0.8)

    ax.set_xlabel("Day")
    ax.set_ylabel("Proportion")
    ax.set_xticks(np.concatenate([[0], np.arange(25, time+1, 25)]))
    ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    plt.legend(loc="upper right")

    plt.show()

### Example

In [4]:
T = 100
N = 1_000_000

chg_pt = np.array([26, 51, 76])

sc_1 = [(0.3, 0.05), (0.4, 0.15), (0.25, 0.2),  (0.2,  0.25)]
sc_2 = [(0.4, 0.1),  (0.4, 0.25), (0.25, 0.25), (0.25, 0.4) ]
sc_3 = [(0.5, 0.1),  (0.3, 0.3),  (0.4,  0.2),  (0.2,  0.4) ]
scenarios = np.array([sc_1, sc_2, sc_3])

S0 = N-50
I0 = 50
R0 = 0

n_datasets = 100

In [5]:
Delta_I, Delta_R, S, I, R =  sim_dataset(chg_pt, scenarios, T, S0, I0, R0, n_datasets)
#for i in range(len(scenarios)):
 #   print(f"Scenario {i+1}:")
#    plot_SRI(S, I, R, sc=i)

## Gibbs Sampling

In [6]:
S_obs = np.concatenate(([S0],S[0,0]))
I_obs = np.concatenate(([I0],I[0,0]))
R_obs = np.concatenate(([R0],R[0,0]))
PI_obs = I_obs/N
N_infect_obs = np.concatenate(([I0],Delta_I[0,0]))
N_recovery_obs = np.concatenate(([R0],Delta_R[0,0]))

data = pd.DataFrame({
        'susceptible': S_obs,
        'infects': I_obs,
        'recovered': R_obs,
        'PI': PI_obs,
        'deltaI': N_infect_obs,
        'deltaR': N_recovery_obs
    })

In [7]:
data

Unnamed: 0,susceptible,infects,recovered,PI,deltaI,deltaR
0,999950,50,0,0.000050,50,0
1,999930,68,2,0.000068,20,2
2,999913,81,6,0.000081,17,4
3,999882,107,11,0.000107,31,5
4,999848,137,15,0.000137,34,4
...,...,...,...,...,...,...
96,156618,22,843360,0.000022,0,6
97,156617,17,843366,0.000017,1,6
98,156615,14,843371,0.000014,2,5
99,156615,11,843374,0.000011,0,3


In [8]:
def calculate_contact_hat_parallel(pp_lambda_t):
    pp, lambda_t = pp_lambda_t
    return np.sum(poisson.ppf(pp, lambda_t))

def gibbs_sampling(data, samples=10000, T_max=100, burnin=1000, thinning=10):
    
    ############# data
    # Data is expected to be a dataframe with 101 rows (100 steps + initial one)
    # OBS: the initial row should display the starting values 
    #      for Susceptible (S), Infected (I), and Recovered (R), with zero values in the 
    #      columns representing changes (deltas) over time.
    
    I_obs = data["infects"].values
    S_obs = data["susceptible"].values
    PI_obs = data["PI"].values
    N_infect_obs = data["deltaI"].values[1:]
    N_recovery_obs = data["deltaR"].values[1:]
    

    #######################
    ##  HYPERPARAMETERS  ##
    #######################
    
    p_a = 1/T_max**4
    p_b = 2 - p_a
    
    b_shape = 0.1
    b_rate = 0.1
    r_shape = 0.1
    r_rate = 0.1
    
    gamma_b_shape = gammaFunc(b_shape)
    gamma_r_shape = gammaFunc(r_shape)
    

    ######################
    ##  INITIALIZATION  ##
    ######################
    
    ##### delta
    Delta_hat = np.zeros(shape=T_max, dtype=int)
    Delta_hat[0] = 1
    Stage_hat = np.cumsum(Delta_hat, dtype=int)-1
    K_hat     = np.sum(Delta_hat, dtype=int)

    ##### b and r
    b_hat    = gamma.rvs(a=b_shape, scale=1/b_rate, size=K_hat)
    r_hat    = gamma.rvs(a=r_shape, scale=1/r_rate, size=K_hat)
    
    ##### beta and gamma
    beta_hat    = expon.rvs(scale=1/b_hat[Stage_hat[0]], size=T_max)
    gamma_hat   = beta.rvs(a=r_hat[Stage_hat[0]], b=1, size=T_max)
    lambda_t    = beta_hat*PI_obs[:-1]
    p_upper     = 1 - poisson.cdf(0, lambda_t)
    pp          = [uniform.rvs(size=N_infect_obs[t]) * p_upper[t] + (1 - p_upper[t]) for t in range(T_max)]
    contact_hat = np.array(Parallel(n_jobs=-1)(delayed(calculate_contact_hat_parallel)((pp[t], lambda_t[t],)) for t in range(T_max)))
    
    
    beta_hat    = gamma.rvs(a=contact_hat + 1, scale=1/(b_hat[Stage_hat] + PI_obs[:-1]*S_obs[:-1]))
    gamma_hat   = beta.rvs(a=N_recovery_obs + r_hat[Stage_hat], b=1 + I_obs[:-1] - N_recovery_obs)
    lambda_t    = beta_hat * PI_obs[:-1]
    p_upper     = 1 - poisson.cdf(0, lambda_t)
    pp          = [uniform.rvs(size=N_infect_obs[t]) * p_upper[t] + (1 - p_upper[t]) for t in range(T_max)]
    contact_hat = np.array(Parallel(n_jobs=-1)(delayed(calculate_contact_hat_parallel)((pp[t], lambda_t[t],)) for t in range(T_max)))
    
    
    print("Initialization:\n")
    print("Delta_hat:",Delta_hat)
    print("b_hat:",b_hat)
    print("r_hat:",r_hat)
    print("beta_hat:",beta_hat)
    print("gamma_hat:",gamma_hat)
    print("----------------------------------------------------")
    print("----------------------------------------------------")
    

    ################
    ##  SAMPLING  ##
    ################
    
    Delta_all = []
    Stage_all = []
    b_all     = []
    r_all     = []
    beta_all  = []
    gamma_all = []
    
    for step in tqdm.tqdm(range(samples)):
        
        ###### delta_hat sampling
        ## add–delete–swap - proposal step
        # -1 delete
        # 0 swap
        # +1 add
                
        if K_hat==1:
            change_type = 1
        
        elif K_hat==T_max:
            change_type = -1

        else:
            change_type = np.random.choice([-1, 0, 1])
            
        Delta_hat_candidate = Delta_hat.copy()
        print(change_type)
        if change_type != 0:
            
            if change_type == 1:
                possible_change_indices = np.where(Delta_hat[1:] == 0)[0]+1
            elif change_type == -1:
                possible_change_indices = np.where(Delta_hat[1:] == 1)[0]+1
            index_to_change = np.random.choice(possible_change_indices)

            Delta_hat_candidate[index_to_change] = 1 - Delta_hat_candidate[index_to_change]
            Stage_hat_candidate = np.cumsum(Delta_hat_candidate)-1

            if change_type == 1:
                possible_change_indices_candidate = np.where(Delta_hat_candidate[1:] == 1)[0] + 1
                phase_original  = np.array([Stage_hat[index_to_change]])
                phase_candidate = np.array([Stage_hat[index_to_change], Stage_hat_candidate[index_to_change]])
            elif change_type == -1:
                possible_change_indices_candidate = np.where(Delta_hat_candidate[1:] == 0)[0] + 1
                phase_original  = np.array([Stage_hat[index_to_change], Stage_hat_candidate[index_to_change]])
                phase_candidate = np.array([Stage_hat_candidate[index_to_change]])
        
            logp_original = 0
            logp_candidate = 0

            for i in phase_original:
                L_i_original = np.where(Stage_hat == i)[0]

                logp_original += np.sum(np.log(np.arange(1, len(L_i_original)) - 1 + b_shape)) \
                               - np.log(gamma_b_shape) + b_shape * np.log(b_rate) \
                               - (b_shape + len(L_i_original)) * np.log(b_rate + np.sum(beta_hat[L_i_original]))
                logp_original += np.sum(np.log(np.arange(1, len(L_i_original)) - 1 + r_shape)) \
                               - np.log(gamma_r_shape) + r_shape * np.log(r_rate) \
                               - (r_shape + len(L_i_original)) * np.log(r_rate + np.sum(-np.log(gamma_hat[L_i_original])))
                logp_original -= np.sum(np.log(np.arange(1, len(L_i_original))))

            for i in phase_candidate:
                L_i_candidate = np.where(Stage_hat_candidate == i)[0]

                logp_candidate += np.sum(np.log(np.arange(1, len(L_i_candidate)) - 1 + b_shape)) \
                                - np.log(gamma_b_shape) + b_shape * np.log(b_rate) \
                                - (b_shape + len(L_i_candidate)) * np.log(b_rate + np.sum(beta_hat[L_i_candidate]))
                logp_candidate += np.sum(np.log(np.arange(1, len(L_i_candidate)) - 1 + r_shape)) \
                                - np.log(gamma_r_shape) + r_shape * np.log(r_rate) \
                                - (r_shape + len(L_i_candidate)) * np.log(r_rate + np.sum(-np.log(gamma_hat[L_i_candidate])))
                logp_candidate -= np.sum(np.log(np.arange(1, len(L_i_candidate))))

            logp_candidate += np.log(p_a / p_b) + np.log((3 - 2 * (K_hat == 1)) \
                            * len(possible_change_indices) / (3 - 2 * ((K_hat + 1) == T_max)) / len(possible_change_indices_candidate))

            ratio = np.exp(min([0, logp_candidate - logp_original]))
            print(np.exp(logp_candidate - logp_original))
        elif change_type == 0:
            
            possible_change_indices = np.where(np.abs(Delta_hat[1:-1] - Delta_hat[2:]) == 1)[0]+1
            index_to_change = np.random.choice(possible_change_indices, 1)

            Delta_hat_candidate[index_to_change + np.array([0, 1])] = Delta_hat_candidate[index_to_change + np.array([1, 0])]
            Stage_hat_candidate = np.cumsum(Delta_hat_candidate)-1
   
            possible_change_indices_candidate = np.where(np.abs(Delta_hat_candidate[1:-1] - Delta_hat_candidate[2:]) == 1)[0]+1
        
            phase = np.array([Stage_hat[index_to_change], Stage_hat_candidate[index_to_change]])
    
            logp_original = 0
            logp_candidate = 0
        
            for i in phase:
                L_i_original = np.where(Stage_hat == i)[0]
                logp_original += np.sum(np.log(np.arange(1, len(L_i_original)) - 1 + b_shape)) \
                               - np.log(gamma_b_shape) + b_shape * np.log(b_rate) \
                               - (b_shape + len(L_i_original)) * np.log(b_rate + np.sum(beta_hat[L_i_original]))
                logp_original += np.sum(np.log(np.arange(1, len(L_i_original)) - 1 + r_shape)) \
                               - np.log(gamma_b_shape) + r_shape * np.log(r_rate) \
                               - (r_shape + len(L_i_original)) * np.log(r_rate + np.sum(-np.log(gamma_hat[L_i_original])))
                logp_original -= np.sum(np.log(np.arange(1, len(L_i_original))))

                L_i_candidate = np.where(Stage_hat_candidate == i)[0]
                logp_candidate += np.sum(np.log(np.arange(1, len(L_i_candidate)) - 1 + b_shape)) \
                                - np.log(gamma_b_shape) + b_shape * np.log(b_rate) \
                                - (b_shape + len(L_i_candidate)) * np.log(b_rate + np.sum(beta_hat[L_i_candidate]))
                logp_candidate += np.sum(np.log(np.arange(1, len(L_i_candidate)) - 1 + r_shape)) \
                                - np.log(gamma_b_shape) + r_shape * np.log(r_rate) \
                                - (r_shape + len(L_i_candidate)) * np.log(r_rate + np.sum(-np.log(gamma_hat[L_i_candidate])))
                logp_candidate -= np.sum(np.log(np.arange(1, len(L_i_candidate))))

            ratio = np.exp(min([0, logp_candidate - logp_original + np.log(len(possible_change_indices) / len(possible_change_indices_candidate))]))
            print(np.exp(logp_candidate - logp_original + np.log(len(possible_change_indices) / len(possible_change_indices_candidate))))
        cxx = np.random.binomial(1, ratio)      
        if cxx == 1:
            Delta_hat = Delta_hat_candidate
            Stage_hat = Stage_hat_candidate
        
        
        ##### b and r sampling
            
        K_hat = np.sum(Delta_hat)
        
        b_hat = np.zeros(K_hat)
        r_hat = np.zeros(K_hat)
        for i in range(K_hat):
            L_i = np.where(Stage_hat == i)[0]
            b_hat[i] = gamma.rvs(size=1, a=(b_shape + len(L_i)), scale=1/(b_rate + np.sum(beta_hat[L_i])))
            r_hat[i] = gamma.rvs(size=1, a=(r_shape + len(L_i)), scale=1/(r_rate + np.sum(-np.log(gamma_hat[L_i]))))
    
        
        ##### beta and gamma sampling

        beta_hat    = gamma.rvs(a=contact_hat + 1, scale=1/(b_hat[Stage_hat] + PI_obs[:-1]*S_obs[:-1]))
        gamma_hat   = beta.rvs(a=N_recovery_obs + r_hat[Stage_hat], b=1 + I_obs[:-1] - N_recovery_obs)
        lambda_t    = beta_hat * PI_obs[:-1]
        p_upper     = 1 - poisson.cdf(0, lambda_t)
        pp          = [uniform.rvs(size=N_infect_obs[t]) * p_upper[t] + (1 - p_upper[t]) for t in range(T_max)]
        contact_hat = np.array(Parallel(n_jobs=-1)(delayed(calculate_contact_hat_parallel)((pp[t], lambda_t[t],)) for t in range(T_max)))
    
       


        if step % 100 == 0 and step != 0:
            print("\nStep:",step)
            #print("Delta_hat:",Delta_hat)
            #print("b_hat:",b_hat)
            #print("r_hat:",r_hat)
            #print("beta_hat:",beta_hat)
            #print("gamma_hat:",gamma_hat)
            print("----------------------------------------------------")
            
        Delta_all.append(Delta_hat)
        Stage_all.append(Stage_hat)
        b_all.append(b_hat)
        r_all.append(r_hat)
        beta_all.append(beta_hat)
        gamma_all.append(gamma_hat)
  
    # Creazione del DataFrame
    MCMC_chain = {
        'Delta': Delta_all[burnin::thinning],
        'Stage': Stage_all[burnin::thinning],
        'b': b_all[burnin::thinning],
        'r': r_all[burnin::thinning],
        'beta': beta_all[burnin::thinning],
        'gamma': gamma_all[burnin::thinning]
            }
    
    return MCMC_chain

In [None]:
MCMC_chain=gibbs_sampling(data, samples=200, burnin=0, thinning=1)

Initialization:

Delta_hat: [1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
b_hat: [0.00121858]
r_hat: [4.72039065e-23]
beta_hat: [2.87185123e-01 3.14597787e-01 5.04175711e-01 3.50660502e-01
 3.10818622e-01 5.02809436e-01 3.50428623e-01 3.27376657e-01
 3.26183159e-01 3.33902073e-01 2.83252419e-01 3.53332073e-01
 4.01043455e-01 5.89561122e-01 1.24694194e+00 3.09552588e-01
 5.60615859e-01 3.66422994e-01 4.63505753e-01 4.99158293e-01
 2.03876230e+00 7.43523679e+00 1.47394391e+00 7.45092419e+00
 3.10038181e+00 4.17967983e+00 4.21683331e+00 3.39417582e+00
 6.03238433e+00 2.05503714e+01 8.40460013e+00 3.77106248e+00
 2.40743679e+01 5.59954905e+00 2.38068298e+01 1.03601139e+02
 1.47240725e+01 1.95281803e+00 4.37628447e+01 3.22728395e+00
 2.65019633e+01 9.70065362e+01 1.14000766e+02 9.51155882e+01
 2.47756072e+02 6.46096234e+01 4.69415460e+

  0%|                                                        | 0/200 [00:00<?, ?it/s]

1
1.558284347072993e-05


  0%|▏                                               | 1/200 [00:01<03:44,  1.13s/it]

1
1.2429924685006875e+23


  1%|▍                                               | 2/200 [00:01<03:00,  1.10it/s]

0
2.119794714335619


  2%|▋                                               | 3/200 [00:02<02:47,  1.17it/s]

1
160523.41101498203


  2%|▉                                               | 4/200 [00:03<03:07,  1.04it/s]

0
0.3034568510837373


  2%|█▏                                              | 5/200 [00:05<03:28,  1.07s/it]

-1
5.623539546229162e-19


  3%|█▍                                              | 6/200 [00:06<03:47,  1.17s/it]

-1
7.100980180151153e-23


  4%|█▋                                              | 7/200 [00:07<03:33,  1.11s/it]

-1
5.358242071438587e-19


  4%|█▉                                              | 8/200 [00:08<03:12,  1.00s/it]

1
2.250717573115289e-07


  4%|██▏                                             | 9/200 [00:08<02:58,  1.07it/s]

0
1.5657146395075727


  5%|██▎                                            | 10/200 [00:09<02:59,  1.06it/s]

-1
3.0753320567865685e-19


  6%|██▌                                            | 11/200 [00:10<02:53,  1.09it/s]

1
9.794017158729014e-06


  6%|██▊                                            | 12/200 [00:11<02:56,  1.06it/s]

0
0.3275450136159221


  6%|███                                            | 13/200 [00:12<02:54,  1.07it/s]

1
9.286888434421816e-08


  7%|███▎                                           | 14/200 [00:13<02:58,  1.04it/s]

0
2.984359130653778


  8%|███▌                                           | 15/200 [00:14<03:05,  1.00s/it]

0
0.639446523939219


  8%|███▊                                           | 16/200 [00:16<03:22,  1.10s/it]

-1
5.287788796112621e-19


  8%|███▉                                           | 17/200 [00:17<03:19,  1.09s/it]

-1
4.239499872712927e-23


  9%|████▏                                          | 18/200 [00:18<03:22,  1.11s/it]

-1
6.418802164343983e-19


 10%|████▍                                          | 19/200 [00:19<03:18,  1.10s/it]

0
2.9098630281897093


 10%|████▋                                          | 20/200 [00:20<03:07,  1.04s/it]

-1
8.929176936061281e-24


 10%|████▉                                          | 21/200 [00:21<02:57,  1.01it/s]

1
2.8915649099486693e-07


 11%|█████▏                                         | 22/200 [00:22<02:45,  1.07it/s]

1
3.0729513374334466e-07


 12%|█████▍                                         | 23/200 [00:22<02:36,  1.13it/s]

1
4.3703010184910246e-05


 12%|█████▋                                         | 24/200 [00:23<02:30,  1.17it/s]

-1
2.205566983374098e-19


 12%|█████▉                                         | 25/200 [00:24<02:31,  1.16it/s]

1
40.27653346497102


 13%|██████                                         | 26/200 [00:25<02:30,  1.16it/s]

1
4.88528860427099e-08


 14%|██████▎                                        | 27/200 [00:26<02:31,  1.14it/s]

1
1.047115340136985e-07


 14%|██████▌                                        | 28/200 [00:27<02:30,  1.14it/s]

0
0.6075187522644775


 14%|██████▊                                        | 29/200 [00:27<02:29,  1.14it/s]

1
6.150746770262345e-06


 15%|███████                                        | 30/200 [00:29<02:36,  1.09it/s]

1
7.059582494476018e-08


 16%|███████▎                                       | 31/200 [00:29<02:32,  1.11it/s]

-1
5.421926301734657e-19


 16%|███████▌                                       | 32/200 [00:30<02:25,  1.15it/s]

1
1.354112712742025e-06


 16%|███████▊                                       | 33/200 [00:31<02:33,  1.09it/s]

1
3.6996355928089154e-07


 17%|███████▉                                       | 34/200 [00:32<02:26,  1.13it/s]

-1
6.745855416228602e-19


 18%|████████▏                                      | 35/200 [00:33<02:39,  1.04it/s]

-1
6.2690841142477385e-19


 18%|████████▍                                      | 36/200 [00:35<03:02,  1.12s/it]

-1
6.643125929520034e-19


 18%|████████▋                                      | 37/200 [00:36<03:06,  1.15s/it]

-1
4.145730687473214e-19


 19%|████████▉                                      | 38/200 [00:37<03:09,  1.17s/it]

1
4.495789186307094e-05


 20%|█████████▏                                     | 39/200 [00:39<03:22,  1.26s/it]

0
0.6084440750984292


 20%|█████████▍                                     | 40/200 [00:39<03:05,  1.16s/it]

-1
4.61564133198272e-19


 20%|█████████▋                                     | 41/200 [00:40<02:46,  1.05s/it]

-1
5.95420352306178e-19


 21%|█████████▊                                     | 42/200 [00:41<02:34,  1.02it/s]

0
0.7806724714568503


 22%|██████████                                     | 43/200 [00:42<02:41,  1.03s/it]

-1
6.057631188067799e-19


 22%|██████████▎                                    | 44/200 [00:43<02:37,  1.01s/it]

1
3.3529160602449796e-06


 22%|██████████▌                                    | 45/200 [00:44<02:29,  1.04it/s]

0
1.635875077468825


 23%|██████████▊                                    | 46/200 [00:45<02:20,  1.10it/s]

-1
3.142650929048589e-19


 24%|███████████                                    | 47/200 [00:46<02:14,  1.14it/s]

0
0.6140167077937732


 24%|███████████▎                                   | 48/200 [00:46<02:09,  1.17it/s]

-1
8.699264286043786e-19


 24%|███████████▌                                   | 49/200 [00:47<02:08,  1.17it/s]

-1
5.761558688014479e-19


 25%|███████████▊                                   | 50/200 [00:48<02:05,  1.19it/s]

0
0.7322247649467146


 26%|███████████▉                                   | 51/200 [00:49<02:03,  1.21it/s]

1
1.3717527651990385e-07


 26%|████████████▏                                  | 52/200 [00:50<02:01,  1.22it/s]

0
0.5443917284340193


 26%|████████████▍                                  | 53/200 [00:50<01:59,  1.23it/s]

0
0.5404548716906585


 27%|████████████▋                                  | 54/200 [00:51<01:58,  1.24it/s]

0
0.5105194434473702


 28%|████████████▉                                  | 55/200 [00:52<01:57,  1.24it/s]

-1
6.834089266311554e-19


 28%|█████████████▏                                 | 56/200 [00:53<01:56,  1.24it/s]

1
7.159224857320426e-08


 28%|█████████████▍                                 | 57/200 [00:54<01:55,  1.24it/s]

1
8.953248258847489e-09


 29%|█████████████▋                                 | 58/200 [00:55<02:00,  1.18it/s]

0
1.454112035446203


 30%|█████████████▊                                 | 59/200 [00:55<01:58,  1.19it/s]

0
0.4823582057737577


 30%|██████████████                                 | 60/200 [00:56<01:57,  1.19it/s]

0
0.44476685491727364


 30%|██████████████▎                                | 61/200 [00:57<01:55,  1.20it/s]

1
2.2309634058870274e-08


 31%|██████████████▌                                | 62/200 [00:58<01:53,  1.21it/s]

-1
1.5685772274846577e-17


 32%|██████████████▊                                | 63/200 [00:59<01:55,  1.19it/s]

0
0.6726782260741819


 32%|███████████████                                | 64/200 [01:00<01:56,  1.17it/s]

-1
1.4885374506748357e-17


 32%|███████████████▎                               | 65/200 [01:00<01:53,  1.19it/s]

0
2.0813648111327243


 33%|███████████████▌                               | 66/200 [01:01<01:50,  1.21it/s]

-1
3.4571750360308607e-19


 34%|███████████████▋                               | 67/200 [01:02<01:49,  1.22it/s]

-1
5.1449171258597095e-18


 34%|███████████████▉                               | 68/200 [01:03<01:47,  1.23it/s]

-1
1.035062122327952e-17


 34%|████████████████▏                              | 69/200 [01:04<01:45,  1.24it/s]

-1
5.094676976760867e-18


 35%|████████████████▍                              | 70/200 [01:04<01:44,  1.24it/s]

1
1.587572580996352e-06


 36%|████████████████▋                              | 71/200 [01:05<01:43,  1.25it/s]

-1
5.375457433586406e-18


 36%|████████████████▉                              | 72/200 [01:06<01:41,  1.26it/s]

1
1.57832153200931e-08


 36%|█████████████████▏                             | 73/200 [01:07<01:41,  1.25it/s]

1
8.961332428006734e-08


 37%|█████████████████▍                             | 74/200 [01:08<01:40,  1.26it/s]

0
0.46832067884661477


 38%|█████████████████▋                             | 75/200 [01:08<01:39,  1.26it/s]

0
0.692370950619849


 38%|█████████████████▊                             | 76/200 [01:09<01:37,  1.27it/s]

-1
1.8012806575946663e-17


 38%|██████████████████                             | 77/200 [01:10<01:37,  1.26it/s]

0
0.43964714464143245


 39%|██████████████████▎                            | 78/200 [01:11<01:36,  1.26it/s]

0
0.6800184144214207


 40%|██████████████████▌                            | 79/200 [01:12<01:36,  1.26it/s]

-1
2.633246165270571e-17


 40%|██████████████████▊                            | 80/200 [01:12<01:35,  1.26it/s]

-1
5.944360610768784e-19


 40%|███████████████████                            | 81/200 [01:13<01:34,  1.26it/s]

-1
5.4074487970809835e-17


 41%|███████████████████▎                           | 82/200 [01:14<01:33,  1.26it/s]

1
1.763445587364538e-08


 42%|███████████████████▌                           | 83/200 [01:15<01:32,  1.26it/s]

1
1.9505871588765413e-06


 42%|███████████████████▋                           | 84/200 [01:16<01:31,  1.26it/s]

1
1.8527081726745054e-08


 42%|███████████████████▉                           | 85/200 [01:16<01:30,  1.27it/s]

1
3.636235428251825e-05


 43%|████████████████████▏                          | 86/200 [01:17<01:30,  1.26it/s]

0
0.42412700124693825


 44%|████████████████████▍                          | 87/200 [01:18<01:29,  1.26it/s]

1
0.0002170741903894074


 44%|████████████████████▋                          | 88/200 [01:19<01:28,  1.26it/s]

0
2.2362107276023737


 44%|████████████████████▉                          | 89/200 [01:20<01:28,  1.26it/s]

0
0.6741688188172623


 45%|█████████████████████▏                         | 90/200 [01:20<01:27,  1.26it/s]

-1
2.438076687426992e-17


 46%|█████████████████████▍                         | 91/200 [01:21<01:25,  1.27it/s]

-1
8.212503866975864e-18


 46%|█████████████████████▌                         | 92/200 [01:22<01:25,  1.27it/s]

1
0.0036417235781142005


 46%|█████████████████████▊                         | 93/200 [01:23<01:24,  1.27it/s]

-1
2.479801790946255e-17


 47%|██████████████████████                         | 94/200 [01:23<01:23,  1.27it/s]

-1
1.016576267654148e-18


 48%|██████████████████████▎                        | 95/200 [01:24<01:22,  1.27it/s]

-1
8.146575266493104e-18


 48%|██████████████████████▌                        | 96/200 [01:25<01:22,  1.26it/s]

1
0.00011061408479581216


 48%|██████████████████████▊                        | 97/200 [01:26<01:21,  1.27it/s]

1
3.046322855948712e-07


 49%|███████████████████████                        | 98/200 [01:27<01:21,  1.26it/s]

0
0.4431137626448037


 50%|███████████████████████▎                       | 99/200 [01:27<01:20,  1.26it/s]

1
3.0556521985045554e-08


 50%|███████████████████████                       | 100/200 [01:28<01:19,  1.26it/s]

0
0.669593650871921


 50%|███████████████████████▏                      | 101/200 [01:29<01:19,  1.24it/s]


Step: 100
----------------------------------------------------
1
0.00017923491273889817


 51%|███████████████████████▍                      | 102/200 [01:30<01:19,  1.24it/s]

0
0.4333519770158698


 52%|███████████████████████▋                      | 103/200 [01:31<01:18,  1.24it/s]

1
1.5461501571013396e-07


 52%|███████████████████████▉                      | 104/200 [01:31<01:17,  1.24it/s]

1
0.004158413990480382


 52%|████████████████████████▏                     | 105/200 [01:32<01:16,  1.24it/s]

1
0.00022558615113396898


 53%|████████████████████████▍                     | 106/200 [01:33<01:15,  1.24it/s]

-1
1.7484768942285407e-18


 54%|████████████████████████▌                     | 107/200 [01:34<01:14,  1.24it/s]

0
0.4360002387967741


 54%|████████████████████████▊                     | 108/200 [01:35<01:14,  1.24it/s]

0
2.1751166184269333


 55%|█████████████████████████                     | 109/200 [01:35<01:13,  1.24it/s]

1
1.5073495168153538e-08


 55%|█████████████████████████▎                    | 110/200 [01:36<01:12,  1.25it/s]

1
2.0421645081989396e-07


 56%|█████████████████████████▌                    | 111/200 [01:37<01:11,  1.24it/s]

0
1.539489034759304


 56%|█████████████████████████▊                    | 112/200 [01:38<01:10,  1.25it/s]

0
0.647132939564254


 56%|█████████████████████████▉                    | 113/200 [01:39<01:09,  1.25it/s]

1
3.9320408021697117e-07


 57%|██████████████████████████▏                   | 114/200 [01:39<01:09,  1.24it/s]

-1
1.210330403004728e-18


 57%|██████████████████████████▍                   | 115/200 [01:40<01:08,  1.24it/s]

-1
1.858412080001532e-18


 58%|██████████████████████████▋                   | 116/200 [01:41<01:09,  1.21it/s]

-1
2.7750942387889647e-18


 58%|██████████████████████████▉                   | 117/200 [01:42<01:08,  1.22it/s]

0
0.4839419228872142


 59%|███████████████████████████▏                  | 118/200 [01:43<01:07,  1.22it/s]

-1
2.6886419829711397e-18


 60%|███████████████████████████▎                  | 119/200 [01:44<01:05,  1.23it/s]

1
3.4479279105219594e-05


 60%|███████████████████████████▌                  | 120/200 [01:44<01:04,  1.24it/s]

0
1.5231424862737795


 60%|███████████████████████████▊                  | 121/200 [01:45<01:03,  1.24it/s]

-1
1.0013400149572021e-18


 61%|████████████████████████████                  | 122/200 [01:46<01:02,  1.25it/s]

1
4.357286774786164e-07


 62%|████████████████████████████▎                 | 123/200 [01:47<01:01,  1.25it/s]

-1
2.780007902904546e-18


 62%|████████████████████████████▌                 | 124/200 [01:48<01:00,  1.26it/s]

-1
7.244082794305402e-18


 62%|████████████████████████████▊                 | 125/200 [01:48<00:59,  1.26it/s]

0
0.4703190210179267


 63%|████████████████████████████▉                 | 126/200 [01:49<00:58,  1.27it/s]

1
0.00012460756947742257


 64%|█████████████████████████████▏                | 127/200 [01:50<00:57,  1.26it/s]

1
1.0962779100282463e-07


 64%|█████████████████████████████▍                | 128/200 [01:51<00:56,  1.27it/s]

0
0.44131279872149876


 64%|█████████████████████████████▋                | 129/200 [01:52<00:56,  1.26it/s]

0
0.6834083725233026


 65%|█████████████████████████████▉                | 130/200 [01:52<00:55,  1.27it/s]

-1
4.797475471522512e-18


 66%|██████████████████████████████▏               | 131/200 [01:53<00:54,  1.26it/s]

-1
3.8862360167661095e-18


 66%|██████████████████████████████▎               | 132/200 [01:54<00:53,  1.27it/s]

0
1.6098396550395369


 66%|██████████████████████████████▌               | 133/200 [01:55<00:52,  1.26it/s]

1
4.3523482339110644e-07


 67%|██████████████████████████████▊               | 134/200 [01:55<00:52,  1.27it/s]

1
2.4005884499699538e-05


 68%|███████████████████████████████               | 135/200 [01:56<00:51,  1.25it/s]

-1
9.950315930185197e-18


 68%|███████████████████████████████▎              | 136/200 [01:57<00:51,  1.25it/s]

0
2.1553053970375133


 68%|███████████████████████████████▌              | 137/200 [01:58<00:50,  1.26it/s]

0
2.000165112260519


 69%|███████████████████████████████▋              | 138/200 [01:59<00:49,  1.26it/s]

1
8.801880318907906e-08


 70%|███████████████████████████████▉              | 139/200 [01:59<00:48,  1.26it/s]

0
1.882925242990665


 70%|████████████████████████████████▏             | 140/200 [02:00<00:47,  1.26it/s]

0
0.7256519659717295


 70%|████████████████████████████████▍             | 141/200 [02:01<00:46,  1.26it/s]

1
6.796853935954174e-05


 71%|████████████████████████████████▋             | 142/200 [02:02<00:45,  1.26it/s]

0
1.283599512032607


 72%|████████████████████████████████▉             | 143/200 [02:03<00:45,  1.26it/s]

1
0.0011935701902459313


 72%|█████████████████████████████████             | 144/200 [02:03<00:44,  1.27it/s]

1
1.8177303166771012e-07


 72%|█████████████████████████████████▎            | 145/200 [02:04<00:43,  1.27it/s]

0
0.6552685881820368


 73%|█████████████████████████████████▌            | 146/200 [02:05<00:42,  1.26it/s]

1
1.4235042447774701e-05


 74%|█████████████████████████████████▊            | 147/200 [02:06<00:41,  1.27it/s]

1
3.781838478924072e-05


 74%|██████████████████████████████████            | 148/200 [02:07<00:41,  1.26it/s]

1
3.454567787840035e-08


 74%|██████████████████████████████████▎           | 149/200 [02:07<00:40,  1.26it/s]

-1
6.118577796114578e-20


 75%|██████████████████████████████████▌           | 150/200 [02:08<00:39,  1.27it/s]

-1
4.477615560816147e-19


 76%|██████████████████████████████████▋           | 151/200 [02:09<00:38,  1.27it/s]

0
0.5605282342887852


 76%|██████████████████████████████████▉           | 152/200 [02:10<00:37,  1.27it/s]

1
2.823619125120404e-08


 76%|███████████████████████████████████▏          | 153/200 [02:10<00:36,  1.28it/s]

1
1.5764813169825233e-06


 77%|███████████████████████████████████▍          | 154/200 [02:11<00:36,  1.27it/s]

1
7.716389125498405e-05


 78%|███████████████████████████████████▋          | 155/200 [02:12<00:35,  1.27it/s]

-1
5.058522782056964e-19


 78%|███████████████████████████████████▉          | 156/200 [02:13<00:35,  1.23it/s]

0
0.6280208811731417


 78%|████████████████████████████████████          | 157/200 [02:14<00:43,  1.00s/it]

-1
6.673661526095355e-20


 79%|████████████████████████████████████▎         | 158/200 [02:16<00:45,  1.09s/it]

-1
4.04281259458003e-17


 80%|████████████████████████████████████▌         | 159/200 [02:17<00:47,  1.17s/it]

-1
4.365530968231077e-17


 80%|████████████████████████████████████▊         | 160/200 [02:18<00:47,  1.18s/it]

1
1.8856690949827298e-05


 80%|█████████████████████████████████████         | 161/200 [02:19<00:42,  1.08s/it]

1
2.396876680671839e-07


 81%|█████████████████████████████████████▎        | 162/200 [02:20<00:37,  1.00it/s]

1
3.9438056788954475e-06


 82%|█████████████████████████████████████▍        | 163/200 [02:21<00:34,  1.07it/s]

0
0.5569340907825969


 82%|█████████████████████████████████████▋        | 164/200 [02:21<00:32,  1.11it/s]

-1
7.682675560782072e-20


 82%|█████████████████████████████████████▉        | 165/200 [02:22<00:30,  1.15it/s]

1
1.058120972424101e-07


 83%|██████████████████████████████████████▏       | 166/200 [02:23<00:28,  1.18it/s]

-1
2.1096734111335562e-19


 84%|██████████████████████████████████████▍       | 167/200 [02:24<00:27,  1.20it/s]

0
1.3410099182618918


 84%|██████████████████████████████████████▋       | 168/200 [02:25<00:26,  1.22it/s]

1
1.5409505726951773e-06


 84%|██████████████████████████████████████▊       | 169/200 [02:25<00:25,  1.23it/s]

1
1.571651878633635e-08


 85%|███████████████████████████████████████       | 170/200 [02:26<00:24,  1.23it/s]

-1
1.0946745884279364e-16


 86%|███████████████████████████████████████▎      | 171/200 [02:27<00:24,  1.20it/s]

0
1.6210160567874823


 86%|███████████████████████████████████████▌      | 172/200 [02:28<00:26,  1.04it/s]

-1
1.526479543542131e-20


 86%|███████████████████████████████████████▊      | 173/200 [02:29<00:24,  1.09it/s]

1
1.6210988868787224e-06


 87%|████████████████████████████████████████      | 174/200 [02:30<00:22,  1.13it/s]

-1
1.6807530590754576e-20


 88%|████████████████████████████████████████▎     | 175/200 [02:31<00:21,  1.17it/s]

-1
8.448877582649892e-17


 88%|████████████████████████████████████████▍     | 176/200 [02:32<00:20,  1.19it/s]

-1
4.835135804275094e-20


 88%|████████████████████████████████████████▋     | 177/200 [02:32<00:19,  1.21it/s]

1
0.0009192644465110379


 89%|████████████████████████████████████████▉     | 178/200 [02:33<00:18,  1.21it/s]

-1
1.6629658984167598e-20


 90%|█████████████████████████████████████████▏    | 179/200 [02:34<00:17,  1.23it/s]

1
0.0008760980476555501


 90%|█████████████████████████████████████████▍    | 180/200 [02:35<00:16,  1.24it/s]

1
3.534508232861855e-05


 90%|█████████████████████████████████████████▋    | 181/200 [02:36<00:15,  1.25it/s]

0
1.5988264263379692


### Estimators

In [None]:
def compute_loss(delta, Q_tt):
    stage = np.cumsum(delta, dtype=int)-1
    Q_candidate = (stage[:, np.newaxis] == stage)
    return np.sum(np.abs(Q_candidate - Q_tt))

def get_estimators(chain, T_max=100, samples=1000):

    ##### delta

    delta = np.array(MCMC_chain["Delta"])
    stage = np.array(MCMC_chain["Stage"])
    masks = (stage[:, :, np.newaxis] == stage[:, np.newaxis, :])
    Q_ttprime = np.sum(masks, axis=0)/samples

    Delta_final = np.zeros(shape=T_max, dtype=int)
    Delta_final[0] = 1
    Stage_final = np.cumsum(Delta_final, dtype=int)-1 
    current_loss = compute_loss(Delta_final, Q_ttprime)

    check_add_drop = True
    check_swap = True
    while(check_add_drop==True or check_swap==True):
        # check add or drop
        candidates_loss = np.zeros(shape=(T_max))
        for i in range(1, T_max):
            Delta_candidate = Delta_final.copy()
            Delta_candidate[i] = 1-Delta_candidate[i]
            candidates_loss[i] = compute_loss(Delta_candidate, Q_ttprime)
        index_min = np.argmin(candidates_loss[1:])+1
        if candidates_loss[index_min] < current_loss:
            current_loss = candidates_loss[index_min]
            Delta_final[index_min] = 1-Delta_final[index_min]
            check_add_drop = True
        else:
            check_add_drop = False

        possible_change_indices = np.where(np.abs(Delta_final[1:-1] - Delta_final[2:]) == 1)[0]+1
        candidates_loss = np.zeros(shape=(len(possible_change_indices)))
        for i, idx in enumerate(possible_change_indices):
            Delta_candidate = Delta_final.copy()
            Delta_candidate[idx+np.array([0,1])] = Delta_candidate[idx+np.array([1,0])]
            candidates_loss[i] = compute_loss(Delta_candidate, Q_ttprime)
        index_min = np.argmin(candidates_loss)
        if candidates_loss[index_min] < current_loss:
            current_loss = candidates_loss[index_min]
            Delta_candidate[possible_change_indices[index_min]+np.array([0,1])] = Delta_candidate[possible_change_indices[index_min]+np.array([1,0])]
            check_swap = True
        else:
            check_swap = False

        # check swap
        if np.sum(Delta_final) in np.arange(1,T_max-1):
            possible_change_indices = np.where(np.abs(Delta_final[1:-1] - Delta_final[2:]) == 1)[0]+1
            candidates_loss = np.zeros(shape=(len(possible_change_indices)))
            for i, idx in enumerate(possible_change_indices):
                Delta_candidate = Delta_final.copy()
                Delta_candidate[idx+np.array([0,1])] = Delta_candidate[idx+np.array([1,0])]
                candidates_loss[i] = compute_loss(Delta_candidate, Q_ttprime)
            index_min = np.argmin(candidates_loss)
            if candidates_loss[index_min] < current_loss:
                current_loss = candidates_loss[index_min]
                Delta_candidate[possible_change_indices[index_min]+np.array([0,1])] = Delta_candidate[possible_change_indices[index_min]+np.array([1,0])]
                check_swap = True
            else:
                check_swap = False


    ##### beta and gamma
                
    b = np.array([[MCMC_chain["b"][s][stage[s][t]] for t in range(T_max)] for s in range(samples)])
    r = np.array([[MCMC_chain["r"][s][stage[s][t]] for t in range(T_max)] for s in range(samples)])
    beta_final  = np.sum(1/b, axis=0)/samples
    gamma_final = np.sum(r/(1+r), axis=0)/samples


    return Delta_final, beta_final, gamma_final
        

In [None]:
delta_est, beta_est, gamma_est = get_estimators(MCMC_chain, 100, 200)
stage_est = np.cumsum(delta_est)-1

In [None]:
delta_true = np.array([(t in chg_pt or t==1) for t in range(1, T_max+1)]).astype(int)
stage_true = np.cumsum(delta_true)-1
print(f"True change points={np.where(delta_true==1)[0]}")
print(f"Predicted change points={np.where(delta_est==1)[0]}")

In [None]:
n_sc=3
true_beta  = np.array([[scenarios[i,(chg_pt <= t+1).sum(),0] for t in range(T)] for i in range(n_sc)])
true_gamma = np.array([[scenarios[i,(chg_pt <= t+1).sum(),1] for t in range(T)] for i in range(n_sc)])

fig, ax = plt.subplots()
plt.plot(beta_est, label="Estimated Beta", color='dodgerblue')
plt.plot(true_beta[0], label="True Beta", color='dodgerblue', linestyle='dashed', alpha=0.7)
plt.plot(gamma_est, label="Estimated Gamma", color='darkorange')
plt.plot(true_gamma[0], label="True Gamma", color='darkorange', linestyle='dashed', alpha=0.7)
ax.set_xlabel("Day")
plt.legend(loc="best")

### Agreement with True Values

In [None]:
def comp_ARI(true, estim, T_max=100):

    comb = np.array(list(itertools.combinations(range(T_max), 2))).T
    true_mask = (true[comb[0]] == true[comb[1]]).astype(int)
    estim_mask = (estim[comb[0]] == estim[comb[1]]).astype(int)

    TP = np.mean(true_mask*estim_mask)
    FP = np.mean((1-true_mask)*estim_mask)
    FN = np.mean(true_mask*(1-estim_mask))
    TN = np.mean((1-true_mask)*(1-estim_mask))
    num = TP+TN-(TP+FP)*(TP+FN)-(TN+FP)*(TN+FN)
    den = 1-(TP+FP)*(TP+FN)-(TN+FP)*(TN+FN)

    ARI = num/den

    return ARI


In [None]:
def comp_MI(true, estim, T_max=100):
    n_kkprime = np.histogram2d(true, estim, bins=(np.max(true)+1, np.max(estim)+1))[0]
    n_k = np.sum(n_kkprime, axis=1)
    n_kprime = np.sum(n_kkprime, axis=0)
    
    MI = np.sum(n_kkprime/T_max*np.log((n_kkprime+(n_kkprime==0))*T_max/np.outer(n_k, n_kprime)))
    return MI

In [None]:
ARI = comp_ARI(stage_true, stage_est)
MI = comp_MI(stage_true, stage_est)
print(f"ARI = {ARI} out of 1\nMI = {MI} out of 1.386")