In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from scipy.stats import multivariate_normal as mvnorm
from scipy.stats import uniform, invgamma, bernoulli, poisson, norm
import statsmodels.api as sm
import pandas as pd
from src.preprocessing import prepare_data, prepare_data_no_standardizing, MAR_data_deletion
from src.model_code import Gibbs_MH
import pymc3 as pm
import arviz as az
from statsmodels.tsa.stattools import acf 
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')



In [3]:
df = pd.read_csv('student-mat.csv', sep=",")
features = ["age", "sex", "failures", "higher", "Medu", 
            "absences", "G2", "G3"]
df = df[features]
df.head()

Unnamed: 0,age,sex,failures,higher,Medu,absences,G2,G3
0,18,F,0,yes,4,6,6,6
1,17,F,0,yes,1,4,5,6
2,15,F,3,yes,1,10,8,10
3,15,F,0,yes,4,2,14,15
4,16,F,0,yes,3,4,10,10


In [4]:
X_df, y_df = prepare_data_no_standardizing(df)
X_df.head()

Unnamed: 0,age,failures,Medu,absences,G2,sex_M,higher_yes,intercept
0,18,0,4,6,6,0,1,1
1,17,0,1,4,5,0,1,1
2,15,3,1,10,8,0,1,1
3,15,0,4,2,14,0,1,1
4,16,0,3,4,10,0,1,1


In [9]:
X_df_missing = MAR_data_deletion(X_df, 0.05, 0.05, 'higher_yes', 'G2')
X_df_missing

Unnamed: 0,age,failures,Medu,absences,G2,sex_M,higher_yes,intercept
0,18,0,4,6,,0,1.0,1
1,17,0,1,4,5.0,0,1.0,1
2,15,3,1,10,8.0,0,1.0,1
3,15,0,4,2,14.0,0,1.0,1
4,16,0,3,4,10.0,0,1.0,1
...,...,...,...,...,...,...,...,...
390,20,2,2,11,9.0,1,1.0,1
391,17,0,3,3,16.0,1,1.0,1
392,21,3,1,3,8.0,1,1.0,1
393,18,0,3,0,12.0,1,1.0,1


In [37]:
idx = X_df_missing['G2'].isna() 

In [50]:
X_df_missing['G2'][idx == False]

1       5.0
2       8.0
3      14.0
4      10.0
5      15.0
       ... 
390     9.0
391    16.0
392     8.0
393    12.0
394     9.0
Name: G2, Length: 378, dtype: float64

In [10]:
X = X_df_missing
y = y_df.to_numpy()
B = 5000
n = len(y)
higher_yes_col = 6
absences_col = 4
age_col = 0
taus = [2.9, 0.9, 0.7, 1]# [3.5, 0.8, 0.7, 1]
thin = 10

In [28]:
(betas, sigmas2, higher_yes_sim, absences_sim, alphas0, alphas1, gammas0, gammas1,
 accepts_alpha0, accepts_alpha1, accepts_gamma0, accepts_gamma1) = Gibbs_MH(X, y, B, n, higher_yes_col, absences_col, age_col, taus, thin)

  0%|          | 92/99999 [00:00<01:48, 919.02it/s]

g0 16.67797640187838
g1 0.6397027620019322
g0 -0.6541764986866141
g1 -0.4022572234626158
g0 17.559270177124763
g1 0.6846287131413195
g0 -1.5675555265714807
g1 -0.5117546132881541
g0 19.938807539396862
g1 0.7805276838343093
g0 -4.055879837591023
g1 -0.5143919266581891
g0 22.429257637300328
g1 1.0006526369301896
g0 -9.457016439646186
g1 -0.7268627803049527
g0 -1869.671433552477
g1 -162.73301360229803
g0 2764.860329668496
g1 117.60945796359199
19.49411246707582
g0 -1931.4162867562172
g1 -170.60003979647607
17.697183241060085
g0 3114.752552435447
g1 148.0859549952676
18.593012027856282
g0 -2355.83594389372
g1 -200.40854057147888
18.700029256565006
g0 3388.693770044489
g1 148.46384196387126
16.855971268374823
g0 -2295.578929573242
g1 -198.42593032818863
15.95048782338867
g0 3439.7486014880888
g1 149.41641193756888
19.874677377338703
g0 -2649.078286739853
g1 -216.18944249722378
19.677520672679027
g0 3539.6750920416907
g1 182.80136429739417
18.47608512743958
g0 -2932.5691092026054
g1 -229.640

  0%|          | 184/99999 [00:00<02:16, 732.65it/s]

 -559824.262264278
