# Spike and Slab examples

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import os,sys,time

sys.path.insert(1, '../../src/')
import madmix
import madmix_aux
import gibbs
from concrete import *
import dequantization
import meanfield
import aux

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 40})
fig_path='fig/'
fig_path='../../../madmix-tex/fig/'

## Prostate cancer data

In [2]:
prst_dat_raw = pd.read_table('https://hastie.su.domains/ElemStatLearn/datasets/prostate.data')
prst_dat_raw

Unnamed: 0.1,Unnamed: 0,lcavol,lweight,age,lbph,svi,lcp,gleason,pgg45,lpsa,train
0,1,-0.579818,2.769459,50,-1.386294,0,-1.386294,6,0,-0.430783,T
1,2,-0.994252,3.319626,58,-1.386294,0,-1.386294,6,0,-0.162519,T
2,3,-0.510826,2.691243,74,-1.386294,0,-1.386294,7,20,-0.162519,T
3,4,-1.203973,3.282789,58,-1.386294,0,-1.386294,6,0,-0.162519,T
4,5,0.751416,3.432373,62,-1.386294,0,-1.386294,6,0,0.371564,T
...,...,...,...,...,...,...,...,...,...,...,...
92,93,2.830268,3.876396,68,-1.386294,1,1.321756,7,60,4.385147,T
93,94,3.821004,3.896909,44,-1.386294,1,2.169054,7,40,4.684443,T
94,95,2.907447,3.396185,52,-1.386294,1,2.463853,7,10,5.143124,F
95,96,2.882564,3.773910,68,1.558145,1,1.558145,7,80,5.477509,T


In [3]:
prst_dat_x = np.array(prst_dat_raw[prst_dat_raw.columns.difference(['Unnamed: 0','lpsa','train'])])
prst_dat_x = (prst_dat_x-np.mean(prst_dat_x,axis=1)[:,None])/np.std(prst_dat_x,axis=1)[:,None]
prst_dat_y = np.array(prst_dat_raw['lpsa'])
prst_dat_y = (prst_dat_y-np.mean(prst_dat_y))/np.std(prst_dat_y)

### Gibbs sampling

In [4]:
# settings
steps = 5000
burnin_pct = 0.8
seed = 2023

In [5]:
pis_prst,gbetas_prst,gthetas_prst,gsigmas2_prst,gtaus2_prst = gibbs.gibbs_sas(y=prst_dat_y,x=prst_dat_x,steps=steps,burnin_pct=burnin_pct,seed=seed)

Burn-in: 627/20000

  betas[t+1,:]=np.random.multivariate_normal(mean=ridge_hat@(x.T@y)/sigmas2[t+1],cov=ridge_hat,size=1)


Sampling: 5000/50000

### MAD Mix

In [10]:
# MAD Mix settings
size = 100
N = 500
L = 15
epsilon = 0.001
xi = np.pi/16
nu2 = 10

# relevant functions
lp = madmix_aux.sas_gen_lp(prst_dat_x,prst_dat_y)
grad_lp = madmix_aux.sas_gen_grad_lp(prst_dat_x,prst_dat_y)
lq0 = madmix_aux.sas_gen_lq0(prst_dat_x,prst_dat_y,nu2)
randq0 = madmix_aux.sas_gen_randq0(prst_dat_x,prst_dat_y,nu2)

In [11]:
RUN=False # to control whether to run or to import saved results
madmix_elbos = -np.inf*np.ones(2)

if RUN:
    madmix_cput  =  np.inf*np.ones(2)
    
    print('Sampling')
    t0=time.perf_counter()
    xd_,ud_,xc_,rho_,uc_=madmix.randqN(size,N,randq0,L,epsilon,lp,grad_lp,xi)
    madmix_cput[0]=time.perf_counter()-t0
    
    mad_prst_results=madmix_aux.sas_pack(xd_,ud_,xc_,rho_,uc_)
    print('Done!')
    print('Saving sampling results')
    aux.pkl_save(mad_prst_results,'results/mad_prst_results')
    
    print('Evaluating log density')
    mad_lq_prst = madmix.lqN(xd_.astype(int),ud_,xc_,rho_,uc_,N,lq0,L,epsilon,lp,grad_lp,xi)
    print('Done!')
    print('Saving log density results')
    aux.pkl_save(mad_lq_prst,'results/mad_prst_lq')
else:
    mad_prst_results=aux.pkl_load('results/mad_prst_results')
    mad_lq_prst=aux.pkl_load('results/mad_prst_lq')
    xd_,ud_,xc_,rho_,uc_=madmix_aux.sas_unpack(mad_prst_results,K=prst_dat_x.shape[1])
# end if

Sampling
Done!ing 500/500
Saving sampling results
Evaluating log density
Done!
Saving log density results
