In [None]:
# Step 1: Import necessary libraries
import numpy as np
import pandas as pd
from econml.drlearner import DRLearner
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:


# Step 2: Load the dataset
# For this example, we'll assume the data is in a CSV file named 'startup_data.csv'
data = pd.read_csv('startup_data.csv')

# Step 3: Preprocess the data

# Encode binary treatment variables as integers
data['Tech Support'] = data['Tech Support'].astype(int)
data['Discount'] = data['Discount'].astype(int)

# Handle categorical features
categorical_features = ['Global Flag', 'Major Flag', 'SMC Flag', 'Commercial Flag']
for col in categorical_features:
    data[col] = data[col].astype(int)

# Fill or drop missing values if any
data = data.dropna()

# Define features for W (controls), T (treatments), and Y (outcome)
W = data[['Global Flag', 'Major Flag', 'SMC Flag', 'Commercial Flag', 'IT Spend', 'Employee Count', 'PC Count', 'Size']]
T = data[['Tech Support', 'Discount']]
Y = data['Revenue']

# Step 4: Split the data into training and testing sets
X_train, X_test, T_train, T_test, Y_train, Y_test = train_test_split(W, T, Y, test_size=0.2, random_state=42)

# Step 5: Initialize models for the DR Learner
# For the outcome and treatment models, we'll use Random Forests
model_y = RandomForestRegressor(n_estimators=100, random_state=42)
model_t = RandomForestClassifier(n_estimators=100, random_state=42)

# Step 6: Initialize the DR Learner
dr_learner = DRLearner(
    model_regression=model_y,
    model_propensity=model_t,
    random_state=42
)

# Step 7: Fit the model for each treatment separately and analyze heterogeneous treatment effects

# Define a function to fit and plot treatment effects
def analyze_treatment(dr_learner, T_column, T_name):
    # Fit the model
    dr_learner.fit(Y_train, T_train[T_column], X=X_train)
    
    # Estimate the Conditional Average Treatment Effect (CATE)
    cate = dr_learner.effect(X_test)
    
    # Feature importances for heterogeneity
    feature_importances = dr_learner.feature_importances_
    feature_names = X_train.columns
    
    # Create a DataFrame for CATE
    cate_df = pd.DataFrame({
        'CATE': cate,
        'Revenue': Y_test.reset_index(drop=True)
    })
    
    # Print average treatment effect
    ate = np.mean(cate)
    print(f"\nAverage Treatment Effect (ATE) for {T_name}: {ate:.2f}")
    
    # Plot CATE distribution
    plt.figure(figsize=(10,6))
    plt.hist(cate, bins=30, edgecolor='k', alpha=0.7)
    plt.title(f'Distribution of CATE for {T_name}')
    plt.xlabel('CATE')
    plt.ylabel('Frequency')
    plt.show()
    
    # Plot feature importances
    importances_df = pd.DataFrame({
        'Feature': feature_names,
        'Importance': feature_importances
    }).sort_values(by='Importance', ascending=False)
    
    plt.figure(figsize=(10,6))
    plt.barh(importances_df['Feature'], importances_df['Importance'], color='skyblue')
    plt.gca().invert_yaxis()
    plt.title(f'Feature Importances for Heterogeneous Effect of {T_name}')
    plt.xlabel('Importance')
    plt.ylabel('Feature')
    plt.show()
    
    return cate_df

# Analyze the effect of Tech Support
cate_tech_support = analyze_treatment(dr_learner, 'Tech Support', 'Tech Support')

# Analyze the effect of Discount
cate_discount = analyze_treatment(dr_learner, 'Discount', 'Discount')

# Step 8: Interpret the results

# Combine CATE results with customer features
results_tech_support = pd.concat([X_test.reset_index(drop=True), cate_tech_support], axis=1)
results_discount = pd.concat([X_test.reset_index(drop=True), cate_discount], axis=1)

# Example: Analyze CATE by customer size
plt.figure(figsize=(10,6))
plt.scatter(results_tech_support['Size'], results_tech_support['CATE'], alpha=0.7, label='Tech Support')
plt.scatter(results_discount['Size'], results_discount['CATE'], alpha=0.7, label='Discount')
plt.title('CATE vs. Customer Size')
plt.xlabel('Customer Size')
plt.ylabel('CATE')
plt.legend()
plt.show()
