In [1]:
import numpy as np
import pandas as pd
import torch
from pyabc import Distribution, RV
from sbi.utils import BoxUniform

  mod = _original_import(name, globals, locals, fromlist, level)


In [2]:
from src.utils import set_seed
from src.inference import SBIEngine
from models.epidemic_models import simulate_seir  

In [3]:
# 1. Global Setup
set_seed(0) # Fix seed for reproducibility
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

[#] Seed has been fixed to: 0
Using device: cpu


In [4]:
# 2. Define Priors
low_linear = np.array([0.01, 0.01, 0.01])
high_linear = np.array([1.5, 0.5, 0.5])

In [5]:
# 2. Define Priors (Two versions for pyabc and sbi compatibility)
# sbi version (low, high)
sbi_prior = BoxUniform(
    low=torch.tensor([0.01, 0.01, 0.01]),
    high=torch.tensor([1.5, 0.5, 0.5])
)

In [7]:
import pickle

In [8]:
with open("./data/M1_dataset.pkl", "rb") as handle:
    xs_train = pickle.load(handle)

with open("./data/M1_params.pkl", "rb") as handle:
    thetas_train = pickle.load(handle)

In [9]:
with open('./data/model1.pkl', 'rb') as f:
    true_dataset = pickle.load(f)

In [10]:
import src.utils
import time

In [11]:
xs_train=src.utils.add_poisson_noise(xs_train)

In [12]:
engine = SBIEngine(density_estimator='maf')

In [14]:
start_time = time.time()
npe_post_1k, npe_samples_1k = engine.run_npe(
    obs_data = true_dataset[0]['poisson'],
    prior = sbi_prior,
    thetas = torch.tensor(thetas_train[:1000],dtype=torch.float32),
    xs = torch.tensor(xs_train[:1000],dtype=torch.float32),
    use_lstm=True,
    batch_size=64
)
end_time = time.time()
print(f"Done in {end_time - start_time:.2f} seconds")

  thetas = torch.tensor(thetas_train[:1000],dtype=torch.float32),


[*] Running NPE (use_lstm=True) with batch size 64...
 Neural network successfully converged after 106 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 106
        Best validation performance: -3.3468
        -------------------------
        


10740it [00:00, 138712.74it/s]           

Done in 97.97 seconds





In [15]:
with open("./results/Model1/NPE_LSTM/M1_1k_maf_posterior.pkl", "wb") as handle:
    pickle.dump(npe_post_1k, handle)

In [16]:
start_time = time.time()
npe_post_10k, npe_samples_10k = engine.run_npe(
    obs_data = true_dataset[0]['poisson'],
    prior = sbi_prior,
    thetas = torch.tensor(thetas_train[:10000],dtype=torch.float32),
    xs = torch.tensor(xs_train[:10000],dtype=torch.float32),
    use_lstm=True,
    batch_size=128
)
end_time = time.time()
print(f"Done in {end_time - start_time:.2f} seconds")

[*] Running NPE (use_lstm=True) with batch size 128...


  thetas = torch.tensor(thetas_train[:10000],dtype=torch.float32),


 Neural network successfully converged after 201 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 201
        Best validation performance: -6.0694
        -------------------------
        


10039it [00:00, 164954.51it/s]           

Done in 1682.96 seconds





In [None]:
start_time = time.time()
npe_post_100k, npe_samples_100k = engine.run_npe(
    obs_data = true_dataset[0]['poisson'],
    prior = sbi_prior,
    thetas = torch.tensor(thetas_train[:100000],dtype=torch.float32),
    xs = torch.tensor(xs_train[:100000],dtype=torch.float32),
    use_lstm=True,
    batch_size=256
)
end_time = time.time()
print(f"Done in {end_time - start_time:.2f} seconds")

[*] Running NPE (use_lstm=True) with batch size 256...


  thetas = torch.tensor(thetas_train[:100000],dtype=torch.float32),


 Training neural network. Epochs trained: 1

In [None]:
with open("./results/Model1/NPE_LSTM/M1_100k_maf_posterior.pkl", "wb") as handle:
    pickle.dump(npe_post_100k, handle)

In [None]:
engine = SBIEngine(density_estimator='nsf')

In [None]:
start_time = time.time()
npe_post_1k_nsf, npe_samples_1k = engine.run_npe(
    obs_data = true_dataset[0]['poisson'],
    prior = sbi_prior,
    thetas = torch.tensor(thetas_train[:1000],dtype=torch.float32),
    xs = torch.tensor(xs_train[:1000],dtype=torch.float32),
    use_lstm=False,
    batch_size=64
)
end_time = time.time()
print(f"Done in {end_time - start_time:.2f} seconds")

In [None]:
with open("./results/Model1/NPE/M1_1k_nsf_posterior.pkl", "wb") as handle:
    pickle.dump(npe_post_1k_nsf, handle)

In [18]:
start_time = time.time()
npe_post_10k_nsf, npe_samples_10k = engine.run_npe(
    obs_data = true_dataset[0]['poisson'],
    prior = sbi_prior,
    thetas = torch.tensor(thetas_train[:10000],dtype=torch.float32),
    xs = torch.tensor(xs_train[:10000],dtype=torch.float32),
    use_lstm=False,
    batch_size=64
)
end_time = time.time()
print(f"Done in {end_time - start_time:.2f} seconds")

  thetas = torch.tensor(thetas_train[:10000],dtype=torch.float32),


[*] Running NPE (use_lstm=False) with batch size 64...
 Neural network successfully converged after 212 epochs.
        -------------------------
        ||||| ROUND 1 STATS |||||:
        -------------------------
        Epochs trained: 212
        Best validation performance: -7.5712
        -------------------------
        


100%|██████████| 10000/10000 [00:00<00:00, 172547.59it/s]

Done in 288.64 seconds





In [None]:
with open("./results/Model1/NPE/M1_10k_nsf_posterior.pkl", "wb") as handle:
    pickle.dump(npe_post_10k_nsf, handle)

In [None]:
start_time = time.time()
npe_post_100k_nsf, npe_samples_100k = engine.run_npe(
    obs_data = true_dataset[0]['poisson'],
    prior = sbi_prior,
    thetas = torch.tensor(thetas_train[:100000],dtype=torch.float32),
    xs = torch.tensor(xs_train[:100000],dtype=torch.float32),
    use_lstm=False,
    batch_size=64
)
end_time = time.time()
print(f"Done in {end_time - start_time:.2f} seconds")

[*] Running NPE (use_lstm=False) with batch size 64...


  thetas = torch.tensor(thetas_train[:100000],dtype=torch.float32),


 Training neural network. Epochs trained: 10

In [None]:
with open("./results/Model1/NPE/M1_100k_nsf_posterior.pkl", "wb") as handle:
    pickle.dump(npe_post_100k_nsf, handle)