# Parameter inference for the SEIQR model.

This model contains 5 species (Susceptible, Exposed, Infected, Quarantined and Recovered).

In [None]:
import torch as tn
import torchtt as tntt
import TTCME
import matplotlib.pyplot as plt 
import datetime
import numpy as np
import pickle 

tn.set_default_tensor_type(tn.DoubleTensor)
qtt = True

Define the model. The 4 reaction coefficients are the parameters.

In [None]:
rates = np.array([0.04,0.4,0.4,0.004,0.12,0.8765,0.01,0.01,0.01])

r1m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'S+I->E+I', 'c1', params = ['c1'])
r2m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'E->I', 'c2', params = ['c2'])
r3m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'I->Q', 'c3', params = ['c3'])
r4m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'I->', 'c4', params = ['c4'])
r5m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'I->R', rates[4], params = [])
r6m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'Q->R', rates[5], params = [])
r7m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'I->S', rates[6], params = [])
r8m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'Q->S', rates[7], params = [])
r9m = TTCME.ChemicalReaction(['S','E','I','Q','R'],'->S',  rates[8], params = [])

r1 = TTCME.ChemicalReaction(['S','E','I','Q','R'],'S+I->E+I', rates[0], params = [])
r2 = TTCME.ChemicalReaction(['S','E','I','Q','R'],'E->I', rates[1], params = [])
r3 = TTCME.ChemicalReaction(['S','E','I','Q','R'],'I->Q', rates[2], params = [])
r4 = TTCME.ChemicalReaction(['S','E','I','Q','R'],'I->',  rates[3], params = [])

mdl = TTCME.ReactionSystem(['S','E','I','Q','R'],[r1m, r2m, r3m, r4m, r5m, r6m, r7m, r8m, r9m], params= ['c1','c2','c3','c4'])
mdl_true = TTCME.ReactionSystem(['S','E','I','Q','R'],[r1, r2, r3, r4, r5m, r6m, r7m, r8m, r9m], params= [])

IC = [90,0,4,0,0]
N = [128,64,64,32,32] # state truncation

Generate the measurements and the time grid discretization.

In [None]:
No = 45
Tend = 5
time_observation = np.linspace(0, Tend, No+1)
dT = time_observation[1]-time_observation[0]

sigmas = [0.1,0.1,0.1,0.01,0.01]

np.random.seed(13212)
reaction_time,reaction_jumps,reaction_indices = mdl_true.ssa_single(IC, time_observation[-1])
observations = mdl_true.jump_process_to_states(time_observation, reaction_time, reaction_jumps)
observations_noise = TTCME.pdf.LogNormalObservation.add_noise(observations, sigmas)


Plot the observations and the true trajectory.

In [None]:
plt.figure()
for i in range(5): plt.plot(np.repeat(reaction_time,2)[1:],np.repeat(reaction_jumps[:,i],2)[:-1])
for i in range(5): plt.scatter(time_observation,observations_noise[:,i],c='k', marker='x',s=20)
plt.legend(['S','E','I','Q','R'])
plt.xlabel(r'$t$')
plt.ylabel(r'#individuals')
# import tikzplotlib
# tikzplotlib.save('seiqr_trajectory.tex')

Define the basis of the parameter space.

In [None]:
Nl = 64
mult = 5
param_range = [[0,rc*mult] for rc in rates[:4]]
basis_param = [TTCME.basis.BSplineBasis(Nl,[p[0],p[1]],deg = 2) for p in param_range]

Define the prior distribution (independent gamma).

In [None]:
mu = rates[:4]*np.array([1.5,1.5,1.5,1.0])
var = rates[:4] * np.array([0.025, 0.1, 0.25, 0.0001])
alpha_prior = mu**2/var
beta_prior = mu/var
prior = TTCME.GammaPDF(alpha_prior, beta_prior, basis_param, ['c1','c2','c3','c4'])


Define the initial condition and the joint PDF.

In [None]:
p_ic = TTCME.pdf.SingularPMF(N,IC,['S','E', 'I', 'Q', 'R'])
p0 = p_ic ** prior
p0.normalize()

p = p0.dofs.clone()

print(p0)

Instantiate the observation operator.

In [None]:
obs_operator = TTCME.pdf.LogNormalObservation(N, sigmas)

Compute the CME operator for the parameter dependent case.

In [None]:
Stt,Mtt,Mtt_inv = mdl.generator_tt_galerkin(N, basis_param)
Att = Mtt_inv @ Stt

Instantiate the TT time-domain solver. If the `qtt` flag is `True`, the result will be performed in the QTT format.

In [None]:
if qtt:
    A_qtt = Att.to_qtt()
    fwd_int = TTCME.TimeIntegrator.TTInt(A_qtt, epsilon = 1e-5, N_max = 8, dt_max = 1.0,method='cheby')
    Nbs = 1
    p = p.to_qtt()
else:
    fwd_int = TTCME.TimeIntegrator.TTInt(Att, epsilon = 1e-5, N_max = 64, dt_max = 1.0,method='crank–nicolson')
    Nbs = 1

Perform the perameter inference.

