In [1]:
import sys
sys.path.append('/Users/samrelins/Documents/LIDA/ace_project/')

from IPython.display import Image
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pymc3 as pm
from src.train_test import *
from src.data_prep import *

from scipy import stats


# load and prep data
data_dir = "/Users/samrelins/Documents/LIDA/ace_project/data/"
data_path = os.path.join(data_dir, "ace_data_extra.csv")
ace_dat = pd.read_csv(data_path)

X_train, y_train, X_test, y_test = return_train_test(ace_dat)
X_train, X_test = encode_and_scale(
    X_train, y_train, X_test, 
    cat_encoder="one_hot", 
    scaled=True
)


def sample_distribution(X_train, y_train, features):
    """
    Helper function to sample logistic regression bayes model
    
    :param X_train: (object: pandas DataFrame) input training data dataframe
    :param y_train: (object: pandas Series) training labels
    :param features: (list) list of features / columns to use 
        in logistic regression model
    
    :return: (tuple: (object: pyMC3 Model, object: pyMC3 Trace) 
        pyMC3 model and trace objects from sampler
    
    """
    with pm.Model() as model:
        pm.glm.GLM(
            x = X_train[features],
            y = y_train,
            intercept=True,
            family=pm.glm.families.Binomial()
        )
        trace = pm.sample(5000, 
                          tune=500, 
                          cores=14,  
                          target_accept=0.90,
                          init="adapt_diag")
    return model, trace
        
    
def return_kde_plot(samples, step=0.001, color="red"):
    """
    returns plolty scatter plot of np.array KDE
    """
    kde = stats.gaussian_kde(samples)
    xx = np.arange(samples.min(), samples.max(), step)
    yy = kde.evaluate(xx)
    kde_plot = go.Scatter(x=xx, 
                          y=yy, 
                          line=dict(color=color,
                                    shape="spline"),
                          mode="lines",
                          fill="tozeroy")
    return kde_plot
        
        
def return_trace_fig(pymc_trace):
    """
    Helper function to plot trace and KDE plots from pyMC3 samples
    
    :param pymc_trace: (object: pyMC3 Trace object) input trace
    
    :return: plotly Figure
    """
    
    subplot_titles = []
    for title in pymc_trace.varnames:
        subplot_titles += [title, ""]
    fig = make_subplots(rows=len(pymc_trace.varnames), 
                        cols=2,
                        subplot_titles=subplot_titles)

    for i, feature in enumerate(pymc_trace.varnames):
        samples = pymc_trace[feature]
        color = px.colors.qualitative.Plotly[i%10]
        
        trace_plot = go.Scatter(x=np.arange(len(samples)), 
                                y=samples,
                                line=dict(color=color))
        fig.add_trace(trace_plot, row=i+1, col=1)
        
        kde_plot = return_kde_plot(samples, color=color)
        fig.add_trace(kde_plot, row=i+1, col=2)

    fig.update_layout(height=200*len(pymc_trace.varnames),
                      showlegend=False)
    return fig



In [None]:
eg_model, eg_trace = sample_distribution(X_train, y_train, ["ox_sat"])
trace_fig = return_trace_fig(eg_trace)
Image(trace_fig.to_image(format="png"))