This notebook constructs an emulator for an agent based model using an ensemble of parameters.  It then performs MCMC sampling on it.  Some of these MCMC samples are fed back into the ABM to create better and targeted training data in the region of good fidelity to the data.    

In [None]:
#%load_ext autoreload
#%autoreload 2

import h5py  
import torch
import numpy as np
from torch.utils import data
import json
from sklearn.preprocessing import normalize
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
import pandas as pd
import pandemic as pan
from multiprocessing import Pool
from data_tools import hdf5Dataset_init,normalize_sets,Emulator
from scipy.stats import beta
from scipy.special import gamma
import pickle
from joblib import Parallel, delayed

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
keys = ['initial_infected_fraction', 'initial_removed_fraction', 'incubation_period', 
        'serial_interval', 'symptomatic_fraction', 'recovery_days', 'quarantine_days', 
        'days_indetectable', 'R0', 'contact_rate', 'npi_factor', 'contact_tracing_testing_rate', 
        'contact_tracing_quarantine_rate', 'contact_tracing_days', 'daily_testing_fraction', 
        'daily_testing_false_positive', 'daily_testing_false_negative', 'class_size_limit', 
        'contact_upscale_factor', 'friendship_contacts', 'academic_contacts', 'broad_social_contacts', 
        'department_environmental_contacts', 'broad_environmental_contacts', 'residential_neighbors', 
        'online_transition']

keys_to_round = ['incubation_period', 'serial_interval', 'recovery_days',
                 'quarantine_days', 'days_indetectable', 'contact_tracing_days',
                 'class_size_limit', 'friendship_contacts', 'academic_contacts',
                 'broad_social_contacts', 'department_environmental_contacts',
                 'broad_environmental_contacts', 'residential_neighbors', 'online_transition']

Read in the initial (large) ensemble of ABM model runs

In [None]:
initset = hdf5Dataset_init('full_data.hdf5')   
training, testing = initset.split_datasets(.1,split_type='parameter')
training_,testing_,X_mean,X_std = normalize_sets(training,testing)

Read in the COVID data.

In [None]:
data = pd.read_csv('datasets/testing.csv',skiprows=9)
active = pd.read_csv('datasets/active.csv',skiprows=1)
av = active.values[7:,2]
r = np.linspace(0,len(av)-1,len(av))
dd = np.vstack((r,av)).T
ac = dd[np.invert(np.isnan(dd[:,1].astype(float)))]

rapid_pcr_tests = np.nan_to_num(data.values[79:,2].astype(float))
rapid_pcr_pos = np.nan_to_num(data.values[79:,4].astype(float))
rapid_antigen_tests = np.nan_to_num(data.values[79:,3].astype(float))
rapid_antigen_pos = np.nan_to_num(data.values[79:,5].astype(float))
state_tests = np.nan_to_num(data.values[79:,6].astype(float))
state_pos = np.nan_to_num(data.values[79:,8].astype(float))

total_tests = rapid_pcr_tests + rapid_antigen_tests + state_tests
total_pos = rapid_pcr_pos + rapid_antigen_pos + state_pos
cum_tests = np.cumsum(total_tests)
cum_pos = np.cumsum(total_pos)
n_data = 194 - 90

test_obs = torch.tensor(cum_tests,dtype=torch.float,device=device)[:n_data]
pos_obs = torch.tensor(cum_pos,dtype=torch.float,device=device)[:n_data]

Find minimum and maximum values for the parameters.

In [None]:
alpha_b = 3
beta_b = 3

X = torch.stack([t[0] for t in training_])

X_min = X.cpu().numpy().min(axis=0)-1e-3
X_max = X.cpu().numpy().max(axis=0)+1e-3

X_min = torch.tensor(X_min,dtype=torch.float32,device=device)
X_max = torch.tensor(X_max,dtype=torch.float32,device=device)

Define the log-posterior function for ABM sampling.  The alpha parameter is used to downweight the likelihood, since we know that the model is misspecified.

In [None]:
alpha = 0.1
sigma_tes = 100
sigma_pos = 10
nu = 1.

