In [None]:
import pandas as pd
from dotenv import load_dotenv
import matplotlib.pyplot as plt

load_dotenv()

df = pd.read_pickle('/Users/jacob.perius/psa_segment_testing/final_df_with_cvr.pkl')

df.head(5)

In [None]:
import scipy
import numpy as np
import pymc as pm

#np.__config__.show()
#scipy.__config__.show()

In [None]:
feature_of_interest = 'cvr'

data_a = df[df['group'] == 'group_a'][feature_of_interest]
data_b = df[df['group'] == 'group_b'][feature_of_interest]
data_c = df[df['group'] == 'group_c'][feature_of_interest]

with pm.Model() as model:

    # Priors for group means and standard deviations
    #mu_a = pm.Normal("mu_a", mu=5, sigma=1)
    #mu_b = pm.Normal("mu_b", mu=5, sigma=1)
    #mu_c = pm.Normal("mu_c", mu=5, sigma=1)

    mu_a = pm.Normal("mu_a", mu=0, sigma=100)
    mu_b = pm.Normal("mu_b", mu=0, sigma=100)
    mu_c = pm.Normal("mu_c", mu=0, sigma=100)

    sigma_a = pm.HalfNormal("sigma_a", sigma=1)
    sigma_b = pm.HalfNormal("sigma_b", sigma=1)
    sigma_c = pm.HalfNormal("sigma_c", sigma=1)

    # Likelihoods for observed data
    obs_a = pm.Normal("obs_a", mu=mu_a, sigma=sigma_a, observed=data_a)
    obs_b = pm.Normal("obs_b", mu=mu_b, sigma=sigma_b, observed=data_b)
    obs_c = pm.Normal("obs_c", mu=mu_c, sigma=sigma_c, observed=data_c)

    # Sampling
    trace = pm.sample(1000, chains=4)

print(trace.posterior)

import seaborn as sns

# Check Posterior Overlap
#pm.plot_posterior(trace, var_names=["mu_a", "mu_b", "mu_c"])

mu_a_samples = trace.posterior['mu_a'].values.flatten()
mu_b_samples = trace.posterior['mu_b'].values.flatten()
mu_c_samples = trace.posterior['mu_c'].values.flatten()

plt.figure(figsize=(10, 6))

# Plot KDE for each variable with different colors
sns.kdeplot(mu_a_samples, fill=True, alpha=0.5, label='mu_a')
sns.kdeplot(mu_b_samples, fill=True, alpha=0.5, label='mu_b')
sns.kdeplot(mu_c_samples, fill=True, alpha=0.5, label='mu_c')

# Add labels and legend
plt.xlabel("Value")
plt.ylabel("Density")
plt.title("Overlayed Posterior Distributions for mu_a, mu_b, mu_c")
plt.legend()

plt.show()