In [11]:
%load_ext autoreload
%autoreload 2

from webbrowser import get
import pandas as pd
from sklearn.metrics import accuracy_score, r2_score
from models import *
import logging
from pretty_logger import get_logger
from pathlib import Path


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# define subset (ignore all other data)
ema = [f"Y{i}" for i in range(1, 8, 1)]
physical = [f"P{i}" for i in range(1, 5, 1)]
social = [f"S{i}" for i in range(1, 8, 1)]

In [13]:
df = pd.read_csv(datafile)
df["date"] = pd.to_datetime(df["day_survey"])
df["C"] = df["date"].apply(
    lambda date: (
        COVIDStatus.POST_COVID if date > date_covid else COVIDStatus.PRE_COVID
    )
)

df_head = df.head(5).copy()

df.rename(columns=reverse_ema_dictionary, inplace=True)
df.set_index(["uid", "date"], inplace=True)

subset = ema + physical + social + sleep + ["C"]
df = df[subset]
df.dropna(inplace=True)

sets_df = pd.read_parquet(sets_file, engine="pyarrow")

In [14]:
skip_emas = list(set(ema_dictionary.keys()).difference(set(["Y2", "Y3"])))
skip_treatments = list(
    set(subset).difference(
        set(
            [
                "P1",
                "P4",
                # "S1",
                # "S2",
                "S3",
                # "S4",
                "S5",
                "S6",
                # "Z1",
                # "Z2",
                # "Z3",
            ]
        )
    )
)

skip_treatments

['P2',
 'S4',
 'Y3',
 'Z2',
 'Y2',
 'S1',
 'Z3',
 'Y5',
 'Y6',
 'Y7',
 'S2',
 'C',
 'P3',
 'Y4',
 'S7',
 'Z1',
 'Y1']

In [15]:
# index = 17
# model_row = sets_df.iloc[index]

alphas = [0.2, 1, 2]
n_estimatorss = [400]

for n_estimators in n_estimatorss:
    for alpha in alphas:
        for index, model_row in sets_df.iterrows():
            already_fitted_sets = []

            covariate_set = CovariateSet(
                row=model_row,
                data=df,
                outcomes_to_skip=skip_emas,
                treatments_to_skip=skip_treatments,
            )

            if not covariate_set.valid_set:
                logger.debug(f"Skipping {covariate_set} (no valid set)")
                continue

            if covariate_set.set_to_fit in already_fitted_sets:
                logger.info(
                    f"Skipping {covariate_set} (already did {covariate_set.set_to_fit})"
                )
                continue
            already_fitted_sets.append(covariate_set.set_to_fit)

            logger.info(f"Fitting\n{covariate_set!r}")
            logger.info(
                f"Median of {covariate_set.treatment} {df[covariate_set.treatment].median()}"
            )
            wbm = WBLinearModel(
                data=df,
                treatment=covariate_set.treatment,
                outcome=covariate_set.outcome,
                separating_set=covariate_set.restricted_adjustment_set,
                name=f"row:{index}",
            )
            logger.info(
                f"pre_rsq train={wbm.pre_r_squared[0]}, pre_rsq test={wbm.pre_r_squared[1]}\n"
                f"pre mae test ={wbm.pre_r_squared[2]}\n"
                f"post_rsq train={wbm.post_r_squared[0]}, post_rsq test={wbm.post_r_squared[1]}\n"
            )
            logger.info(
                "-----------------------------------------------------\n"
            )

[[38;5;192m2024-05-28 16:03:31[0m.[33m468[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m736066224.py:[0m [35m30[0m [37m[<module>]:[0m [32m[38;5;222mFitting
treatment: P1:excercise (seconds), outcome: Y2:phq4_score, adjustment set={'P2', 'S4', 'P4', 'S3', 'S1', 'S2', 'S6'}[0m[0m
[[38;5;192m2024-05-28 16:03:31[0m.[33m470[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m736066224.py:[0m [35m31[0m [37m[<module>]:[0m [32m[38;5;222mMedian of P1 10963.4[0m[0m
[[38;5;192m2024-05-28 16:03:31[0m.[33m581[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m736066224.py:[0m [35m41[0m [37m[<module>]:[0m [32m[38;5;222mpre_rsq train=0.011448571829959397, pre_rsq test=0.009545715899704765
pre mae test =1.9240513911251569
post_rsq train=0.016545711004101626, post_rsq test=0.026328570007057128
[0m[0m
[[38;5;192m2024-05-28 16:03:31[0m.[33m582[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m736066224.py:[0m [35m46[0m [3