time = torch.linspace(1,n_data,n_data,device=device)
def V(X):
    _pred = 2**model(X) - 1.

    test_pred = _pred[:n_data,2]
    pos_pred = _pred[:n_data,1]
    
    r_test = (test_pred - test_obs)
    r_pos = (pos_pred - pos_obs)
    X_bar = (X - X_min)/(X_max - X_min)

    sigma_tes_t = torch.sqrt(time)*sigma_tes
    sigma_pos_t = torch.sqrt(time)*sigma_pos

    L1a = torch.sum(np.log(gamma((nu+1)/2.)) - np.log(gamma(nu/2.)) - torch.log(np.sqrt(np.pi*nu)*sigma_tes_t) - (nu+1)/2.*torch.log(1 + 1./nu*(r_test**2/sigma_tes_t**2)))
    L1b = torch.sum(np.log(gamma((nu+1)/2.)) - np.log(gamma(nu/2.)) - torch.log(np.sqrt(np.pi*nu)*sigma_pos_t) - (nu+1)/2.*torch.log(1 + 1./nu*(r_pos**2/sigma_pos_t**2)))
    L2 = torch.sum((alpha_b-1)*torch.log(X_bar) + (beta_b-1)*torch.log(1-X_bar)) 

    return -(alpha*(L1a + L1b) + L2)

def V_empirical(X,Y):
    _pred = Y
    #print(_pred)

    test_pred = _pred[:n_data,2]
    pos_pred = _pred[:n_data,1]
    #print(U_pred.min())
    r_test = (test_pred - test_obs)
    r_pos = (pos_pred - pos_obs)
    X_bar = (X - X_min)/(X_max - X_min)
    sigma_tes_t = torch.sqrt(time)*sigma_tes
    sigma_pos_t = torch.sqrt(time)*sigma_pos

    L1a = torch.sum(np.log(gamma((nu+1)/2.)) - np.log(gamma(nu/2.)) - torch.log(np.sqrt(np.pi*nu)*sigma_tes_t) - (nu+1)/2.*torch.log(1 + 1./nu*(r_test**2/sigma_tes_t**2)))
    L1b = torch.sum(np.log(gamma((nu+1)/2.)) - np.log(gamma(nu/2.)) - torch.log(np.sqrt(np.pi*nu)*sigma_pos_t) - (nu+1)/2.*torch.log(1 + 1./nu*(r_pos**2/sigma_pos_t**2)))
    L2 = torch.sum((alpha_b-1)*torch.log(X_bar) + (beta_b-1)*torch.log(1-X_bar)) 


    return -(alpha*(L1a + L1b) + L2)

def get_log_like_gradient_and_hessian(V,X,eps=1e-2,compute_hessian=False):
    log_pi = V(X)
    if compute_hessian:
        g = torch.autograd.grad(log_pi,X,retain_graph=True,create_graph=True)[0]
        H = torch.stack([torch.autograd.grad(e,X,retain_graph=True)[0] for e in g])
        lamda,Q = torch.eig(H,eigenvectors=True)
        lamda_prime = torch.sqrt(lamda[:,0]**2 + eps)
        lamda_prime_inv = 1./torch.sqrt(lamda[:,0]**2 + eps)
        H = Q @ torch.diag(lamda_prime) @ Q.T
        Hinv = Q @ torch.diag(lamda_prime_inv) @ Q.T
        log_det_Hinv = torch.sum(torch.log(lamda_prime_inv))
        return log_pi,g,H,Hinv,log_det_Hinv
    else: 
        return log_pi

Initialize emulator (a neural network), loss function, and an optimizer.

In [None]:
model = Emulator()
model.to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

This function trains the model, with a higher weight given to models that closely approximate the observed data (we care little if the surrogate works well for poor quality models).