g1 -37264.366722045634
18.740730679642283
g0 567174.2880687702
g1 31353.301676385887
15.832776820833969
g0 -606029.2287271286
g1 -31867.2989468322
20.037977970134502
g0 516463.0236386022
g1 37095.76607241818
17.751119980465056
g0 -590118.8841528494
g1 -41511.95924186288
18.452014984513664
g0 648434.6255466429
g1 38768.92371262116
17.142416317072776
g0 -669960.169961674
g1 -47619.12911846459
20.824420574436925
g0 787893.6977742334
g1 45673.11396337863
18.04484087211105
g0 -760305.359235003
g1 -46634.200403461065
18.977596381449803
g0 778573.3101396705
g1 53103.99788398297
17.69712549561486
g0 -908741.46311628
g1 -61100.75307066402
21.024617363212254
g0 1190873.5595583809
g1 58534.19615863991
17.118618364600163
g0 -1181723.8041102765
g1 -68155.3252764449
18.12447979465616
g0 1276702.0741744705
g1 82830.80492751715
20.631457636488427
g0 -1507763.9315012153
g1 -80814.38234949244
16.601587663487166
g0 1176968.106912849
g1 96588.86771621929
18.759303753554146
g0 -1669013.1

  0%|          | 330/99999 [00:00<02:30, 662.84it/s]

g0 -11808135921.43784
g1 -476670100.1716839
22.466251040041687
g0 7574209258.7841015
g1 755114591.4343771
54.07634819954171
g0 -12645192958.048843
g1 -495346151.6202432
47.31677422344183
g0 7327432750.809945
g1 883236148.5088171
323.0040502481282
g0 -15051229305.580814
g1 -478434789.6329051
96.3862666691678
g0 9687563549.081774
g1 1009525274.772813
1430.1249443376437
g0 -15991233732.354034
g1 -768636248.7110702
431.56365953488114
g0 13647719002.206764
g1 1016912926.6922355
6518.576270355698
g0 -16567075108.231527
g1 -868176729.1636182
4225.304996023391
g0 17965567984.47743
g1 1179317071.575927
83347.65695408855
g0 -21209998482.224613
g1 -1264351999.2490344
159.525241520577
g0 23596041306.44827
g1 1585570987.5910416
1680979.7597392038
g0 -26525468218.734108
g1 -1291809170.018089
1066.7680275952152
g0 18404497902.33947
g1 1915805905.8037472
46528490.15508872
g0 -32428280756.551502
g1 -1347481805.1250598
26622.23154313443
g0 23193317700.936565
g1 2066552644.2896922
822805466.705655
g0 -34

  0%|          | 397/99999 [00:00<02:27, 677.56it/s]

g0 119968822799733.56
g1 7003963361556.891
1.1417429056668689e+226
g0 -143968190066159.84
g1 -7431693041090.617
7.543683382686453e+227
g0 136161686386097.3
g1 10203573593477.951
1.2821716200775383e+234
g0 -210144940113896.6
g1 -8755947760769.349
7.322115441676472e+235
g0 160949131507560.9
g1 13889037229929.084
1.1228278156056887e+242
g0 -258813584222363.88
g1 -10209995052403.945
1.287955746367229e+245
g0 207374856576051.53
g1 14727711872017.498
5.306389190968516e+250
g0 -262078410217969.1
g1 -13855786240417.49
7.11052311050516e+251
g0 272705189077480.44
g1 16054826587539.426
2.067742479699599e+259
g0 -295469540812432.4
g1 -18827515676319.324
3.6380066806783734e+260
g0 287715529717201.2
g1 18480992221153.18
1.9645624336678416e+268
g0 -321577629454379.1
g1 -17114473973258.738
8.244752828613284e+269
g0 279531765128087.34
g1 20507623914444.61
7.925143792877129e+277
g0 -341295980011298.25
g1 -16997682942375.922
8.982116793014879e+279
g0 281944257603387.56
g1 20552402310496.05
1.912849272325




ValueError: array must not contain infs or NaNs

In [None]:
print(accepts_alpha0/(2*B*thin), accepts_alpha1/(2*B*thin), accepts_gamma0/(2*B*thin), accepts_gamma1/(2*B*thin))

