# MCMC Example notebook

Code to run a sample of the EBS MCMC functionality

In [None]:
import warnings
import time
import os
import yaml

import corner
import ebs

import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from ebs.error_budget import ErrorBudget


warnings.filterwarnings('ignore')

### Load the config and define the ErrorBudget object

In [None]:
# Create ErrorBudget object with the parameters in the config.
package_path = Path(ebs.__file__)
root_dir = package_path.parent.parent

config_file = os.path.join(root_dir, "inputs/parameters.yml")

with open(config_file, 'r') as file:
    data = yaml.safe_load(file)

# Update paths to point to sample inputs, output, and temporary folder
data["paths"]["output"] = os.path.join(root_dir, "output")
data["paths"]["input"] = os.path.join(root_dir, "inputs")
data["paths"]["temporary"] = os.path.join(root_dir, "temp")

# To overwrite the original file
with open(config_file, 'w') as file:
    yaml.dump(data, file)

eb = ErrorBudget(config_file)
n_walkers = data["mcmc"]["nwalkers"]
n_steps = data["mcmc"]["nsteps"]
n_cpu = data["mcmc"]["ncpu"]

eb = ErrorBudget(config_file)
print(f"Running MCMC run with {n_walkers} walkers, {n_steps} steps, and using {n_cpu} cores")

### Run the MCMC code

Note this can take quite a long time and is computationally expensive

In [None]:
start = time.time()
mcmc_res = eb.run_mcmc()
stop = time.time()

In [None]:
print(f"MCMC run took {stop-start} seconds")
chain = mcmc_res.get_chain(flat=True, discard=200)
samples = mcmc_res.get_chain(discard=200)

In [None]:
ndim = samples.shape[-1] # nsteps, nwalker, nparameters

ax_labels = ["Dark Current", "WFS&C Factor", "Contrast", "Throughput"]
fig = corner.corner(chain, labels = ax_labels, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_fmt=".1E")
axes = np.array(fig.axes).reshape((ndim, ndim))

for (i,j), ax in np.ndenumerate(axes):
    ax.title.set_size(8)

In [None]:
ndim=4
num_walkers = len(samples[0])
alphas = np.linspace(0, 1, num_walkers)

for i in range(ndim):
    fig, ax = plt.subplots()
    ax.plot(samples[:, :, i], alpha=0.3)
    ax.set_xlim(0, len(samples))
    plt.show()