In [1]:
from src.hierarchical_model import spike_slab_hierarchical_model, horseshoe_hierarchical_model
from jax import random
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
import jax.numpy as jnp
import pandas as pd
from sklearn.preprocessing import StandardScaler
import numpy as np

np.random.seed(42)

In [2]:
# Load the dataset directly from the Rdatasets repository
url = "https://vincentarelbundock.github.io/Rdatasets/csv/AER/CollegeDistance.csv"
df = pd.read_csv(url)

# Display the first few rows
print(df.head())

   rownames  gender ethnicity      score fcollege mcollege home urban  unemp  \
0         1    male     other  39.150002      yes       no  yes   yes    6.2   
1         2  female     other  48.869999       no       no  yes   yes    6.2   
2         3    male     other  48.740002       no       no  yes   yes    6.2   
3         4    male      afam  40.400002       no       no  yes   yes    6.2   
4         5  female     other  40.480000       no       no   no   yes    5.6   

   wage  distance  tuition  education income region  
0  8.09       0.2  0.88915         12   high  other  
1  8.09       0.2  0.88915         12    low  other  
2  8.09       0.2  0.88915         12    low  other  
3  8.09       0.2  0.88915         12    low  other  
4  8.09       0.4  0.88915         13    low  other  


In [3]:
# Drop the unnecessary index column
df = df.drop(columns=["rownames"])

# Summary of the dataset
df.info()
df.describe()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4739 entries, 0 to 4738
Data columns (total 14 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   gender     4739 non-null   object 
 1   ethnicity  4739 non-null   object 
 2   score      4739 non-null   float64
 3   fcollege   4739 non-null   object 
 4   mcollege   4739 non-null   object 
 5   home       4739 non-null   object 
 6   urban      4739 non-null   object 
 7   unemp      4739 non-null   float64
 8   wage       4739 non-null   float64
 9   distance   4739 non-null   float64
 10  tuition    4739 non-null   float64
 11  education  4739 non-null   int64  
 12  income     4739 non-null   object 
 13  region     4739 non-null   object 
dtypes: float64(5), int64(1), object(8)
memory usage: 518.5+ KB


Unnamed: 0,score,unemp,wage,distance,tuition,education
count,4739.0,4739.0,4739.0,4739.0,4739.0,4739.0
mean,50.889029,7.597215,9.500506,1.80287,0.814608,13.807765
std,8.70191,2.763581,1.343067,2.297128,0.339504,1.789107
min,28.950001,1.4,6.59,0.0,0.25751,12.0
25%,43.924999,5.9,8.85,0.4,0.48499,12.0
50%,51.189999,7.1,9.68,1.0,0.82448,13.0
75%,57.769999,8.9,10.15,2.5,1.12702,16.0
max,72.809998,24.9,12.96,20.0,1.40416,18.0


In [4]:
# Ensure the target variable `education` is binary
df['education'] = (df['education'] > 16).astype(int)

df["environment"] = np.where(df["distance"] <= (df['distance'].median()).astype(int), 0, 1)
e = df['environment'].values
E = len(df['environment'].unique())

scaler = StandardScaler()
continuous_cols = ['score', 'tuition', 'unemp', 'wage']
df[continuous_cols] = scaler.fit_transform(df[continuous_cols])

# Convert categorical variables to dummies
df = pd.get_dummies(
    df,
    columns=['gender', 'ethnicity', 'fcollege', 'mcollege', 'home', 'urban', 'income', 'region'],
    drop_first=True
)

In [5]:
n_environments = 2

X_cols = [col for col in df.columns if col not in ['education', 'environment', 'distance']]
X = df[X_cols].values
Y = df['education'].values

# Convert to JAX arrays
X_jax = jnp.array(X, dtype=jnp.float32)
Y_jax = jnp.array(Y, dtype=jnp.int32)  # Must be integers for Bernoulli
e_jax = jnp.array(e, dtype=jnp.int32)

# RUN THIS FOR HORSESHOE

In [6]:
kernel = NUTS(horseshoe_hierarchical_model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1)
mcmc.run(
random.PRNGKey(42), N=X_jax.shape[0], D=X_jax.shape[1],
         E=E, e=e_jax, X=X_jax, y=Y_jax
)
posterior_samples = mcmc.get_samples()
mcmc.print_summary()


sample: 100%|██████████| 3000/3000 [00:19<00:00, 157.10it/s, 63 steps of size 6.86e-02. acc. prob=0.87] 



                      mean       std    median      5.0%     95.0%     n_eff     r_hat
       beta[0,0]      0.05      0.01      0.05      0.04      0.06   2315.22      1.00
       beta[0,1]      0.00      0.00      0.00     -0.00      0.01    945.22      1.00
       beta[0,2]     -0.00      0.00     -0.00     -0.01      0.00   2011.16      1.00
       beta[0,3]      0.00      0.00      0.00     -0.01      0.01   1354.22      1.00
       beta[0,4]      0.00      0.01      0.00     -0.01      0.01   1394.94      1.00
       beta[0,5]      0.02      0.01      0.02      0.00      0.04   1377.88      1.00
       beta[0,6]      0.00      0.01      0.00     -0.01      0.01   1021.18      1.00
       beta[0,7]      0.05      0.01      0.05      0.02      0.07   1615.37      1.00
       beta[0,8]      0.04      0.02      0.04      0.02      0.07   1399.71      1.00
       beta[0,9]      0.04      0.01      0.04      0.03      0.06   1262.31      1.00
      beta[0,10]      0.01      0.01      

In [10]:
X_cols[0], X_cols[5], X_cols[7], X_cols[9]

('score', 'ethnicity_hispanic', 'fcollege_yes', 'home_yes')

# RUN THIS FOR SPIKE AND SLAB

In [13]:
# 1) Define a base kernel for the *continuous* parameters
base_kernel = NUTS(spike_slab_hierarchical_model)

# 2) Wrap it with DiscreteHMCGibbs to handle the discrete site z
kernel = DiscreteHMCGibbs(base_kernel, modified=True)
# "modified=True" uses a 'random-propose' Gibbs update (somewhat better mixing),
# or set "modified=False" for standard approach.

mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1)
mcmc.run(random.PRNGKey(42), N=X_jax.shape[0], D=X_jax.shape[1],
         E=E, e=e_jax, X=X_jax, y=Y_jax)
