In [10]:
import pandas as pd

In [27]:
import jax
from jax import vmap
import jax.numpy as np
import jax.scipy as sp
import jax.random as random

import numpyro
import numpyro.distributions as nd
from numpyro.infer import MCMC, NUTS

In [4]:
import optax
import haiku as hk

In [8]:
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az

import palettes

palettes.set_theme()

In [12]:
D = pd.read_csv("../data/elections.csv")
D.loc[:, "proportion"] = D.dem / (D.dem + D.rep)
D

Unnamed: 0,year,dem,rep,region,state,proportion
0,1976,44058,71555,Mountain West,AK,0.381082
1,1980,41842,86112,Mountain West,AK,0.327008
2,1984,62007,138377,Mountain West,AK,0.309441
3,1988,72584,119251,Mountain West,AK,0.378367
4,1992,78294,102000,Mountain West,AK,0.434257
...,...,...,...,...,...,...
545,2000,60481,147947,Mountain West,WY,0.290177
546,2004,70776,167629,Mountain West,WY,0.296873
547,2008,82868,164958,Mountain West,WY,0.334380
548,2012,69286,170962,Mountain West,WY,0.288394


In [13]:
years = pd.to_datetime(D.year)
years = (years - years.min()) / pd.Timedelta(1)

D.loc[:, ("year_numerical")] = years
D.loc[:, ("region_idxs")] = D["region"].apply(
    lambda x: list(D.region.unique()).index(x)
)
D.loc[:, ("state_idxs")] = D["state"].apply(
    lambda x: list(D.state.unique()).index(x)
)
D.loc[:, ("time_idxs")] = D["year_numerical"].apply(
    lambda x: list(D.year_numerical.unique()).index(x)
)
D = D.sort_values(["region", "state", "year_numerical"])
D

Unnamed: 0,year,dem,rep,region,state,proportion,year_numerical,region_idxs,state_idxs,time_idxs
154,1976,1014714,1183958,Border South,IN,0.461512,0.0,9,14,0
155,1980,844197,1255656,Border South,IN,0.402027,4.0,9,14,1
156,1984,841481,1377230,Border South,IN,0.379266,8.0,9,14,2
157,1988,860643,1297763,Border South,IN,0.398740,12.0,9,14,3
158,1992,848420,989375,Border South,IN,0.461651,16.0,9,14,4
...,...,...,...,...,...,...,...,...,...,...
512,2000,1247652,1108864,West Coast,WA,0.529448,24.0,3,46,6
513,2004,1510201,1304894,West Coast,WA,0.536465,28.0,3,46,7
514,2008,1750848,1229216,West Coast,WA,0.587520,32.0,3,46,8
515,2012,1755396,1290670,West Coast,WA,0.576283,36.0,3,46,9


In [80]:
D[["state", "year_numerical", "proportion"]]

Unnamed: 0,state,year_numerical,proportion
154,IN,0.0,0.461512
155,IN,4.0,0.402027
156,IN,8.0,0.379266
157,IN,12.0,0.398740
158,IN,16.0,0.461651
...,...,...,...
512,WA,24.0,0.529448
513,WA,28.0,0.536465
514,WA,32.0,0.587520
515,WA,36.0,0.576283


In [158]:
Y = D[["state", "year_numerical", "proportion"]].pivot_table(
    index="state", values="proportion", columns="year_numerical"
)

In [159]:
Y.head()

year_numerical,0.0,4.0,8.0,12.0,16.0,20.0,24.0,28.0,32.0,36.0,40.0
state,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
AK,0.381082,0.327008,0.309441,0.378367,0.434257,0.395715,0.320631,0.367737,0.389352,0.426847,0.416143
AL,0.566667,0.493237,0.387366,0.402544,0.461789,0.462661,0.424145,0.371022,0.391091,0.387838,0.356259
AR,0.650228,0.496803,0.387724,0.428084,0.599923,0.593528,0.471993,0.450643,0.398283,0.378456,0.357149
AZ,0.413867,0.317879,0.328833,0.392565,0.486981,0.512238,0.467174,0.44725,0.456861,0.453866,0.4811
CA,0.490822,0.405291,0.417755,0.481927,0.585167,0.572163,0.56203,0.550413,0.622784,0.618728,0.661282


In [160]:
X = np.tile(np.array(Y.columns), (Y.shape[0], 1))
Y = Y.values

