# **Example**
---
test STAN using `cmdstanpy`

In [None]:
# Load packages used in this notebook
import os
import json
import shutil
import urllib.request
import pandas as pd
import numpy as np

from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

In [None]:
# Install package CmdStanPy
!pip install --upgrade cmdstanpy

In [None]:
# Install pre-built CmdStan binary
# (faster than compiling from source via install_cmdstan() function)
tgz_file = 'colab-cmdstan-2.36.0.tar.gz'
tgz_url = 'https://github.com/stan-dev/cmdstan/releases/download/v2.36.0/colab-cmdstan-2.36.0.tgz'
if not os.path.exists(tgz_file):
    urllib.request.urlretrieve(tgz_url, tgz_file)
    shutil.unpack_archive(tgz_file)

# Specify CmdStan location via environment variable
os.environ['CMDSTAN'] = './cmdstan-2.36.0'
# Check CmdStan path
from cmdstanpy import CmdStanModel, cmdstan_path
cmdstan_path()

In [None]:
cont_zp = 700.0
cont_slope = 5.0
amplitude = 150.0
width = 0.5
center = 5.0

# Next, a grid of wavelenght channels (assumed to have no uncertainty)
wave = np.linspace(0,10,100)

# The 'true' observations
flux = amplitude*np.exp(-0.5*np.power(wave-center,2)/width**2) + \
       cont_zp + cont_slope*wave

# The actual observations = true observations + Poisson noise
obs_flux = np.random.poisson(flux)

In [None]:
%matplotlib inline
from matplotlib.pyplot import subplots,plot,step,xlabel,ylabel,show,subplots
fig,ax = subplots(1,1)
ax.plot(wave, flux, 'r-')
ax.step(wave, obs_flux, color='k')
ax.set_xlabel('Wavelength (Angstroms)')
ax.set_ylabel('Counts')

In [None]:
model_string = '/content/drive/MyDrive/Github_rep/alphaxbio/stan_compute/example.stan'
with open(model_string, 'r') as fd:
        print('\n'.join(fd.read().splitlines()))

In [None]:
idata = dict(N=len(wave), wave=wave, flux=obs_flux)

In [None]:
from cmdstanpy import cmdstan_path, CmdStanModel
import cmdstanpy

# Compile example
example_model = CmdStanModel(stan_file = model_string, force_compile=True)

In [None]:
# Condition on example data bernoulli.data.json
example_fit = example_model.sample(data = idata, seed = 123);

In [None]:
# Print a summary of the posterior sample
example_fit.summary()

In [None]:
def Gauss(x, amp, center, width, cont, slope):
  return amp*np.exp(-0.5*np.power(x-center,2)/width**2) + cont +\
        slope*x

mamp = np.median(example_fit.stan_variable('amp'))
mcont = np.median(example_fit.stan_variable('cont'))
mslope = np.median(example_fit.stan_variable('slope'))
mcenter = np.median(example_fit.stan_variable('center'))
mwidth = np.median(example_fit.stan_variable('width'))
xx = np.linspace(wave.min(), wave.max(), 100)
yy = Gauss(xx, mamp, mcenter,mwidth,mcont, mslope)
ax.plot(xx, yy, '-', color='b')
yys = [Gauss(xx, example_fit.stan_variable('amp')[ii*10], example_fit.stan_variable('center')[ii*10],
             example_fit.stan_variable('width')[ii*10], example_fit.stan_variable('cont')[ii*10],
             example_fit.stan_variable('slope')[ii*10]) \
             for ii in range(100)]
sdy = np.std(yys, axis=0)
ax.fill_between(xx, yy-3*sdy, yy+3*sdy, facecolor='k', alpha=0.4, zorder=10)
fig