In [None]:
posterior_list  =[]
joint_pdf = p0.copy()
for i in range(1,No):
    
    y = observations_noise[i,:]

    
    po = obs_operator.likelihood(y) 
    #po = po / po.sum()
    
    po = po ** tntt.ones([Nl]*4)

    if qtt: po = po.to_qtt()
    
    print('new observation ',i,'/',No,' at time ',time_observation[i],' ',y)
    
    tme = datetime.datetime.now()
    p = fwd_int.solve(p, dT, intervals = Nbs,qtt = qtt,verb = False, rounding=True, device = None)
    tme = datetime.datetime.now() - tme
    
    
    print('\tmax rank ',max(p.R))
    p_pred = p
    p_post = po * p_pred
    p_post = p_post.round(1e-10)
    print('\tmax rank (after observation) ',max(p_post.R))
    
    if qtt:
        joint_pdf.dofs = tntt.reshape(p_post,joint_pdf.dofs.N)
        Z = joint_pdf.Z
        joint_pdf.normalize()
        p = p_post / Z
    else:
        joint_pdf.dofs = p_post.clone()
        joint_pdf.normalize()
        p = joint_pdf.dofs.clone()

    
    posterior_pdf = joint_pdf.marginal([0,1,2,3,4])
    posterior_pdf.round(1e-10)
    posterior_list.append(posterior_pdf.copy())



    E = posterior_pdf.expected_value()
    
    print('\tExpected value computed posterior ' ,E)
    # print('\tVariance computed posterior       ' ,V)
    print('\tposterior size ',tntt.numel(p)*8 / 1000000,' MB')
    print('\telapsed ',tme)

Calculate the posterior mean and covariance and display the marginals $ p(\theta_1) , ..., p(\theta_4)$.

In [None]:
posterior_pdf.normalize()

E = posterior_pdf.expected_value()
C = posterior_pdf.covariance_matrix()
V = np.diag(C)

print()
print('Exact rates:                      ',rates)
print('')
print('Expected value computed posterior ' ,E)
print('Variance computed posterior       ' ,V)
# print('Computed modes:                   ',theta_mode)
print('')
print('Expected value prior              ' ,alpha_prior/beta_prior)
print('Variance computed prior           ' ,alpha_prior/beta_prior/beta_prior)
print('')


post_1 = posterior_pdf.marginal([1,2,3])
post_2 = posterior_pdf.marginal([0,2,3])
post_3 = posterior_pdf.marginal([0,1,3])
post_4 = posterior_pdf.marginal([0,1,2])


for i in range(4):
    pr = prior.marginal([j for j in range(4) if i!=j])
    po = posterior_pdf.marginal([j for j in range(4) if i!=j])
    x = np.linspace(param_range[i][0],param_range[i][1],1000)
    
    
    plt.figure()
    plt.plot(x, po[x].numpy())
    plt.axvline(rates[i],c='r',linestyle=':')
    plt.plot(x, pr[x].numpy())

Plot the marginalzied posteriors $p(\theta_i,\theta_j),i,j=1,...,4$.

In [None]:
plt.figure(figsize = [6,6])

k = 0
for i in range(4):
    for j in range(4):
        k += 1
        if i==j:
            plt.subplot(4, 4, k)
            
            theta = np.linspace(param_range[i][0],param_range[i][1],1000)
            pr = prior.marginal([k for k in range(4) if i!=k])[theta].numpy()
            po = posterior_pdf.marginal([k for k in range(4) if i!=k])[theta].numpy()
           
            plt.plot(theta,po/np.max(po)*np.max(theta))
            plt.axvline(rates[i],c='r',linestyle=':')
            plt.plot(theta,pr/np.max(po)*np.max(theta),'g:')
            
        else:
            plt.subplot(4, 4, k)
             
            if j<i: 
                theta1 = np.linspace(param_range[j][0],param_range[j][1],127)
                theta2 = np.linspace(param_range[i][0],param_range[i][1],128)
                
                T1, T2 = np.meshgrid(theta1, theta2)
                po = posterior_pdf.marginal([k for k in range(4) if i!=k and j!=k])
                po.normalize()
                po = po[theta1,theta2].numpy().T
                plt.contourf(T1, T2, po, cmap='gray_r', levels =12)
                plt.axvline(rates[j],c='r',linestyle=':',linewidth=1)
                plt.axhline(rates[i],c='r',linestyle=':',linewidth=1)
            else: 
                theta1 = np.linspace(param_range[i][0],param_range[i][1],127)
                theta2 = np.linspace(param_range[j][0],param_range[j][1],128)
                
                T1, T2 = np.meshgrid(theta1, theta2)
                po = posterior_pdf.marginal([k for k in range(4) if i!=k and j!=k])
                po.normalize()
                po = po[theta1,theta2].numpy().T
                plt.contourf(T2, T1, po, cmap='gray_r', levels =32)
                plt.axvline(rates[j],c='r',linestyle=':',linewidth=1)
                plt.axhline(rates[i],c='r',linestyle=':',linewidth=1)
        
        if i==3: plt.xlabel(r'$\theta_'+str(j+1)+'$')
        if j==0: plt.ylabel(r'$\theta_'+str(i+1)+'$')
        
        if j>0: plt.yticks([])
        if i<3: plt.xticks([])

plt.savefig('seiqr_post.eps')