In [None]:
def train_model(model,training_,testing_,epochs=200,batch_size=256,weight_factor=50.):
    V_vals = torch.stack([V_empirical(x,y.T) for x,y in zip(training_.tensors[0],training_.tensors[1])])

    w = torch.exp(-V_vals/weight_factor)
    w/=w.sum()


    train_weights = TensorDataset(training_.tensors[0],training_.tensors[1],w)

    train_loader = torch.utils.data.DataLoader(dataset=train_weights,
                                           batch_size=batch_size, 
                                           shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=testing_,
                                           batch_size=batch_size, 
                                           shuffle=False)

    for epoch in range(epochs):
        model.train()
        # Loop over each subset of data
        for d,t,w in train_loader:
        
            t = torch.log2(t+1.)

            # Zero out the optimizer's gradient buffer+
            optimizer.zero_grad()
            # Make a prediction based on the model

            outputs = model(d)

            # Compute the loss+
        
            #standardize the output first
            loss = 0
            residual_squared = (outputs - t)**2
            for ww,r in zip(w,residual_squared):
                loss += torch.sum(ww*r)
                
            #print('d',d.min(),d.max())
            #print('outputs',outputs.min(),outputs.max())
            #print('t',t.min(),t.max())
            #print('loss',loss)
        
            loss.backward()

            # Use the derivative information to update the parameters
            optimizer.step()
    
        model.eval()
        
        if epoch %50 == 0:
            training_loss = criterion(model(training_.tensors[0]),torch.log2(training_.tensors[1]+1))
            test_loss = criterion(model(testing_.tensors[0]),torch.log2(testing_.tensors[1]+1))
            print('finished epoch {}'.format(epoch))
            print('loss', training_loss.item(), test_loss.item())
            print('*'*50)
    
    print('finished training') 

Define the MCMC sampling procedure.

In [None]:
def draw_sample(mu,cov,eps=1e-10):
    L = torch.cholesky(cov + eps*torch.eye(cov.shape[0],device=device))
    return mu + L @ torch.randn(L.shape[0],device=device)

def get_proposal_likelihood(Y,mu,inverse_cov,log_det_cov):
    return -0.5*log_det_cov - 0.5*(Y - mu) @ inverse_cov @ (Y-mu)

def MALA_step(X,h,local_data=None):
    if local_data is not None:
        pass  
    else:
        local_data = get_log_like_gradient_and_hessian(V,X,compute_hessian=True)
        
    log_pi,g,H,Hinv,log_det_Hinv = local_data
    
    X_ = draw_sample(X,2*h*Hinv).detach()
    #X_0[0] = X_0[0].round()
    #X_0[1] = X_0[1].round()
    #X_0[-1] = X_0[-1].round()
    X_.requires_grad=True
    
    log_pi_ = get_log_like_gradient_and_hessian(V,X_,compute_hessian=False)

    logq = get_proposal_likelihood(X_,X,H/(2*h),log_det_Hinv)
    logq_ = get_proposal_likelihood(X,X_,H/(2*h),log_det_Hinv)

    log_alpha = (-log_pi_ + logq_ + log_pi - logq)
    alpha = torch.exp(min(log_alpha,torch.tensor([0.],device=device)))
    u = torch.rand(1,device=device)
    if u <= alpha and log_alpha!=np.inf:
        X.data = X_.data
        log_pi_new = get_log_like_gradient_and_hessian(V,X,compute_hessian=False)
        local_data[0].data = log_pi_new.data# = get_log_like_gradient_and_hessian(V,X,compute_hessian=True)
        s = 1
    else:
        s = 0
    return X,local_data,s,local_data[0]

def MALA(X,n_iters=10001,h=0.001,h_max=1.0,acc_target=0.25,k=0.01,beta=0.99,sample_path='./samples/',model_index=0,save_interval=1000,print_interval=50):
    print('***********************************************')
    print('***********************************************')
    print('Running Metropolis-Adjusted Langevin Algorithm for model index {0}'.format(model_index))
    print('***********************************************')
    print('***********************************************')
    local_data = None
    vars = []
    log_pis = []
    acc = acc_target
    for i in range(n_iters):
        X,local_data,s,log_pi = MALA_step(X,h,local_data=local_data)
        vars.append(X.detach())
        log_pis.append(log_pi.detach())
        acc = beta*acc + (1-beta)*s
        h = min(h*(1+k*np.sign(acc - acc_target)),h_max)
        if i%print_interval==0:
            print('===============================================')
            print('sample: {0:d}, acc. rate: {1:4.2f}, log(P): {2:6.1f}'.format(i,acc,local_data[0].item()))
            print('curr. m: '+('{:.4f} '*26).format(*X.data.cpu().numpy()))
            print('===============================================')
          
        #if i%save_interval==0:
        #    print('///////////////////////////////////////////////')
        #    print('Saving samples for model {0:03d}'.format(model_index))
        #    print('///////////////////////////////////////////////')
        #    X_posterior = torch.stack(vars).cpu().numpy()
        #    np.save(open(sample_path+'X_posterior_model_{0:03d}.npy'.format(model_index),'wb'),X_posterior)
    X_posterior = torch.stack(vars)#.cpu().numpy()
    pi_posterior = torch.stack(log_pis)
    return X_posterior,pi_posterior 

