In [1]:
%load_ext autoreload
%autoreload complete

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from matplotlib_inline.backend_inline import set_matplotlib_formats
from tqdm.notebook import tqdm

set_matplotlib_formats("svg")

In [5]:
from dowhy import CausalModel
from dowhy.causal_estimators.propensity_score_weighting_estimator import (
    PropensityScoreWeightingEstimator,
)
from sklearn.linear_model import LogisticRegression

from src.features.aggregation import naive_all_regions

data = naive_all_regions()
w_col = "occ_FFF"
y_col = "media_online_protest"
y = data.y[0][[y_col]]
x_col = ["is_holiday"] + [c for c in data.x[0].columns if c.startswith("weekday")]
X = data.x[0][[w_col] + x_col]
df = y.join(X)
model = CausalModel(
    data=df,
    treatment=w_col,
    outcome=y_col,
    graph=None,
    common_causes=x_col,
)
estimand = model.identify_effect()
estimator = PropensityScoreWeightingEstimator(
    estimand, propensity_score_model=LogisticRegression()
)

estimator = estimator.fit(df)
estimate = estimator.estimate_effect(data=df, target_units="att")
print(estimate)

  0%|          | 0/16 [00:00<?, ?it/s]

*** Causal Estimate ***

## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE

## Realized estimand
b: media_online_protest~occ_FFF+is_holiday+weekday_Saturday+weekday_Wednesday+weekday_Thursday+weekday_Tuesday+weekday_Monday+weekday_Sunday
Target units: att

## Estimate
Mean value: 9.724385510260781



In [50]:
from sklearn.linear_model import LinearRegression

from src.models.propensity_scores import DowhyWrapper, propensity

model_ = DowhyWrapper(df, y_col, w_col, x_col)
# model_ = LinearRegression()
ts_model = propensity(lags=1, model=model_)

[autoreload of src.models.propensity_scores failed: Traceback (most recent call last):
  File "/Users/david/Repositories/protest-impact/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 274, in check
    superreload(m, reload, self.old_objects, self.shell)
  File "/Users/david/Repositories/protest-impact/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 500, in superreload
    update_generic(old_obj, new_obj)
  File "/Users/david/Repositories/protest-impact/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 397, in update_generic
    update(a, b)
  File "/Users/david/Repositories/protest-impact/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 365, in update_class
    update_instances(old, new)
  File "/Users/david/Repositories/protest-impact/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 323, in update_instances
    object.__setattr__(ref, "__class__", new)
Ty

  0%|          | 0/16 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


In [54]:
print(
    ts_model.model.estimator.estimate_effect(data=ts_model.model.df, target_units="att")
)

*** Causal Estimate ***

## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE

## Realized estimand
b: media_online_protest~occ_FFF_futcov_lag0+occ_GP_futcov_lag0+occ_GP_futcov_lag-1+weekday_Thursday_futcov_lag-1+occ_EG_futcov_lag0+weekday_Wednesday_futcov_lag-1+SERIES8_statcov_target_media_online_protest+occ_OTHER_CLIMATE_ORG_futcov_lag0+SERIES9_statcov_target_media_online_protest+SERIES6_statcov_target_media_online_protest+weekday_Saturday_futcov_lag0+weekday_Thursday_futcov_lag0+occ_EG_futcov_lag-1+weekday_Wednesday_futcov_lag0+SERIES12_statcov_target_media_online_protest+occ_FFFX_futcov_lag-1+weekday_Monday_futcov_lag-1+occ_FFF_futcov_lag-1+SERIES5_statcov_target_media_online_protest+weekday_Monday_futcov_lag0+SERIES2_statcov_target_media_online_protest+is_holiday_futcov_lag0+occ_XR_futcov_lag-1+weekday_Tuesday_futcov_lag-1+SERIES11_statcov_target_media_online_protest+SERIES7_statcov_target_media_online_protest+occ_FFFX_futcov_lag0+SERIES1_statcov_target_media_onlin