# Implementing the BYM2 Model in Stan

## Notebook Setup

Import all libraries, load the NYC study data.

In [None]:
# import all libraries used in this notebook
import os
import numpy as np
import pandas as pd
import geopandas as gpd
import libpysal as sa
import matplotlib
import splot as splt
from splot.libpysal import plot_spatial_weights 
import plotnine as p9
import arviz as az
%matplotlib inline

from cmdstanpy import CmdStanModel, cmdstan_path, cmdstan_version

# suppress plotnine warnings
import warnings
warnings.filterwarnings('ignore')

# setup plotnine look and feel
p9.theme_set(
  p9.theme_grey() + 
  p9.theme(text=p9.element_text(size=10),
        plot_title=p9.element_text(size=14),
        axis_title_x=p9.element_text(size=12),
        axis_title_y=p9.element_text(size=12),
        axis_text_x=p9.element_text(size=8),
        axis_text_y=p9.element_text(size=8)
       )
)
xlabels_90 = p9.theme(axis_text_x = p9.element_text(angle=90, hjust=1))

map_theme =  p9.theme(figure_size=(7,6),
                 axis_text_x=p9.element_blank(),
                 axis_ticks_x=p9.element_blank(),
                 axis_text_y=p9.element_blank(),
                 axis_ticks_y=p9.element_blank())

In [None]:
nyc_geodata = gpd.read_file(os.path.join('data', 'nyc_study.geojson'))
nyc_geodata.columns

## Disconnected Components (and islands)

New York city consists of several islands; only the Bronx is part of the mainland; Brooklyn and Queens are part of Long Island, plus smaller islands City Island, Roosevelt Island, and the Rockaways.

*This is a problem for the ICAR model, which operates on a fully connected graph (single component)*

* For the NYC analysis paper, we hand edited the map of NYC (in R) to create a fully connected network graph.

* For this notebook, we will restrict out attention to Brooklyn, the largest borough in NYC, which is a single network component.

In [None]:
brooklyn_geodata = nyc_geodata[nyc_geodata['BoroName']=='Brooklyn'].reset_index(drop=True)
brooklyn_nbs = sa.weights.Rook.from_dataframe(brooklyn_geodata, geom_col='geometry')
plot_spatial_weights(brooklyn_nbs, brooklyn_geodata) 

In [None]:
print(f'number of components: {brooklyn_nbs.n_components}')
print(f'islands? {brooklyn_nbs.islands}')
print(f'max number of neighbors per node: {brooklyn_nbs.max_neighbors}')
print(f'mean number of neighbors per node: {brooklyn_nbs.mean_neighbors}')

## From ICAR to BYM2

* Combines both ICAR component $\phi$ and ordinary random effects $\theta$ as
$$\left( (\sqrt{\, {\rho} / s}\, \ )\,\phi^* + (\sqrt{1-\rho})\,\theta^* \right) \sigma $$

* Parameter $\rho$ answers the question:  how much of the observed variance is spatial?

* Don't need to run analysis, e.g. Moran's I - the model sorts it out for you.

## BYM2 Model:  `bym2.stan`

This file is in directory `stan/bym2.stan`.

In [None]:
bym2_model_file = os.path.join('stan', 'bym2.stan')

with open(bym2_model_file, 'r') as file:
    contents = file.read()
    print(contents)

## Data Prep

### Get edgeset

- Compute this automatically from `nyc_geodata` spatial geometry component
  + Python package `libpysal`
  + R package `spdep`

In [None]:
brooklyn_nbs_adj =  brooklyn_nbs.to_adjlist(remove_symmetric=True)
print(type(brooklyn_nbs_adj))
brooklyn_nbs_adj.head(10)

In [None]:
# create np.ndarray from columns in adjlist, increment indices by 1
j1 = brooklyn_nbs_adj['focal'] + 1
j2 = brooklyn_nbs_adj['neighbor'] + 1
edge_pairs = np.vstack([j1, j2])
edge_pairs

### Compute scaling factor `tau`

Computed in R:  value 0.658

R Script:
```r
# computes the inverse of a sparse precision matrix
# sub-optimal implementation - better to use INLA
q_inv_dense <- function(Q, A = NULL) {
  Sigma <- Matrix::solve(Q)   ## need sparse matrix solver
  if (is.null(A))
    return(Sigma)
  else {
    A <- matrix(1,1, nrow(Sigma))
    W <- Sigma %*% t(A)
    Sigma_const <- Sigma - W %*% solve(A %*% W) %*% t(W)
    return(Sigma_const)
  }
}

get_scaling_factor = function(adj_list) {
    N = ncol(adj_list)
    # Build the adjacency matrix using edgelist
    adj_matrix = sparseMatrix(i=adj_list[1, ], j=adj_list[2, ], x=1, symmetric=TRUE)

    # Create ICAR precision matrix  (diag - adjacency): this is singular
    Q =  Diagonal(N, rowSums(adj_matrix)) - adj_matrix
    # Add a small jitter to the diagonal for numerical stability (optional but recommended)
    Q_pert = Q + Diagonal(N * max(diag(Q)) * sqrt(.Machine$double.eps)

    # Compute the diagonal elements of the covariance matrix
    Q_inv = q_inv_dense(Q_pert, adj_matrix)

    # Compute the geometric mean of the variances, which are on the diagonal of Q.inv
    return(exp(mean(log(diag(Q_inv)))))
}
```

