# Example notebook

This example is based on the aspirin test case. The code here is mostly adapted from the existing test infrastructure.

Here is the code copied from `<>/deepppl/deepppl/tests/inference`.

```
from .harness import MCMCTest
import numpy as np
from pprint import pprint

# Warning: Generated quantities does not work with numpyro on this example
# (In place mutation of array)

def test_aspirin():
    data = {}
    data['y'] = [2.77, 2.50, 1.84, 2.56, 2.31, -1.15]
    data['s'] = [1.65, 1.31, 2.34, 1.67, 1.98, 0.90]
    data['N'] = len(data['y'])
    data['mu_loc'] = np.mean(data['y'])
    data['mu_scale'] = 5 * np.std(data['y'])
    data['tau_scale'] = 2.5 * np.std(data['y'])
    data['tau_df'] = 4
     
    t_aspirin = MCMCTest(
        name='aspirin',
        model_file='deepppl/tests/good/aspirin.stan',
        data=data
    )

    return t_aspirin.run()
    
if __name__ == "__main__":
    pprint(test_aspirin())
```

The file `harness.py` has the data class that controls the test (`MCMCTest`), including the compile of stan, the execution through pyro and the comparison with the vanilla stan implementation.

Below is an example based on what `harness` implements.



## Setup the data

In [None]:
import numpy as np

data = {}
data['y'] = [2.77, 2.50, 1.84, 2.56, 2.31, -1.15]
data['s'] = [1.65, 1.31, 2.34, 1.67, 1.98, 0.90]
data['N'] = len(data['y'])
data['mu_loc'] = np.mean(data['y'])
data['mu_scale'] = 5 * np.std(data['y'])
data['tau_scale'] = 2.5 * np.std(data['y'])
data['tau_df'] = 4

## Setup model

In [None]:
# Configuration

from dataclasses import dataclass, field
import time
@dataclass
class Config:
    iterations: int = 1000
    warmups: int = 10
    chains: int = 4
    thin: int = 2
    file: str = "../deepppl/deepppl/tests/good/aspirin.stan"
    
@dataclass
class TimeIt:
    name: str

    def __enter__(self):
        self.start = time.perf_counter()

    def __exit__(self, *exc_info):
        print(f"{self.name} took {time.perf_counter() - self.start}s to complete.")

import numpyro
import jax

# Request use the GPU
numpyro.set_platform("gpu")
print(f"jax version: {jax.__version__}")
print(f"numpyro version: {numpyro.__version__}")
print(f"jax target backend: {jax.config.FLAGS.jax_backend_target}")
print(f"jax target device: {jax.lib.xla_bridge.get_backend().platform}")

from deepppl import PyroModel, NumPyroModel

In [None]:
with open(Config.file, 'r') as f:
    print(f.read())

## Pyro run

In [None]:
with TimeIt('Pyro model obj creation'):
    pyro_model = PyroModel(model_file=Config.file)

In [None]:
with TimeIt('Pyro model configuration'):
    pyro_mcmc = pyro_model.mcmc(Config.iterations, Config.warmups, num_chains=Config.chains, thin=Config.thin)

In [None]:
with TimeIt('Pyro model run'):
    pyro_mcmc.run(**data)

In [None]:
with TimeIt('Pyro model get samples'):
    pyro_samples = pyro_mcmc.get_samples()

## Numpyro run

In [None]:
with TimeIt('Numpyro model obj creation'):
    numpyro_model = NumPyroModel(model_file=Config.file)

In [None]:
with TimeIt('Numpyro model configuration'):
    numpyro_mcmc = numpyro_model.mcmc(Config.iterations, Config.warmups, num_chains=Config.chains, thin=Config.thin)

In [None]:
#Run with timer
import cProfile

p = cProfile.Profile()
p.enable()
numpyro_mcmc.run(**data)
p.disable


In [None]:
#Run with stats collection
import pstats
import io

stats = pstats.Stats(p).sort_stats(pstats.SortKey.TIME)
stats.print_stats()

In [None]:
with TimeIt('Numpyro model run'):
    numpyro_mcmc.run(**data)

In [None]:
with TimeIt('Numpyro model get samples'):
    numpyro_samples = numpyro_mcmc.get_samples()