posterior = mcmc.get_samples()
mcmc.print_summary()


sample: 100%|██████████| 3000/3000 [01:54<00:00, 26.10it/s, 511 steps of size 1.14e-02. acc. prob=0.80] 


                     mean       std    median      5.0%     95.0%     n_eff     r_hat
          mu[0]      0.07      0.21      0.05     -0.19      0.30    298.52      1.00
          mu[1]     -0.01      1.02      0.00     -1.58      1.70   1551.44      1.00
          mu[2]      0.02      1.02      0.05     -1.72      1.59    860.39      1.01
          mu[3]     -0.04      0.98     -0.07     -1.62      1.57   1259.59      1.00
          mu[4]      0.03      1.01      0.03     -1.66      1.69   1381.87      1.00
          mu[5]     -0.04      1.02     -0.08     -1.69      1.55    983.59      1.00
          mu[6]      0.01      0.97      0.01     -1.50      1.65   1300.13      1.00
          mu[7]      0.00      0.34      0.05     -0.61      0.42     93.30      1.05
          mu[8]      0.00      1.00      0.03     -1.62      1.62   1345.10      1.00
          mu[9]     -0.04      0.96     -0.04     -1.80      1.40   1482.46      1.00
         mu[10]     -0.00      0.95     -0.04     -1.




In [7]:
X_cols[0], X_cols[7]


('score', 'fcollege_yes')