In [None]:
%load_ext autoreload
%autoreload 2

from webbrowser import get
import pandas as pd
from sklearn.metrics import accuracy_score, r2_score
from models import *
import logging
from pretty_logger import get_logger
from pathlib import Path


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# define subset (ignore all other data)
ema = [f"Y{i}" for i in range(1, 8, 1)]
physical = [f"P{i}" for i in range(1, 5, 1)]
social = [f"S{i}" for i in range(1, 8, 1)]

In [None]:
df = pd.read_csv(datafile)
df["date"] = pd.to_datetime(df["day_survey"])
df["C"] = df["date"].apply(
    lambda date: (
        COVIDStatus.POST_COVID if date > date_covid else COVIDStatus.PRE_COVID
    )
)

df_head = df.head(5).copy()

df.rename(columns=reverse_ema_dictionary, inplace=True)
df.set_index(["uid", "date"], inplace=True)

subset = ema + social + physical + sleep + ["C"]
df = df[subset]
df.dropna(inplace=True)

sets_df = pd.read_parquet(sets_file, engine="pyarrow")

In [None]:
# skip_emas = list(set(ema_dictionary.keys()).difference(set(["Y2", "Y3"])))
skip_emas = []
# set_of_treatments_to_keep = set(
#     [
#         "P1",
#         "P4",
#         "S1",
#         "S2",
#         "S3",
#         "S4",
#         "S5",
#         "S6",
#     ]
# )
# skip_treatments = list(set(subset).difference(set_of_treatments_to_keep))
skip_treatments = []
skip_treatments

[]

In [55]:
# index = 17
# model_row = sets_df.iloc[index]

ccp_alphas = [1e-3, 1e-5]
n_estimatorss = [800, 400]

for n_estimators in n_estimatorss:
    for ccp_alpha in ccp_alphas:
        for index, model_row in sets_df.iterrows():
            already_fitted_sets = []

            covariate_set = CovariateSet(
                row=model_row,
                data=df,
                outcomes_to_skip=skip_emas,
                treatments_to_skip=skip_treatments,
            )

            if not covariate_set.valid_set:
                logger.debug(f"Skipping {covariate_set} (no valid set)")
                continue

            if covariate_set.set_to_fit in already_fitted_sets:
                logger.warning(
                    f"Skipping {covariate_set} (already did {covariate_set.set_to_fit})"
                )
                continue
            already_fitted_sets.append(covariate_set.set_to_fit)

            logger.info(f"Fitting\n{covariate_set!r}")
            logger.info(
                f"Median of {covariate_set.treatment} {df[covariate_set.treatment].median()}"
            )
            wbm = WBRandomForestModel(
                data=df,
                n_estimators=n_estimators,
                ccp_alpha=ccp_alpha,
                treatment=covariate_set.treatment,
                outcome=covariate_set.outcome,
                separating_set=covariate_set.restricted_adjustment_set,
                name=f"row:{index}",
            )
            logger.info(
                f"\nccp_alpha ={ccp_alpha}, n_estimators={n_estimators}, "
                f"pre_rsq train={wbm.pre_r_squared[0]}, pre_rsq test={wbm.pre_r_squared[1]}\n"
                f"post_rsq train={wbm.post_r_squared[0]}, post_rsq test={wbm.post_r_squared[1]}\n"
            )
            logger.info(
                "-----------------------------------------------------\n"
            )

[[38;5;192m2024-05-28 17:02:50[0m.[33m957[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1147895444.py:[0m [35m30[0m [37m[<module>]:[0m [32m[38;5;222mFitting
treatment: P2:studying (hours), outcome: Y1:pam, adjustment set={'S5', 'S7', 'S4', 'P1', 'S2'}[0m[0m
[[38;5;192m2024-05-28 17:02:50[0m.[33m958[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1147895444.py:[0m [35m31[0m [37m[<module>]:[0m [32m[38;5;222mMedian of P2 2.719055555553[0m[0m


[[38;5;192m2024-05-28 17:03:14[0m.[33m583[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1147895444.py:[0m [35m43[0m [37m[<module>]:[0m [32m[38;5;222m
ccp_alpha =0.001, n_estimators=800, pre_rsq train=0.8653600981108766, pre_rsq test=0.043665024235382544
post_rsq train=0.8749683461363316, post_rsq test=0.03563651088335329
[0m[0m
[[38;5;192m2024-05-28 17:03:14[0m.[33m584[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1147895444.py:[0m [35m48[0m [37m[<module>]:[0m [32m[38;5;222m-----------------------------------------------------
[0m[0m
[[38;5;192m2024-05-28 17:03:14[0m.[33m585[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1147895444.py:[0m [35m30[0m [37m[<module>]:[0m [32m[38;5;222mFitting
treatment: P3:in house (hours), outcome: Y1:pam, adjustment set={'S7', 'S4', 'P4', 'S1', 'S3', 'S2', 'P2'}[0m[0m
[[38;5;192m2024-05-28 17:03:14[0m.[33m586[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m1147895