In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
from models import *


In [2]:
# define subset (ignore all other data)
ema = [f"Y{i}" for i in range(1, 8, 1)]
physical = [f"P{i}" for i in range(1, 5, 1)]
social = [f"S{i}" for i in range(1, 8, 1)]

In [3]:
df = pd.read_csv(datafile)
df["date"] = pd.to_datetime(df["day_survey"])
df["C"] = df["date"].apply(
    lambda date: (
        COVIDStatus.POST_COVID if date > date_covid else COVIDStatus.PRE_COVID
    )
)

df_head = df.head(5).copy()

df.rename(columns=reverse_ema_dictionary, inplace=True)
df.set_index(["uid", "date"], inplace=True)

subset = ema + physical + social + sleep + demographic + ["C"]
df = df[subset]
df.dropna(inplace=True)

sets_df = pd.read_parquet(sets_file, engine="pyarrow")

In [4]:
skip_treatments = []
skip_outcomes = []

In [5]:
alphas = [0]

for alpha in alphas:
    for index, model_row in sets_df.iterrows():
        already_fitted_sets = []

        covariate_set = CovariateSet(
            row=model_row,
            data=df,
            outcomes_to_skip=skip_outcomes,
            treatments_to_skip=skip_treatments,
        )

        if not covariate_set.valid_set:

            logger.error(
                f"Skipping {covariate_set} (no valid set for "
                f"outcome:{model_row['outcome']}, "
                f"treatment:{model_row['treatment']}, "
                f"set:{model_row['sets']})."
            )
            raise ValueError("Invalid set")
            continue

        already_fitted_sets.append(covariate_set.set_to_fit)

        logger.info(f"Fitting\n{covariate_set!r}")
        logger.info(
            f"Median of {covariate_set.treatment} {df[covariate_set.treatment].median()}"
        )
        wbm = WBLinearModel(
            data=df,
            alpha=alpha,
            treatment=covariate_set.treatment,
            outcome=covariate_set.outcome,
            separating_set=covariate_set.restricted_adjustment_set,
            name=f"row:{index}",
        )
        logger.info(
            f"pre_rsq train={wbm.pre_r_squared[0]}, pre_rsq test={wbm.pre_r_squared[1]}\n"
            f"pre mae test ={wbm.pre_r_squared[2]}\n"
            f"post_rsq train={wbm.post_r_squared[0]}, post_rsq test={wbm.post_r_squared[1]}\n"
        )
        logger.info("-----------------------------------------------------\n")

[[38;5;192m2024-05-29 11:53:40[0m.[33m470[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1968370032.py:[0m [35m27[0m [37m[<module>]:[0m [32m[38;5;222mFitting
treatment: D2:race, outcome: Y1:pam, adjustment set={'S2', 'S1', 'P2', 'D1', 'P1', 'D4', 'P4', 'P3', 'D3'}[0m[0m
[[38;5;192m2024-05-29 11:53:40[0m.[33m472[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1968370032.py:[0m [35m28[0m [37m[<module>]:[0m [32m[38;5;222mMedian of D2 1.0[0m[0m


[[38;5;192m2024-05-29 11:53:40[0m.[33m584[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1968370032.py:[0m [35m39[0m [37m[<module>]:[0m [32m[38;5;222mpre_rsq train=0.05772819054113065, pre_rsq test=0.05509073823451094
pre mae test =3.50175528575217e-05
post_rsq train=0.06339933431005351, post_rsq test=0.06656585172228846
[0m[0m
[[38;5;192m2024-05-29 11:53:40[0m.[33m585[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1968370032.py:[0m [35m44[0m [37m[<module>]:[0m [32m[38;5;222m-----------------------------------------------------
[0m[0m
[[38;5;192m2024-05-29 11:53:40[0m.[33m586[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1968370032.py:[0m [35m27[0m [37m[<module>]:[0m [32m[38;5;222mFitting
treatment: D4:cohort year, outcome: Y1:pam, adjustment set={'S2', 'S7', 'D1', 'P1', 'S6', 'D3'}[0m[0m
[[38;5;192m2024-05-29 11:53:40[0m.[33m587[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1968370032.py:[0m 