In [None]:
import sys

sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from models import RandomWalkPoissonModel
from experiments import NeuralCoalMiningExperiment
from data import coal_mining_data

# Model and Experiment

In [None]:
model = RandomWalkPoissonModel()

In [None]:
experiment = NeuralCoalMiningExperiment(model)

In [None]:
history = experiment.run(epochs=10, iterations_per_epoch=1000, batch_size=32)

# Evaluation

In [None]:
posterior_samples = experiment.amortizer.sample(
    np.log1p(coal_mining_data["disasters"][None, :, None]), 1000
)

In [None]:
local_samples = np.expm1(posterior_samples["local_samples"])
post_mean = local_samples.mean(axis=0)
post_std = local_samples.std(axis=0)

In [None]:
time = coal_mining_data["year"]
plt.figure(figsize=(14, 8))
plt.plot(time, post_mean, alpha=0.9, color="maroon")
plt.fill_between(
    time,
    post_mean + post_std,
    post_mean - post_std,
    alpha=0.3,
    label="Neural",
    edgecolor="none",
    color="maroon",
)

plt.bar(
    time,
    coal_mining_data["disasters"],
    align="center",
    facecolor="gray",
    alpha=0.6,
    label="Accident counts",
)

plt.ylabel("Accident rate", fontsize=28)
plt.xlabel("Year", fontsize=28)
plt.tick_params(axis="both", which="major", length=10, labelsize=24)

plt.legend(fontsize=24)
sns.despine()
plt.tight_layout()