In [326]:
class LSTM(hk.Module):
    def __init__(self, name='lstm'):
        super().__init__(name=name)
        self._w = hk.Linear(4, True, name="w")
        self._u = hk.Linear(4, False, name="u")

    def __call__(self, x):
        outs = [None] * x.shape[-1]
        h, c = np.zeros((x.shape[0], 1)), np.zeros((x.shape[0], 1))
        for i in range(x.shape[-1]):
            o, h, c = self._call(x[:, i, None], h, c)
            outs[i] = o
        return np.hstack(outs)

    def _call(self, x_t, h_t, c_t):                        
        iw, gw, fw, ow = np.split(self._w(x_t), indices_or_sections=4, axis=-1)        
        iu, gu, fu, ou = np.split(self._u(h_t), indices_or_sections=4, axis=-1)        
        i = jax.nn.sigmoid(iw + iu)
        f = jax.nn.sigmoid(fw + fu)
        g = np.tanh(gw + gu)
        o = jax.nn.sigmoid(ow + ou)
        c = f * c_t + i * g
        h =  o * np.tanh(c)
        return o, h, c


def _lstm(x):
    module = LSTM()
    return module(x)

In [327]:
model = hk.transform(_lstm)
model = hk.without_apply_rng(model)

key = jax.random.PRNGKey(42)
params = model.init(key, X[[0], :])

In [328]:
model.apply(x=X[[0, 1], :], params=params)

DeviceArray([[0.5       , 0.9952112 , 0.99997604, 0.9999999 , 1.        ,
              1.        , 1.        , 1.        , 1.        , 1.        ,
              1.        ],
             [0.5       , 0.9952112 , 0.99997604, 0.9999999 , 1.        ,
              1.        , 1.        , 1.        , 1.        , 1.        ,
              1.        ]], dtype=float32)

In [330]:
class BetaLSTM(hk.Module):
    def __init__(self, name='beta_lstm'):
        super().__init__(name=name)
        self._h = LSTM()
        self._mu = hk.get_parameter('mu', [], init=np.ones)
        self._kappa = hk.get_parameter('kappa', [], init=np.zeros)

    def __call__(self, x):
        o = self._h(x)        
        mu = sp.special.expit(self._mu * o)
        kappa = np.exp(self._kappa)
        be = dist.Beta(mu * kappa, (1.0 - mu) * kappa)
        return be
        

def _beta_lstm(x):
    module = BetaLSTM()
    return module(x)

In [332]:
x = X[[0], :]
y = Y[[0], :]

In [336]:
@jax.jit
def nll(params: hk.Params):
    beta = model.apply(x=x, params=params)
    ll = np.sum(beta.log_prob(y))
    return -ll

In [337]:
def update(params, opt_state):
    grads = jax.grad(nll)(params)  
    updates, new_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_state

In [350]:
optimizer = optax.adam(0.001)
opt_state = optimizer.init(params)

In [351]:
model = hk.transform(_beta_lstm)
model = hk.without_apply_rng(model)

key = jax.random.PRNGKey(42)
params = model.init(key, x)

In [352]:
params

FlatMapping({
  'beta_lstm': FlatMapping({
                 'mu': DeviceArray(1., dtype=float32),
                 'kappa': DeviceArray(0., dtype=float32),
               }),
  'beta_lstm/~/lstm/~/w': FlatMapping({
                            'w': DeviceArray([[-0.5389954,  0.8341133, -0.8763848,  1.3341686]], dtype=float32),
                            'b': DeviceArray([0., 0., 0., 0.], dtype=float32),
                          }),
  'beta_lstm/~/lstm/~/u': FlatMapping({
                            'w': DeviceArray([[ 0.6433483 , -0.11852746,  0.88966376, -0.33986157]], dtype=float32),
                          }),
})

In [359]:
for step in range(10000):
    params, opt_state = update(params, opt_state)

In [354]:
params

FlatMapping({
  'beta_lstm': FlatMapping({
                 'kappa': DeviceArray(1.0232631, dtype=float32),
                 'mu': DeviceArray(0.3808028, dtype=float32),
               }),
  'beta_lstm/~/lstm/~/u': FlatMapping({
                            'w': DeviceArray([[ 1.9028696 ,  0.03698345,  2.092364  , -1.9447113 ]], dtype=float32),
                          }),
  'beta_lstm/~/lstm/~/w': FlatMapping({
                            'b': DeviceArray([ 1.5069709 ,  1.2371356 ,  1.3585547 , -0.98723483], dtype=float32),
                            'w': DeviceArray([[ 0.68274695,  1.5741028 ,  0.54551226, -0.18318608]], dtype=float32),
                          }),
})

In [355]:
beta = model.apply(x=x, params=params)

In [356]:
y

array([[0.38108171, 0.32700814, 0.30944087, 0.37836683, 0.43425738,
        0.39571497, 0.32063051, 0.36773717, 0.38935215, 0.4268471 ,
        0.41614345]])

In [357]:
ss = beta.sample(random.PRNGKey(1), (1000,))

In [358]:
np.mean(ss, axis=0)

DeviceArray([[0.51533586, 0.51190114, 0.49628186, 0.4987158 , 0.49403563,
              0.49306673, 0.5045116 , 0.5150377 , 0.49857137, 0.47996184,
              0.49744257]], dtype=float32)