In [None]:
from IPython.display import Image
import numpy as np
import pymc3 as pm
import seaborn as sns
import theano as th
import theano.tensor as tt

import matplotlib as mpl
import matplotlib.pyplot as plt
from functools import reduce

from plotutils import addtxt
mpl.style.use(['./scripts/theme_bw.mplstyle', './scripts/presentation.mplstyle'])

# https://docs.pymc.io/notebooks/getting_started.html

In [None]:
disaster_data = np.ma.masked_values([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
                            3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
                            2, 2, 3, 4, 2, 1, 3, -999, 2, 1, 1, 1, 1, 3, 0, 0,
                            1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
                            0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
                            3, 3, 1, -999, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
                            0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], value=-999)
years = np.arange(1851, 1962)

plt.plot(years, disaster_data, 'o', markersize=8);
plt.ylabel("Disaster count")
plt.xlabel("Year");

\begin{align}
D_t &\equiv \textrm{number of disasters in year t} &\sim \text{Poisson}(r_t) \\
r_t &\equiv \textrm{rate parameter} &= \begin{cases}
e & t\le s \\
l & t > s
\end{cases}\\
s &\equiv \textrm{switchpoint when rates changed} \\
e,l &\equiv \textrm{accident rates}
\end{align}

In [None]:
with pm.Model() as model:
    switch = pm.DiscreteUniform('switch', lower=years.min(), upper = years.max())
    
    # prior for early/late rates
    early_rate = pm.Uniform('early_rate', lower=0, upper = 10)
    late_rate   = pm.Uniform('late_rate', lower=0, upper =10)
    
    # Allocate appropriate rate for late/early periods
    rate = pm.math.switch(switch >= years, early_rate, late_rate)
    
    disasters = pm.Poisson('disasters', rate, observed = disaster_data)
    

In [None]:
with model:
    trace = pm.sample(10000, tune=15000, progressbar=True)

In [None]:
pm.traceplot(trace);

In [None]:
pm.forestplot(trace, varnames=['early_rate', 'late_rate', 'disasters_missing']);plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(9,9))

ax.plot(years, disaster_data, '.')
ax.set_ylabel('Number of accidents', fontsize=18)
ax.set_ylabel('Year', fontsize=18)
ax.vlines(trace['switch'].mean(), disaster_data.min(), disaster_data.max())

avg_disasters = np.zeros_like(disaster_data, dtype=np.float)
for i,year in enumerate(years):
    idx = year < trace['switch']  # which samples have switch > year
    avg_disasters[i] = (trace['early_rate'][idx].sum() + trace['late_rate'][~idx].sum())/(len(trace)*trace.nchains)

sp_hpd = pm.hpd(trace['switch'])
plt.fill_betweenx(y=[disaster_data.min(), disaster_data.max()],x1=sp_hpd[0], x2=sp_hpd[1], alpha=0.5, color='C1');
ax.plot(years, avg_disasters, 'k--', lw=2)