# Causal Inference with Double Machine Learning

<center>
<img 
  src="../assets/double_ml.png" 
  alt="Confounding Relationships" 
  style="width:300px;height:auto;"
> 

In [19]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_style("whitegrid") 
sns.set_palette('viridis')
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['font.family'] = 'monospace'

## Double ML
from doubleml import DoubleMLData, DoubleMLPLR
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier

In [20]:
# Load observational dataset
observational_df = pd.read_pickle('../data/observational_df.pkl')

# Identify columns
customer_features = observational_df.drop(columns=['amu_signup', 'upsell_marketing']).columns.to_list()
target_outcome = 'amu_signup'

print('Customer Features: ', customer_features)

observational_df.head(5)

Customer Features:  ['streaming_tier_prime', 'play_days', 'songs_listened', 'other_subscriptions', 'retail_spending']


Unnamed: 0,amu_signup,upsell_marketing,streaming_tier_prime,play_days,songs_listened,other_subscriptions,retail_spending
0,0,0,1,0,0,0,35.0
1,0,1,0,30,120,0,0.0
2,0,1,0,0,0,1,0.0
3,0,0,1,0,0,0,35.0
4,0,0,1,0,0,0,35.0


## Causal Assumptions
All causal models share the following data assumptions

<br>
<br>

<center>
<img 
  src="../assets/causal_assumptions.png" 
  alt="Causal Assumptions" 
  style="width:750px;height:auto;"
> 

<br>
<br>
<br>

## Double Machine Learning Model Families

The `doubleml` package has a number of models that can be used for various causal effects estimation tasks based on the assumed casual mechanisms present in observational data. 

All available model types are listed in their [model documentation](https://docs.doubleml.org/stable/guide/models.html)

We will be used the PLR model, which is the most common model when we have confounding due to customer features. The causal diagram for this model is shown below

<br>
<br>

<center>
<img 
  src="../assets/plr_model.png" 
  alt="Confounding Relationships" 
  style="width:550px;height:auto;"
> 

### Creating DoubleML Datasets

In [21]:
music_dml_data = (
    DoubleMLData(
        data=observational_df,
        y_col='amu_signup',
        d_cols='upsell_marketing',
        x_cols=customer_features,
        use_other_treat_as_covariate=False)
)

print(music_dml_data)


------------------ Data summary      ------------------
Outcome variable: amu_signup
Treatment variable(s): ['upsell_marketing']
Covariates: ['streaming_tier_prime', 'play_days', 'songs_listened', 'other_subscriptions', 'retail_spending']
Instrument variable(s): None
No. Observations: 10000

------------------ DataFrame info    ------------------
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Columns: 7 entries, amu_signup to retail_spending
dtypes: float64(1), int64(6)
memory usage: 547.0 KB



### Defining the Various ML Models

<br>
<br>

<center>
<img 
  src="../assets/double_ml_process.png" 
  alt="Double ML Process" 
  style="width:750px;height:auto;"
> 

<br>
<br>
<br>

In [22]:
# Specify the model components
## Set random seed for reproducability
np.random.seed(314)

# Outcome and treatment models
outcome_model = RandomForestClassifier(random_state=314)

treatment_model = RandomForestClassifier(random_state=314)

# DML model
dml_model = DoubleMLPLR(
    music_dml_data,
    ml_l=outcome_model,
    ml_m=treatment_model,
    n_folds=5)

In [23]:
# Fit the model
dml_model.fit();

In [24]:
# View treatment effect estimates
dml_model.summary.applymap('{:.2%}'.format)

Unnamed: 0,coef,std err,t,P>|t|,2.5 %,97.5 %
upsell_marketing,2.64%,0.69%,383.64%,0.01%,1.29%,3.99%


## Closing Remarks

Why hasn't causal ML taken over the world? 

- We have highly efficient methods for a wide range of estimation problems
- State of the art techniques are still relatively new, especially advanced Double ML methods
- It’s complicated
    - ML, semi-parametric statistical theory, probabilistic grapghical models, matrix calculs, ... 🤯

<br>
<br>
<br>
<center>
<img 
  src="../assets/complicated.png" 
  alt="It's complicated" 
  style="width:300px;height:auto;"
> 