#### Assemble the input data 

In [None]:
design_vars = np.array(['pct_pubtransit','med_hh_inc', 'traffic', 'frag_index'])

design_mat = brooklyn_geodata[design_vars].to_numpy()
design_mat[:, 1] = np.log(design_mat[:, 1])
design_mat[:, 2] = np.log(design_mat[:, 2])

pd.DataFrame(data=design_mat).describe()

In [None]:
bym2_data = {"N":brooklyn_geodata.shape[0],
             "y":brooklyn_geodata['count'].astype('int'),
             "E":brooklyn_geodata['kid_pop'].astype('int'),
             "K":4,
             "xs":design_mat,
             "N_edges": edge_pairs.shape[1],
             "neighbors": edge_pairs,
	     "tau":0.7
}

## Fitting the BYM2 Model on the Brooklyn data

#### Model is compiled (as needed) on instantiation

In [None]:
bym2_mod = CmdStanModel(stan_file=bym2_model_file)

#### Run the NUTS-HMC sampler, summarize results

In [None]:
bym2_fit = bym2_mod.sample(data=bym2_data)

bym2_summary = bym2_fit.summary()
bym2_summary.round(2).loc[
  ['beta_intercept', 'beta0', 'betas[1]', 'betas[2]', 'betas[3]', 'betas[4]', 'sigma', 'rho']]

## Model Comparison: BYM2 vs ICAR vs. ordinary random effects

#### ICAR model

In [None]:
pois_icar_mod = CmdStanModel(stan_file=os.path.join(
  'stan', 'poisson_icar.stan'))
pois_icar_fit = pois_icar_mod.sample(data=bym2_data)
pois_icar_summary = pois_icar_fit.summary()
pois_icar_summary.round(2).loc[
  ['beta_intercept', 'beta0', 'betas[1]', 'betas[2]', 'betas[3]', 'betas[4]', 'sigma']]

#### Ordinary random effects model

In [None]:
pois_re_mod = CmdStanModel(stan_file=os.path.join(
  'stan', 'poisson_re.stan'))
pois_re_fit = pois_re_mod.sample(data=bym2_data)
pois_re_summary = pois_re_fit.summary()
pois_re_summary.round(2).loc[
  ['beta_intercept', 'beta0', 'betas[1]', 'betas[2]', 'betas[3]', 'betas[4]', 'sigma']]

Which model provides a better fit (on the Brooklyn subset of the data)?

In [None]:
print('BYM2')
print(bym2_summary.round(2).loc[['sigma', 'rho']])

print('\nPoisson ICAR')
print(pois_icar_summary.round(2).loc[['sigma']])

print('\nPoisson Ordinary random effects')
print(pois_re_summary.round(2).loc[['sigma']])

### Visual comparison

#### BYM2 model

In [None]:
idata_bym2 = az.from_cmdstanpy(
    bym2_fit,
    posterior_predictive="y_rep",
    dims={"betas": ["covariates"]},
    coords={"covariates": design_vars},
    observed_data={"y": bym2_data['y']}
)
idata_bym2

In [None]:
az_bym2_ppc_plot = az.plot_ppc(idata_bym2, data_pairs={"y":"y_rep"})
az_bym2_ppc_plot.set_title("BYM2 model posterior predictive check")
az_bym2_ppc_plot

#### ICAR model

In [None]:
idata_pois_icar = az.from_cmdstanpy(
    pois_icar_fit,
    posterior_predictive="y_rep",
    dims={"betas": ["covariates"]},
    coords={"covariates": design_vars},
    observed_data={"y": bym2_data['y']}
)
idata_pois_icar

In [None]:
az_pois_icar_ppc_plot = az.plot_ppc(idata_pois_icar, data_pairs={"y":"y_rep"})
az_pois_icar_ppc_plot.set_title("Poisson ICAR model posterior predictive check")
az_pois_icar_ppc_plot

#### RE model

In [None]:
idata_pois_re = az.from_cmdstanpy(
    pois_re_fit,
    posterior_predictive="y_rep",
    dims={"betas": ["covariates"]},
    coords={"covariates": design_vars},
    observed_data={"y": bym2_data['y']}
)
az_pois_re_ppc_plot = az.plot_ppc(idata_pois_re, data_pairs={"y":"y_rep"})
az_pois_re_ppc_plot.set_title("Poisson RE model posterior predictive check")
az_pois_re_ppc_plot

### Leave-one-out cross-validation (LOO)

In [None]:
az.compare({"bym2":idata_bym2, "poisson_icar":idata_pois_icar, "poisson_re":idata_pois_re})