Define a function that runs the ABM for given parameter values.

In [None]:
def run_abm(run_input):
    run_index,parameters = run_input
    sample = dict(zip(keys,parameters))
    for k in keys_to_round:
        sample[k] = int(np.round(sample[k]))

    sample['scenario_name'] = 'trial_'+str(run_index)
    sample['quarantining'] = 1
    sample['social_distancing'] = 1
    sample['contact_tracing'] = 1

    pandemic = pan.Disease(sample)
    pandemic.multiple_runs(5,recorder)

Iterate between training the surrogate, running MCMC, and running the ABM with samples from the MCMC sampler to augment the training set.

In [None]:
# Number of ABM runs to perform for each iterations
n_abm_samples = 20

for i in range(0,20):
    print('iteration '+str(i))
    # Create a new model at each iteration (training is cheap, and this avoids local minima)
    model = Emulator()
    model.to(device)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # Train the emulator with the existing training data
    train_model(model,training_,testing_,epochs=500)
    torch.save(model.state_dict(),'iterative_models/trained_{:03d}.h5'.format(i))
    
    # Evaluate the surrogate on the training data and find the prediction closest to observation.
    # Use this as the initial guess for MCMC sampling
    objs = torch.stack([V(X_) for X_ in X])
    X_0 = torch.tensor(X[torch.argmin(objs)],requires_grad=True,dtype=torch.float,device=device)
    
    # Draw 5000 MCMC samples.
    X_posterior,pi_posterior = MALA(X_0,n_iters=5001,model_index=0,save_interval=1000,print_interval=100)

    # Rescale parameters (the neural net expects them to be z-normalized, but the ABM doesn't)
    X_p_numpy = (X_posterior*X_std + X_mean).detach().cpu().numpy().astype(np.float64)
    pickle.dump(X_p_numpy,open('mcmc_samples/sample_{:03d}.p'.format(i),'wb'))
    
    # Initialize a recorder for the ABM
    recorder = pan.analysis.recorder(['tests_performed_total', 'positive_tests_total', 'active_cases'], 
                                     'abm_results/runs_{:03d}.hdf5'.format(i))
    
    # Draw a few random samples from the MCMC results
    samples = X_p_numpy[np.random.choice(range(1000,X_p_numpy.shape[0]),n_abm_samples,replace=False)]
    
    # Run the ABM on these samples in parallel, then augment the training set with the new ABM runs
    try:
        Parallel(n_jobs=4)(delayed(run_abm)(x) for x in enumerate(samples))
        newset = hdf5Dataset_init('abm_results/runs_{:03d}.hdf5'.format(i))   
        train_new, test_new = newset.split_datasets(.1,split_type='parameter')
        train_new_,test_new_,_,_ = normalize_sets(train_new,test_new,X_mean=X_mean,X_std=X_std)
        training_ = TensorDataset(torch.cat((training_.tensors[0],train_new_.tensors[0])),torch.cat((training_.tensors[1],train_new_.tensors[1])))
        testing_ = TensorDataset(torch.cat((testing_.tensors[0],test_new_.tensors[0])),torch.cat((testing_.tensors[1],test_new_.tensors[1])))

        X = torch.stack([t[0] for t in training_])
        X_min = X.cpu().numpy().min(axis=0)-1e-3
        X_max = X.cpu().numpy().max(axis=0)+1e-3
        X_min = torch.tensor(X_min,dtype=torch.float32,device=device)
        X_max = torch.tensor(X_max,dtype=torch.float32,device=device)
    except KeyError:
        pass

Load the model that was trained most recently.

In [None]:
model.load_state_dict(torch.load('iterative_models/trained_019.h5'))

Perform some plotting.

In [None]:
fig,axs = plt.subplots(nrows=3,figsize=(12,12))

