In [22]:
import pandas as pd
import numpy as np
import pyro
import torch
from simple_model.SimpleModel import cols, SimpleModel
from pyro.infer import NUTS, MCMC
from max_likelihood.utils.ObservationalDataset import ObservationalDataset
from torch.utils.data.sampler import SubsetRandomSampler

In [19]:
increment_factor = (20,20)
path = "../model-data/simple-observational-data-40000-0.csv"
simulator_model = SimpleModel(learn_initial_state=False, increment_factor=increment_factor,
                                  phy_pol=True, rho=None)
nuts_kernel = NUTS(simulator_model.model, jit_compile=False)
observational_dataset = ObservationalDataset(path, columns=cols)
validation_split = 0.20
test_split = 0.20
shuffle_dataset = True
dataset_size = len(observational_dataset)
indices = list(range(dataset_size))
split_val = int(np.floor(validation_split * dataset_size))
split_test = int(np.floor(test_split * dataset_size))
if shuffle_dataset:
    np.random.shuffle(indices)
train_indices, val_indices, test_indices = indices[split_val + split_test:], indices[:split_val], indices[
                                                                                                  split_val:split_val + split_test]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

In [23]:
train_loader = torch.utils.data.DataLoader(observational_dataset, batch_size=16, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(observational_dataset, batch_size=16, sampler=valid_sampler)
test_loader = torch.utils.data.DataLoader(observational_dataset, batch_size=16, sampler=test_sampler)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=1000, num_chains=1)
for x in test_loader:
    mcmc.run(x.float())
    break
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
for site, values in hmc_samples.items():
    print("Site: {}".format(site))
    print(values, "\n")

Sample: 100%|██████████| 2000/2000 [07:01,  4.74it/s, step size=6.54e-02, acc. prob=0.837]

Site: s_0
[[[ 66.610275 110.9284  ]
  [ 69.52102  107.605354]
  [ 76.731514 114.0084  ]
  ...
  [ 69.534676 111.44975 ]
  [ 66.88491   98.42541 ]
  [ 64.25114  100.04121 ]]

 [[ 67.5177   106.00569 ]
  [ 70.65374  106.590256]
  [ 75.662926 117.38046 ]
  ...
  [ 69.138306 105.85377 ]
  [ 68.19785  101.15019 ]
  [ 63.53312   97.110504]]

 [[ 66.90245  105.22755 ]
  [ 69.573364 111.33925 ]
  [ 77.12877  113.67287 ]
  ...
  [ 70.02957  111.44688 ]
  [ 69.11413  101.95673 ]
  [ 62.533054 105.00507 ]]

 ...

 [[ 67.08204  111.70907 ]
  [ 70.150345 110.684616]
  [ 77.02178  116.05847 ]
  ...
  [ 68.486595 114.34624 ]
  [ 69.6109   108.9204  ]
  [ 63.07623   99.313446]]

 [[ 66.6338   104.19663 ]
  [ 72.08465  109.65518 ]
  [ 76.222206 109.57951 ]
  ...
  [ 68.3846   111.802734]
  [ 68.13355  101.62384 ]
  [ 64.59784  108.34359 ]]

 [[ 68.07424  110.644615]
  [ 69.41639  106.81993 ]
  [ 75.43365  118.14082 ]
  ...
  [ 67.72654  102.89365 ]
  [ 68.46876  110.482315]
  [ 62.896885 112.209755]]] 


