In [None]:
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import bayesflow as beef

from models import RandomWalkPoissonModel
from experiments import NeuralCoalMiningExperiment, BayesLoopCoalMiningExperiment
from data import coal_mining_data

In [None]:
# gpu setting and checking
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
print(tf.config.list_physical_devices('GPU'))

# Neural Experiment

In [None]:
model = RandomWalkPoissonModel()

In [None]:
neural_experiment = NeuralCoalMiningExperiment(model)

In [None]:
history = neural_experiment.run(epochs=25, iterations_per_epoch=1000, batch_size=32)

In [None]:
h = beef.diagnostics.plot_losses(neural_experiment.trainer.loss_history.get_plottable())

# Bayesloop Experiment

In [None]:
bayesloop_experiment = BayesLoopCoalMiningExperiment()

In [None]:
bl_post_means, bl_post_stds = bayesloop_experiment.run(coal_mining_data)

# Evaluation

In [None]:
posterior_samples = neural_experiment.amortizer.sample(
    np.log1p(coal_mining_data["disasters"][None, :, None]), 1000
)

In [None]:
local_samples = np.expm1(posterior_samples["local_samples"])
post_mean = local_samples.mean(axis=0)
post_std = local_samples.std(axis=0)

In [None]:
EMPIRIC_COLOR = '#1F1F1F'
NEURAL_COLOR = '#852626'
COMPARISON_COLOR = '#133a76'

In [None]:
time = coal_mining_data["year"]
plt.figure(figsize=(14, 8))
plt.plot(time, post_mean, alpha=0.9, color=NEURAL_COLOR)
plt.fill_between(
    time,
    post_mean + post_std,
    post_mean - post_std,
    alpha=0.6,
    label="Neural",
    edgecolor="none",
    color=NEURAL_COLOR
)

plt.plot(time, bl_post_means, alpha=0.9, color=COMPARISON_COLOR)
plt.fill_between(
    time,
    bl_post_means + bl_post_stds,
    bl_post_means - bl_post_stds,
    alpha=0.6,
    label='BayesLoop',
    edgecolor="none",
    color=COMPARISON_COLOR
)

plt.bar(
    time,
    coal_mining_data["disasters"],
    align="center",
    facecolor="gray",
    alpha=0.6,
    label="Accident counts",
)

plt.ylabel("Accident rate", fontsize=28)
plt.xlabel("Year", fontsize=28)
plt.tick_params(axis="both", which="major", length=10, labelsize=24)

plt.legend(fontsize=24)
sns.despine()
plt.tight_layout()

plt.savefig('../plots/coal_mining_benchmark.png', dpi=300)