# This is the best model the surrogate produces
X_star = X_posterior[np.argmin(pi_posterior.cpu().numpy())]
pred_star = 2**model(X_star) - 1

# These are the posterior log probabilities for each training example
V_vals = torch.stack([V_empirical(x,y.T) for x,y in zip(training_.tensors[0],training_.tensors[1])])

# Plot a bunch of predictions from the MCMC samples
for i in range(0,5000,20):
    index = i# np.random.randint(1000)
    pred_opt = 2**model(X_posterior[index]) - 1.
    axs[0].plot(pred_opt[:n_data,2].detach().cpu().numpy(),'k-',alpha=0.02,rasterized=True)
    axs[1].plot(pred_opt[:n_data,1].detach().cpu().numpy(),'k-',alpha=0.02,rasterized=True)
    axs[2].plot(pred_opt[:n_data,0].detach().cpu().numpy(),'k-',alpha=0.02,rasterized=True)
    #axs[2].plot(X_posterior[index,*pred_opt[:n_data,0].detach().cpu().numpy(),'b--',alpha=0.02,rasterized=True)

# Plot the data
axs[0].plot(test_obs[:n_data].detach().cpu().numpy(),'r:',lw=3.0)
axs[1].plot(pos_obs[:n_data].detach().cpu().numpy(),'r:',lw=3.0)
axs[0].plot(pred_star[:n_data,2].detach().cpu().numpy(),'b--',lw=3.0)
axs[1].plot(pred_star[:n_data,1].detach().cpu().numpy(),'b--',lw=3.0)
axs[2].plot(pred_star[:n_data,0].detach().cpu().numpy(),'b--',lw=3.0)
#axs[2].plot(pred_star[:n_data,0].detach().cpu().numpy(),'b--',lw=3.0)

# Plot the best fitting (n) training example(s).
for idx in torch.sort(V_vals)[1][0:1]:
    best_train = training_.tensors[1][idx].detach().cpu().numpy()
    pred_hat = 2**model(training_.tensors[0][idx]).detach().cpu().numpy()
    axs[0].plot(best_train[2,:n_data],'g-',alpha=0.5)
    axs[1].plot(best_train[1,:n_data],'g-',alpha=0.5)
    axs[2].plot(best_train[0,:n_data],'g-',alpha=0.5)
    axs[0].plot(pred_hat[:n_data,2],'g--',alpha=0.5)
    axs[1].plot(pred_hat[:n_data,1],'g--',alpha=0.5)
    axs[2].plot(pred_hat[:n_data,0],'g--',alpha=0.5)

axs[2].plot(ac[:,0],ac[:,1],'r:',lw=3.0)
axs[2].set_xlim(0,n_data)
axs[1].set_xlim(0,n_data)
axs[0].set_xlim(0,n_data)
axs[2].set_ylim(0,1500)

axs[0].set_ylabel('# Tests')
axs[1].set_ylabel('# Positive')
axs[2].set_ylabel('Active Cases')
axs[2].set_xlabel('Days since 08/19/2020')

fig.savefig('Covid_surrogate_ModelvsObs.pdf')

Print a dictionary of the mean and map parameters.

In [None]:
keys = ['initial_infected_fraction', 'initial_removed_fraction', 'incubation_period', 'serial_interval', 'symptomatic_fraction', 'recovery_days', 'quarantine_days', 'days_indetectable', 'R0', 'contact_rate', 'npi_factor', 'contact_tracing_testing_rate', 'contact_tracing_quarantine_rate', 'contact_tracing_days', 'daily_testing_fraction', 'daily_testing_false_positive', 'daily_testing_false_negative', 'class_size_limit', 'contact_upscale_factor', 'friendship_contacts', 'academic_contacts', 'broad_social_contacts', 'department_environmental_contacts', 'broad_environmental_contacts', 'residential_neighbors', 'online_transition']
X_p_numpy = (X_posterior*X_std + X_mean).detach().cpu().numpy().astype(np.float64)
X_ = training_.tensors[0][idx]
X__numpy = (X_*X_std + X_mean).detach().cpu().numpy()

mean_params = dict(zip(keys,X_p_numpy.mean(axis=0)))
map_train = dict(zip(keys,X__numpy))

import pprint
pprint.pprint(mean_params)
pprint.pprint(map_train)