In [None]:
betas_df = pd.DataFrame(betas.T, columns=[f"beta_{i}" for i in X_df.columns])
higher_yes_df = pd.DataFrame(higher_yes_sim.T, columns=[f"missing_higher_yes_{i}" for i in range(0, higher_yes_sim.shape[0])])
absences_df = pd.DataFrame(absences_sim.T, columns=[f"missing_absences_{i}" for i in range(0, absences_sim.shape[0])])
rest_df = pd.DataFrame({"sigmas2": sigmas2.T,
              "alpha0": alphas0.T,
              "alphas1": alphas1.T,
              "gammas0": gammas0.T, 
              "gammas1": gammas1.T
             })
              
results = pd.concat([betas_df , higher_yes_df, absences_df, rest_df], axis=1) 
results.head()
results.to_csv("results/03results_not_thinned.csv")
              

In [None]:
def MCMC_diagnostics(chain, param):
    plt.subplot(411)
    plt.plot(chain)
    plt.title(f'Trace Plot {param}')

    plt.subplot(412)
    plt.hist(chain, bins=60)
    plt.title(f'Histogram {param}')

    plt.subplot(413)
    gw_plot = pm.geweke(chain)
    plt.scatter(gw_plot[:,0],gw_plot[:,1])
    plt.axhline(-1.98, c='r')
    plt.axhline(1.98, c='r')
    
    plt.ylim(-2.5,2.5)
    plt.title(f'Geweke Plot Comparing first 10% and Slices of the Last 50% of Chain {param}')

    plt.subplot(414)
    acf_values = acf(chain)
    plt.scatter(range(0, len(acf_values)), acf_values)
    plt.title(f'ACF {param}')
    
    plt.tight_layout()
    plt.show()

In [None]:
MCMC_diagnostics(alphas0, "alpha0") # need to thin alphas and gammas

In [None]:
MCMC_diagnostics(alphas1, 'alpha1')

In [None]:
MCMC_diagnostics(gammas0, "gamma0")

In [None]:
MCMC_diagnostics(gammas1, "gamma1")

In [None]:
MCMC_diagnostics(sigmas2, "sigma2")

In [None]:
for i, beta in enumerate(betas):
    MCMC_diagnostics(beta, X_df.columns[i])

In [None]:
# thin
betas_thin = betas[:, ::thin]
higher_yes_sim_thin = higher_yes_sim[:, ::thin]
absences_sim_thin = absences_sim[:, ::thin]
alphas0_thin = alphas0[::thin]
alphas1_thin = alphas1[::thin]
gammas0_thin = gammas0[::thin]
gammas1_thin = gammas1[::thin]
sigmas2_thin = sigmas2[::thin]

In [None]:
MCMC_diagnostics(alphas0_thin, "alpha0")
MCMC_diagnostics(alphas1_thin, "alpha1")
MCMC_diagnostics(gammas0_thin, "gamma0")
MCMC_diagnostics(gammas1_thin, "gamma1")

In [None]:
plt.subplot(121)
plt.hist(higher_yes_sim.flatten())
plt.title(f'Simulated missing absences')


plt.subplot(122)
plt.hist(X[:, higher_yes_col])
plt.title(f'Observed absences')

plt.show()

In [None]:
plt.subplot(121)
plt.hist(absences_sim.flatten(), bins=30)
plt.title(f'Simulated missing higher_yes')

plt.subplot(122)
plt.hist(X[:, absences_col], bins=30)
plt.title(f'Observed higher_yes')
plt.show()



In [None]:
betas_df_thinned = pd.DataFrame(betas_thin.T, columns=[f"beta_{i}" for i in X_df.columns])
higher_yes_df_thinned = pd.DataFrame(higher_yes_sim_thin.T, columns=[f"missing_higher_yes_{i}" for i in range(0, higher_yes_sim.shape[0])])
absences_df_thinned = pd.DataFrame(absences_sim_thin.T, columns=[f"missing_absences_{i}" for i in range(0, absences_sim.shape[0])])
rest_df_thinned = pd.DataFrame({"sigmas2": sigmas2_thin.T,
              "alpha0": alphas0_thin.T,
              "alphas1": alphas1_thin.T,
              "gammas0": gammas0_thin.T, 
              "gammas1": gammas1_thin.T
             })
              
results_thinned = pd.concat([betas_df_thinned, higher_yes_df_thinned, absences_df_thinned, rest_df_thinned], axis=1) 
results_thinned.to_csv("results/03results_thinned.csv")
              

In [None]:
from scipy.stats import norm

In [None]:
norm.rvs(np.array([0,1,2]), 1)