In [42]:
%load_ext autoreload
%autoreload 2

from webbrowser import get
import pandas as pd
from sklearn.metrics import accuracy_score, r2_score
from models import *
import logging
from pretty_logger import get_logger
from pathlib import Path


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [43]:
# define subset (ignore all other data)
ema = [f"Y{i}" for i in range(1, 8, 1)]
physical = [f"P{i}" for i in range(1, 5, 1)]
social = [f"S{i}" for i in range(1, 4, 1)]

In [44]:
df = pd.read_csv(datafile)
df["date"] = pd.to_datetime(df["day_survey"])
df["C"] = df["date"].apply(
    lambda date: (
        COVIDStatus.POST_COVID if date > date_covid else COVIDStatus.PRE_COVID
    )
)

df_head = df.head(5).copy()

df.rename(columns=reverse_ema_dictionary, inplace=True)
df.set_index(["uid", "date"], inplace=True)

subset = ema + physical + social + ["C"]
df = df[subset]
df.dropna(inplace=True)

sets_df = pd.read_parquet(sets_file, engine="pyarrow")

In [45]:
class CovariateSet:
    @staticmethod
    def model_row_series_valid(row: pd.Series, data: pd.DataFrame):
        adjustment_set = row["sets"].tolist()
        outcome = row["outcome"]
        treatment = row["treatment"]

        if treatment not in data.columns or outcome not in data.columns:
            return logger.debug(
                f"treatment or outcome not in columns: {treatment} or {outcome} (columns:{data.columns})"
            )
        else:
            return set(data.columns).intersection(set(adjustment_set))

    @staticmethod
    def row_string(outcome, treatment, adjustment_set):
        return (
            f"treatment: {treatment}:{full_dictionary[treatment]}, "
            f"outcome: {outcome}:{full_dictionary[outcome]}, "
            f"adjustment set={adjustment_set}"
        )

    def __init__(self, row: pd.Series, data: pd.DataFrame) -> None:
        # passed adjustment set
        self.original_adjustment_set = row["sets"].tolist()
        # adjustment set restricted to valid data (columns in the dataframe)
        self.restricted_adjustment_set = CovariateSet.model_row_series_valid(
            row=row, data=data
        )
        self.treatment = row["treatment"]
        self.outcome = row["outcome"]

    @property
    def set_to_fit(self) -> tuple:
        if self.valid_set:
            return (
                self.outcome,
                set([self.treatment] + list(self.restricted_adjustment_set)),
            )
        else:
            return None

    @property
    def valid_set(self):
        return (
            (self.restricted_adjustment_set is not None)
            and (len(self.restricted_adjustment_set) != 0)
            # exclude demographics
            and (not self.treatment.startswith("D"))
        )

    def __repr__(self):
        return CovariateSet.row_string(
            self.outcome, self.treatment, self.original_adjustment_set
        )

    def __str__(self):
        return f"{self.set_to_fit}"

In [46]:
# index = 17
# model_row = sets_df.iloc[index]

already_fitted_sets = []
ccp_alpha = 1e-5
n_estimatorss = [800]

for index, model_row in sets_df.iterrows():
    n_estimators = n_estimatorss[0]
    covariate_set = CovariateSet(row=model_row, data=df)

    if not covariate_set.valid_set:
        logger.debug(f"Skipping {covariate_set} (no valid set)")
        continue

    if covariate_set.set_to_fit in already_fitted_sets:
        logger.debug(
            f"Skipping {covariate_set} (already did {covariate_set.set_to_fit})"
        )
        continue

    already_fitted_sets.append(covariate_set.set_to_fit)
    logger.info(f"Fitting\n{covariate_set!r}")
    wbm = WBRandomForestModel(
        data=df,
        n_estimators=n_estimators,
        ccp_alpha=ccp_alpha,
        treatment=covariate_set.treatment,
        outcome=covariate_set.outcome,
        separating_set=covariate_set.restricted_adjustment_set,
        name=f"row:{index}",
    )
    logger.info(
        f"\nccp_alpha ={ccp_alpha}, n_estimators={n_estimators}, "
        f"pre_rsq={wbm.pre_r_squared} post_r_sq={wbm.post_r_squared}"
    )
    logger.info("-----------------------------------------------------\n")

[[38;5;192m2024-05-28 12:00:08[0m.[33m534[0m] - [32mmodelslog-[0m [32mINFO[0m [35mN/A-[0m [34m827537744.py:[0m [35m23[0m [37m[<module>]:[0m [32m[38;5;222mFitting
treatment: P2:studying, outcome: Y1:pam, adjustment set=['D1', 'D3', 'D4', 'P1', 'S2', 'S4', 'S5', 'S7'][0m[0m


KeyboardInterrupt: 