# Causal inference on Burnout Risk for [YOUR COMPANY]

## Research question

Among Copilot users, what is the causal effect of increasing Copilot usage (e.g., from 5 to 10 actions per person-week) on burnout risk, i.e. after-hours collaboration hours?

We frame this as a continuous-treatment (dose‚Äìresponse) problem: Copilot usage is the treatment "dose," and we study both the marginal effect of an additional action and policy-relevant contrasts between usage levels (e.g., 5 ‚Üí 10 actions).

‚Äî

## Methods (summary)

We use Double Machine Learning (DML) on person-level aggregated data:
- Aggregate longitudinal data by person, taking means across all observed weeks for treatment, outcome, and time-varying controls.
- Fit LinearDML with a spline featurizer on treatment to learn the marginal dose-response curve; compare to a baseline without featurization.
- Estimate heterogeneity with CausalForestDML to identify which subgroups show different treatment effects.

This cross-sectional approach compares individuals with different average Copilot usage levels, with DML providing robust adjustment for observed confounders. Deliverables include dose‚Äìresponse plots, marginal effects with CIs, and subgroup-level effect estimates.

## 1. Setup and Imports

We begin by importing the necessary libraries for our Double Machine Learning analysis:
- **EconML**: Advanced causal inference estimators (LinearDML, CausalForestDML)
- **Scikit-learn**: Feature transformers and ML models
- **Data processing**: pandas, numpy for data manipulation
- **Visualization**: matplotlib for plotting treatment effects

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import os
import sys
import vivainsights as vi
from pathlib import Path
from datetime import datetime
import json

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# EconML imports for causal inference
from econml.dml import LinearDML, CausalForestDML
from econml.cate_interpreter import SingleTreeCateInterpreter

# Scikit-learn imports for feature engineering and ML models
from sklearn.preprocessing import PolynomialFeatures, SplineTransformer
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score

# Generate subgroup combinations
from itertools import combinations

# Import custom modules
script_dir = os.getcwd()
sys.path.append(os.path.join(script_dir, 'modules'))

from modules.data_processor import DataProcessor
from modules.estimator import TreatmentEffectEstimator
from modules.output_manager import OutputManager
from modules.subgroup_analysis import (
    create_subgroup_definition,
    create_transition_matrix,
    run_ate_for_subgroup,
    identify_top_subgroups
)
from modules.sensitivity_analysis import (
    calculate_evalue,
    rosenbaum_bounds_approximation,
    run_sensitivity_analysis
)

# ==================== ANALYSIS CONFIGURATION ====================
# Toggle to control whether to find subgroups with NEGATIVE or POSITIVE effects
# For after-hours collaboration, we're typically interested in NEGATIVE effects (reductions)
FIND_NEGATIVE_EFFECTS = True  # Set to True for negative effects (bottom groups), False for positive effects (top groups)

print("=" * 60)
print("ANALYSIS CONFIGURATION")
print("=" * 60)
if FIND_NEGATIVE_EFFECTS:
    print("üîç Target: Subgroups with MOST NEGATIVE effects")
    print("   (i.e., largest reductions in after-hours collaboration)")
else:
    print("üîç Target: Subgroups with MOST POSITIVE effects") 
    print("   (i.e., largest increases in after-hours collaboration)")
print("=" * 60)
print()
# ================================================================

# Set up paths
data_file_path = os.path.join(script_dir, '..', 'data', 'PersonQuery.Csv') # Update path

# Set up output path
output_base_dir = os.path.join(script_dir, '..', 'output', 'Subgroup Analysis - [YOUR COMPANY]') # Update path
os.makedirs(output_base_dir, exist_ok=True)

print("‚úì All imports successful!")
print(f"Working directory: {script_dir}")
print(f"Data file path: {data_file_path}")
print(f"Data file exists: {os.path.exists(data_file_path)}")

# Set random seeds for reproducibility
np.random.seed(123)
plt.style.use('default')  # Clean plotting style

# Read the data
data = vi.import_query(data_file_path)

In [None]:
print("Columns in the dataset:")
print(data.columns.tolist())

print("HR attributes in the dataset:")
hrvar_str = vi.extract_hr(data)
print(hrvar_str)

In [None]:
data['Total'] = 'Total'

# Analysis Configuration
OUTCOME_VAR = 'After_hours_collaboration_hours'
TREATMENT_VAR = 'Total_Copilot_actions_taken'
PERSON_ID_VAR = 'PersonId'

# Checking which Organization is high on the outcome variable
ech_organization = vi.create_bar(
    data = data,
    metric = OUTCOME_VAR,
    hrvar = 'Organization'
)

# Checking trend of after-hours collaboration over time
ech_time = vi.create_line(
    data = data,
    metric = OUTCOME_VAR,
    hrvar = 'Total'
)

# Checking trend of total copilot actions over time
tch_time = vi.create_line(
    data = data,
    metric = TREATMENT_VAR,
    hrvar = 'Total'
)


## 2. Data Filtering and Configuration

<Placeholder for details here>

In [None]:
# Define key variables for analysis
# Note: There must be a minimum of two subgroup variables for analysis
SUBGROUP_VARS = [
    'FunctionType',
    'IsManager',
    'LevelDesignation',
    'Organization'
]

NETWORK_VARS = [
    'Internal_network_size',
    'External_network_size',
    'Strong_ties',
    'Diverse_ties'
]

COLLABORATION_VARS = [
    'Collaboration_hours',
    'Available_to_focus_hours',
    'Active_connected_hours',
    'Uninterrupted_hours'
]

print("=== Data Filtering Phase ===\n")

# Check original data structure
print(f"üìä Original Dataset:")
print(f"   ‚Ä¢ Shape: {data.shape}")
print(f"   ‚Ä¢ Unique individuals: {data[PERSON_ID_VAR].nunique()}")

# --- Simple start/end date filter ---
# Set these as needed, e.g. '2025-06-01' and '2025-06-30'
start_date_str = '2025-03-01'
end_date_str   = '2025-06-30'

data['MetricDate'] = pd.to_datetime(data['MetricDate'], errors='coerce')
print(f"   ‚Ä¢ Date range available: {data['MetricDate'].min()} to {data['MetricDate'].max()}")
if start_date_str and end_date_str:
    start_date = pd.to_datetime(start_date_str)
    end_date = pd.to_datetime(end_date_str)
    mask = (data['MetricDate'] >= start_date) & (data['MetricDate'] <= end_date)
    kept = int(mask.sum()); total = len(data)
    print(f"   ‚Ä¢ Applying date filter: {start_date.date()} to {end_date.date()} (kept {kept}/{total}, {kept/total:.1%})")
    data = data.loc[mask].copy()
else:
    print("   ‚Ä¢ No start/end date provided; skipping date filter")

# TOGGLE: Not applied because filter is not needed
# Filter 1: Population filter
# pop_mask = (data['JobFunction'] == 'Sales') & (data['ismanager'] == 'N') 
# print(f"   ‚Ä¢ Records after filtering: {pop_mask.sum():,} ({pop_mask.mean():.1%})")
# data = data.loc[pop_mask].copy()
# print(f"   ‚úì Filtered to Sales ICs only")

# Filter 2: Copilot users only (Total_Copilot_actions_taken > 0)
copilot_mask = data[TREATMENT_VAR] > 0
print(f"   ‚Ä¢ Copilot users (person-weeks): {copilot_mask.sum():,} ({copilot_mask.mean():.1%})")
data = data.loc[copilot_mask].copy()
print(f"   ‚úì Filtered to Copilot users only")

# Winsorize: cap treatment at 95th percentile for overlap
treatment_95th = data[TREATMENT_VAR].quantile(0.95)
print(f"   ‚Ä¢ 95th percentile of {TREATMENT_VAR}: {treatment_95th}")
data.loc[data[TREATMENT_VAR] > treatment_95th, TREATMENT_VAR] = treatment_95th
print(f"   ‚úì Winsorized {TREATMENT_VAR} at the 95th percentile (upper cap)")

# Print final filtered dataset
print(f"\nüìä Filtered Dataset:")
print(f"   ‚Ä¢ Shape: {data.shape}")
print(f"   ‚Ä¢ Unique individuals: {data[PERSON_ID_VAR].nunique()}")

# Print if in 'data' any variables are missing from SUBGROUP_VARS, NETWORK_VARS, or COLLABORATION_VARS
missing_vars = []
for var_list in [SUBGROUP_VARS, NETWORK_VARS, COLLABORATION_VARS]:
    missing = set(var_list) - set(data.columns)
    missing_vars.extend(missing)
if missing_vars:
    print(f"   ‚Ä¢ Missing variables in 'data': {missing_vars}")
else:
    print(f"   ‚Ä¢ All required variables are present in 'data'")


In [None]:
# Create a unique set of grouping variables by flattening the lists
GROUPING_VARS = [PERSON_ID_VAR] + SUBGROUP_VARS
GROUPING_VARS = list(set(GROUPING_VARS))  # Remove duplicates
GROUPING_VARS = [var for var in GROUPING_VARS if var in data.columns]  # Only keep existing columns

# Create a unique set of numeric variables to aggregate by mean
AGG_VARS = NETWORK_VARS + COLLABORATION_VARS + [OUTCOME_VAR] + [TREATMENT_VAR]
AGG_VARS = list(set(AGG_VARS))  # Remove duplicates
AGG_VARS = [var for var in AGG_VARS if var in data.columns]  # Only keep existing columns

# For 'data', group by GROUPING_VARS, aggregate by mean for AGG_VARS
# Use lambda function with skipna=True to ignore missing values when calculating mean
agg_dict = {var: lambda x: x.mean(skipna=True) for var in AGG_VARS}
data_snapshot = data.groupby(GROUPING_VARS).agg(agg_dict).reset_index()

# Flatten column names if they become multi-level due to lambda functions
if isinstance(data_snapshot.columns, pd.MultiIndex):
    data_snapshot.columns = ['_'.join(col).strip() if col[1] else col[0] for col in data_snapshot.columns]

print(f"‚úì Data snapshot created with shape: {data_snapshot.shape}")
print(f"‚úì Grouping variables: {GROUPING_VARS}")
print(f"‚úì Aggregated variables: {AGG_VARS}")
print(f"‚úì Missing values ignored during mean calculation")

data_snapshot


## 3. CATE Analysis to Identify Top Subgroups

First, we'll use CausalForestDML to identify subgroups with heterogeneous treatment effects, then select the top 5 groups for detailed ATE analysis.

In [None]:
# Create necessary columns for cross-group analysis
print("\n=== Generating Subgroup Combinations ===")

# Create meaningful subgroup combinations
subgroup_combinations = []
min_group_size = 50  # Minimum observations per group for reliable analysis

# Filter to variables that exist in the dataset and have reasonable number of categories
available_group_vars = [var for var in SUBGROUP_VARS if var in data_snapshot.columns]
print(f"Available grouping variables: {available_group_vars}")
print(f"Using cross-sectional data_snapshot with shape: {data_snapshot.shape}")

# Create all possible 2-way combinations using data_snapshot (cross-sectional data)
for var1, var2 in combinations(available_group_vars, 2):
    if var1 in data_snapshot.columns and var2 in data_snapshot.columns:
        # Create combination groups using data_snapshot
        group_combinations = data_snapshot.groupby([var1, var2]).size()
        
        # Filter groups with sufficient sample size
        large_groups = group_combinations[group_combinations >= min_group_size]
        
        for (val1, val2), count in large_groups.items():
            group_name = f"{var1}_{val1}__and__{var2}_{val2}".replace(' ', '_').replace('/', '_')
            subgroup_combinations.append({
                'name': group_name,
                'var1': var1,
                'val1': val1,
                'var2': var2, 
                'val2': val2,
                'size': count
            })

print(f"Found {len(subgroup_combinations)} subgroup combinations with >= {min_group_size} observations")

# Show top combinations by size
if subgroup_combinations:
    sorted_combos = sorted(subgroup_combinations, key=lambda x: x['size'], reverse=True)
    print("\nTop 10 largest subgroups:")
    for combo in sorted_combos[:10]:
        print(f"  {combo['name']}: n={combo['size']}")


In [None]:
# Prepare data for CATE analysis on subgroups
print("\n=== Preparing CATE Analysis ===")

# Prepare variables for heterogeneity analysis using data_snapshot
X_vars = SUBGROUP_VARS + NETWORK_VARS[:2]  # Use key demographic and network variables
available_X_vars = [var for var in X_vars if var in data_snapshot.columns and data_snapshot[var].dtype in ['object', 'category', 'int64', 'float64']]

print(f"Variables for heterogeneity analysis: {available_X_vars}")

# Create dummy variables for categorical features
X_data = pd.get_dummies(data_snapshot[available_X_vars], drop_first=True)
print(f"Feature matrix shape after encoding: {X_data.shape}")

# Prepare treatment and outcome from data_snapshot
T = data_snapshot[TREATMENT_VAR].values
Y = data_snapshot[OUTCOME_VAR].values
W = data_snapshot[COLLABORATION_VARS[:3]].fillna(0).values  # Control variables

print(f"Treatment variable range: {T.min():.2f} to {T.max():.2f}")
print(f"Outcome variable range: {Y.min():.2f} to {Y.max():.2f}")
print(f"Control variables shape: {W.shape}")

# Remove any rows with missing values
valid_mask = ~(pd.isna(T) | pd.isna(Y) | np.isnan(X_data).any(axis=1) | np.isnan(W).any(axis=1))
T_clean = T[valid_mask]
Y_clean = Y[valid_mask] 
X_clean = X_data[valid_mask]
W_clean = W[valid_mask]
data_clean = data_snapshot[valid_mask].copy()

print(f"Clean data shape: {len(T_clean)} observations (cross-sectional)")
print(f"Removed {len(T) - len(T_clean)} rows due to missing values")

### CATE Effect Interpretation

**Understanding `mean_effect` in Subgroup Analysis:**

The `mean_effect` calculated for each subgroup represents the **average treatment effect for moving from 0 to the mean treatment level** within that subgroup, based on CausalForestDML estimates.

**Key Details:**
- **Treatment Comparison**: Effects are calculated using `T0=0` (baseline) vs `T1=T_clean.mean()` (dataset average)
- **Units**: Hours of after-hours collaboration change by moving from 0 to average Copilot usage
- **Interpretation**: 
  - If `mean_effect = -1.5` ‚Üí **1.5 hours REDUCTION** in after-hours collaboration per week
  - If `mean_effect = +1.5` ‚Üí **1.5 hours INCREASE** in after-hours collaboration per week
- **Direction**: Based on the `FIND_NEGATIVE_EFFECTS` toggle:
  - `True` ‚Üí Find subgroups with most negative effects (largest reductions)
  - `False` ‚Üí Find subgroups with most positive effects (largest increases)

**Note**: This is a dose-response effect for policy-relevant treatment contrasts, not a per-unit marginal effect. The CausalForestDML captures heterogeneous, potentially non-linear effects across individuals.

In [None]:
# Fit CATE model to identify high-effect subgroups
print("\n=== Fitting CATE Model ===")

# Initialize CausalForestDML
cate_estimator = CausalForestDML(
    model_t=RandomForestRegressor(n_estimators=100, random_state=123),
    model_y=RandomForestRegressor(n_estimators=100, random_state=123), 
    cv=3,
    random_state=123
)

# Fit the CATE model
print("Fitting CausalForestDML...")
cate_estimator.fit(Y_clean, T_clean, X=X_clean, W=W_clean)
print("‚úì CATE model fitted successfully")

# Estimate treatment effects for each individual
print("Estimating individual treatment effects...")
treatment_effects = cate_estimator.effect(X_clean, T0=0, T1=T_clean.mean())

print(f"Individual treatment effects range: {treatment_effects.min():.3f} to {treatment_effects.max():.3f}")
print(f"Mean treatment effect: {treatment_effects.mean():.3f}")

# Add treatment effects back to clean data
data_clean['individual_treatment_effect'] = treatment_effects
print("‚úì Individual treatment effects calculated")

In [None]:
# Identify top 5 subgroups with highest/lowest treatment effects based on configuration
if FIND_NEGATIVE_EFFECTS:
    print("\n=== Identifying Subgroups with MOST NEGATIVE Effects ===")
    print("(Largest reductions in after-hours collaboration)")
else:
    print("\n=== Identifying Subgroups with MOST POSITIVE Effects ===")
    print("(Largest increases in after-hours collaboration)")

subgroup_effects = []

for combo in subgroup_combinations:
    # Create mask for this subgroup
    mask = ((data_clean[combo['var1']] == combo['val1']) & 
            (data_clean[combo['var2']] == combo['val2']))
    
    if mask.sum() < min_group_size:  # Skip if too small after cleaning
        continue
        
    subgroup_data = data_clean[mask]
    mean_effect = subgroup_data['individual_treatment_effect'].mean()
    std_effect = subgroup_data['individual_treatment_effect'].std()
    n_obs = len(subgroup_data)
    n_users = subgroup_data[PERSON_ID_VAR].nunique()
    
    # Calculate statistical significance (t-test against 0)
    from scipy import stats
    if n_obs > 1 and std_effect > 0:
        t_stat = mean_effect / (std_effect / np.sqrt(n_obs))
        p_value = 2 * (1 - stats.t.cdf(abs(t_stat), n_obs - 1))
    else:
        p_value = 1.0
    
    subgroup_effects.append({
        'name': combo['name'],
        'var1': combo['var1'],
        'val1': combo['val1'], 
        'var2': combo['var2'],
        'val2': combo['val2'],
        'mean_effect': mean_effect,
        'std_effect': std_effect,
        'p_value': p_value,
        'n_observations': n_obs,
        'n_users': n_users,
        'significant': p_value < 0.05
    })

# Sort by mean effect based on configuration toggle
subgroup_effects_df = pd.DataFrame(subgroup_effects)
significant_subgroups = subgroup_effects_df[subgroup_effects_df['significant']].sort_values(
    'mean_effect', 
    ascending=FIND_NEGATIVE_EFFECTS  # ascending=True gets most negative (bottom), ascending=False gets most positive (top)
)

print(f"Found {len(significant_subgroups)} statistically significant subgroups (p < 0.05)")
if FIND_NEGATIVE_EFFECTS:
    print(f"Bottom 10 subgroups by treatment effect (most negative):")
else:
    print(f"Top 10 subgroups by treatment effect (most positive):")
print(significant_subgroups[['name', 'mean_effect', 'p_value', 'n_observations', 'n_users']].head(10))

# Select top/bottom 5 for detailed ATE analysis
top_5_subgroups = significant_subgroups.head(5)
if FIND_NEGATIVE_EFFECTS:
    print(f"\nüìä Selected 5 subgroups with MOST NEGATIVE effects for ATE analysis:")
else:
    print(f"\nüìä Selected 5 subgroups with MOST POSITIVE effects for ATE analysis:")

for idx, (_, subgroup_info) in enumerate(top_5_subgroups.iterrows(), 1):
    effect_direction = "‚Üì" if subgroup_info['mean_effect'] < 0 else "‚Üë"
    print(f"  {idx}. {subgroup_info['name']}")
    print(f"     Effect: {effect_direction} {subgroup_info['mean_effect']:.4f} hours (p={subgroup_info['p_value']:.4f})")

# Save summary
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
significant_subgroups_path = os.path.join(output_base_dir, f"significant_subgroups_{timestamp}.csv")
significant_subgroups.to_csv(significant_subgroups_path, index=False)
print(f"\n‚úì Full significant subgroups list saved to: {significant_subgroups_path}")

significant_subgroups

In [None]:
# ==================== Visualization of Selected Subgroups ====================

# Bar chart visualizing significant sub-groups and their mean treatment effect
# Based on significant_subgroups

# Prepare data for visualization - use ALL significant subgroups
plot_data = significant_subgroups.copy()
plot_data['short_name'] = plot_data.apply(
    lambda row: f"{row['val1']}\n{row['val2']}", 
    axis=1
)

# Determine figure size based on number of subgroups
n_subgroups = len(plot_data)
fig_width = max(16, n_subgroups * 0.5)  # At least 16 inches, scale with number of groups
fig_height = 8

# Create figure
fig, ax = plt.subplots(figsize=(fig_width, fig_height))

# Create bar chart
x_pos = np.arange(len(plot_data))
bars = ax.bar(x_pos, plot_data['mean_effect'], 
               color=['#0078D4' if effect > 0 else '#107C10' for effect in plot_data['mean_effect']],
               alpha=0.7, edgecolor='black', linewidth=1.5)

# Add value labels on bars
for i, (idx, row) in enumerate(plot_data.iterrows()):
    height = row['mean_effect']
    ax.text(i, height + 0.002, f"{height:.4f}", 
            ha='center', va='bottom', fontsize=8, fontweight='bold', rotation=0)
    # Add p-value below bar (only if p < 0.01 to avoid clutter)
    if row['p_value'] < 0.01:
        ax.text(i, min(plot_data['mean_effect']) * 0.95, f"p<.01", 
                ha='center', va='top', fontsize=6, style='italic', color='gray')

# Customize plot
ax.set_xlabel('Subgroup', fontsize=12, fontweight='bold')
ax.set_ylabel('Mean Treatment Effect (hours)', fontsize=12, fontweight='bold')
effect_type = "Most Negative" if FIND_NEGATIVE_EFFECTS else "Most Positive"
ax.set_title(f'Treatment Effects for All {n_subgroups} Significant Subgroups\n(Impact of Copilot on After-Hours Collaboration)', 
             fontsize=14, fontweight='bold', pad=20)

# Set x-axis labels
ax.set_xticks(x_pos)
ax.set_xticklabels(plot_data['short_name'], rotation=90, ha='center', fontsize=7)

# Add horizontal line at zero
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.8, alpha=0.5)

# Add grid
ax.grid(axis='y', alpha=0.3, linestyle='--')

# Add legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='#0078D4', alpha=0.7, edgecolor='black', label='Positive Effect (Increase)'),
    Patch(facecolor='#107C10', alpha=0.7, edgecolor='black', label='Negative Effect (Decrease)')
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=10)

# Add summary statistics text box
n_positive = (plot_data['mean_effect'] > 0).sum()
n_negative = (plot_data['mean_effect'] < 0).sum()
avg_effect = plot_data['mean_effect'].mean()
textstr = f'Total: {n_subgroups}\nPositive: {n_positive}\nNegative: {n_negative}\nAvg Effect: {avg_effect:.4f}h'
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', bbox=props)

plt.tight_layout()
plt.show()

# Save the plot
plot_filename = f"all_significant_subgroups_effects_{timestamp}.png"
plot_path = os.path.join(output_base_dir, plot_filename)
fig.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\n‚úì Plot saved to: {plot_path}")
print(f"‚úì Visualized {n_subgroups} significant subgroups")


### Important Note on EconML X vs W Parameters

**Critical distinction for LinearDML vs CausalForestDML:**

- **LinearDML**: 
  - `X` = Effect modifiers (variables that create heterogeneous treatment effects)
  - `W` = Confounders (variables for backdoor adjustment/controlling bias)
  
- **CausalForestDML**: 
  - `X` = Used for both heterogeneity AND confounding control
  - `W` = Additional controls (optional)

**Our approach:**
- `X` (effect modifiers): Network variables (expect different effects by network size)
- `W` (confounders): Demographics + collaboration controls (need to adjust for these)
- **Featurized model**: `œÑ(X) = f(network_vars)` controlling for demographics
- **Baseline model**: `œÑ = constant` controlling for demographics

**Model Comparison:**
- **ATE (Featurized)**: Uses `treatment_featurizer` (SplineTransformer) to capture non-linear dose-response relationships. Can model diminishing returns, thresholds, and varying effects across treatment levels.
- **ATE (Baseline)**: Uses `treatment_featurizer=None` for linear effects only. Assumes constant effect per unit treatment increase across all dose levels. Serves as benchmark to test if non-linear modeling adds value.

In [None]:
# Run ATE analysis for each top subgroup and generate outputs
print("\n=== Running ATE Analysis for Top Subgroups ===")

successful_analyses = []

for idx, (_, subgroup_info) in enumerate(top_5_subgroups.iterrows(), 1):
    print(f"\n{'='*50}")
    print(f"SUBGROUP {idx}/5: {subgroup_info['name']}")
    print(f"{'='*50}")
    
    # Create subgroup mask and extract data
    mask = ((data_clean[subgroup_info['var1']] == subgroup_info['val1']) & 
            (data_clean[subgroup_info['var2']] == subgroup_info['val2']))
    subgroup_data = data_clean[mask].copy()
    
    # Run ATE analysis
    ate_analysis = run_ate_for_subgroup(subgroup_data, subgroup_info, treatment_var=TREATMENT_VAR, outcome_var=OUTCOME_VAR)
    
    if ate_analysis is None:
        continue
    
    # Create subgroup-specific output directory
    subgroup_dir = os.path.join(output_base_dir, f"Subgroup_{idx}_{subgroup_info['name']}")
    os.makedirs(subgroup_dir, exist_ok=True)
    
    # 1. Save ATE results
    ate_results_path = os.path.join(subgroup_dir, f"ate_results_{TREATMENT_VAR}_{timestamp}.csv")
    ate_analysis['ate_results'].to_csv(ate_results_path, index=False)
    print(f"‚úì ATE results saved: {ate_results_path}")
    
    # 2. Save subgroup definition
    definition_path = os.path.join(subgroup_dir, "subgroup_definition.txt")
    definition = create_subgroup_definition(
        subgroup_info['var1'], subgroup_info['val1'], 
        subgroup_info['var2'], subgroup_info['val2']
    )
    with open(definition_path, 'w') as f:
        f.write(definition)
    print(f"‚úì Subgroup definition saved: {definition_path}")
    
    # 3. Create and save transition matrix
    transition_matrix = create_transition_matrix(ate_analysis['subgroup_clean'], TREATMENT_VAR, PERSON_ID_VAR)
    if not transition_matrix.empty:
        transition_path = os.path.join(subgroup_dir, f"transition_matrix_{TREATMENT_VAR}_{timestamp}.csv")
        transition_matrix.to_csv(transition_path, index=False)
        print(f"‚úì Transition matrix saved: {transition_path}")
    else:
        print("‚ö†Ô∏è Transition matrix empty - insufficient data for buckets")
    
    # 4. Create and save ATE plot
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    treatment_grid = ate_analysis['ate_results']['Treatment']
    ate_featurized = ate_analysis['ate_results']['ATE_Featurized']
    ate_baseline = ate_analysis['ate_results']['ATE_Baseline']
    ci_lower_feat = ate_analysis['ate_results']['CI_Lower_Featurized']
    ci_upper_feat = ate_analysis['ate_results']['CI_Upper_Featurized']
    ci_lower_base = ate_analysis['ate_results']['CI_Lower_Baseline']
    ci_upper_base = ate_analysis['ate_results']['CI_Upper_Baseline']
    
    # Plot featurized ATE with confidence interval
    ax.plot(treatment_grid, ate_featurized, 'b-', linewidth=2, label='ATE (Featurized)')
    ax.fill_between(treatment_grid, ci_lower_feat, ci_upper_feat, alpha=0.3, color='blue')
    
    # Plot baseline ATE with confidence interval
    ax.plot(treatment_grid, ate_baseline, 'r--', linewidth=2, label='ATE (Baseline)')
    ax.fill_between(treatment_grid, ci_lower_base, ci_upper_base, alpha=0.3, color='red')
    
    # Add horizontal line at 0
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    
    ax.set_xlabel(f'{TREATMENT_VAR}')
    ax.set_ylabel(f'Average Treatment Effect on {OUTCOME_VAR}')
    ax.set_title(f'ATE Analysis: {subgroup_info["name"]}\n(n={len(ate_analysis["subgroup_clean"])} obs, {ate_analysis["subgroup_clean"][PERSON_ID_VAR].nunique()} users)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plot_path = os.path.join(subgroup_dir, f"ate_plot_{TREATMENT_VAR}_{timestamp}.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"‚úì ATE plot saved: {plot_path}")
    
    # Store successful analysis info
    successful_analyses.append({
        'subgroup': subgroup_info['name'],
        'directory': subgroup_dir,
        'mean_effect': subgroup_info['mean_effect'],
        'p_value': subgroup_info['p_value'],
        'n_observations': len(ate_analysis['subgroup_clean']),
        'n_users': ate_analysis['subgroup_clean'][PERSON_ID_VAR].nunique()
    })
    
    print(f"‚úÖ Completed analysis for subgroup: {subgroup_info['name']}")

print(f"\nüéâ ANALYSIS COMPLETE!")
print(f"Successfully analyzed {len(successful_analyses)} subgroups")
print(f"Results saved to: {output_base_dir}")

## 4. Summary and Results

Let's review the results of our subgroup analysis and show the key findings.

In [None]:
# Display summary of results
print("="*60)
print("SUBGROUP ANALYSIS SUMMARY")
print("="*60)

if successful_analyses:
    summary_df = pd.DataFrame(successful_analyses)
    
    print(f"\nüìä Successfully analyzed {len(successful_analyses)} subgroups with positive treatment effects:")
    print("\nSubgroup Details:")
    for i, analysis in enumerate(successful_analyses, 1):
        print(f"\n{i}. {analysis['subgroup']}")
        print(f"   ‚Ä¢ Mean treatment effect: {analysis['mean_effect']:.4f}")
        print(f"   ‚Ä¢ Statistical significance: p = {analysis['p_value']:.4f}")
        print(f"   ‚Ä¢ Sample size: {analysis['n_observations']} observations ({analysis['n_users']} users)")
        print(f"   ‚Ä¢ Output directory: {analysis['directory']}")
    
    print(f"\nüìÅ All results saved to: {output_base_dir}")
    
    print(f"\nüìã Generated files for each subgroup:")
    print(f"   ‚Ä¢ ate_results_{TREATMENT_VAR}_[timestamp].csv - ATE estimates with confidence intervals")
    print(f"   ‚Ä¢ subgroup_definition.txt - Logical definition of the subgroup")
    print(f"   ‚Ä¢ transition_matrix_{TREATMENT_VAR}_[timestamp].csv - Treatment level transitions") 
    print(f"   ‚Ä¢ ate_plot_{TREATMENT_VAR}_[timestamp].png - Visualization of treatment effects")
    
    print(f"\nüéØ KEY FINDINGS:")
    print(f"   ‚Ä¢ Treatment variable: {TREATMENT_VAR}")
    print(f"   ‚Ä¢ Outcome variable: {OUTCOME_VAR}")
    print(f"   ‚Ä¢ Analysis method: CATE (CausalForestDML) ‚Üí ATE (LinearDML)")
    print(f"   ‚Ä¢ Significance threshold: p < 0.05")
    print(f"   ‚Ä¢ Total subgroups tested: {len(subgroup_combinations)}")
    print(f"   ‚Ä¢ Significant positive effects: {len(successful_analyses)}")
    
else:
    print("‚ùå No subgroups with significant positive treatment effects were found.")
    print("Consider:")
    print("   ‚Ä¢ Relaxing the significance threshold")
    print("   ‚Ä¢ Using different subgroup definitions") 
    print("   ‚Ä¢ Checking data quality and sample sizes")

print(f"\n‚úÖ Analysis completed successfully!")

## 5. Sensitivity Analysis for Unobserved Confounding

To assess the robustness of our causal conclusions, we conduct sensitivity analysis to determine how strong unobserved confounding would need to be to explain away our findings. We implement two complementary approaches:

### Methods Overview:
- **Rosenbaum Bounds**: Tests sensitivity to hidden bias by examining how strong an unmeasured confounder would need to be to overturn significant results
- **E-values**: Quantifies the minimum strength of association an unmeasured confounder would need to have with both treatment and outcome to explain away the observed effect

Both methods help us understand the degree of unobserved confounding that would be required to nullify our treatment effect estimates.

### 5.1 Overall CATE Model Sensitivity Analysis

First, we assess the sensitivity of our overall CATE findings from the CausalForestDML model.

In [None]:
# Sensitivity analysis for overall CATE model
print("=== CATE Model Sensitivity Analysis ===\n")

# Check if we have the required column
if 'individual_treatment_effect' not in data_clean.columns:
    print("‚ö†Ô∏è WARNING: 'individual_treatment_effect' column not found in data_clean")
    print("Please ensure you have run the CATE analysis cells (Cell 22) before running sensitivity analysis.")
    print("\nThe CATE analysis must be executed first to generate individual treatment effects.")
    print("This cell cannot proceed without those effects.\n")
    
    # Create placeholder variables to prevent downstream errors
    sensitivity_results_overall = {
        'error': 'individual_treatment_effect column not found - run CATE analysis first'
    }
    overall_evalue = {'evalue_point': None, 'evalue_ci': None}
    overall_rosenbaum = {'critical_gamma': None, 'original_p_value': None}
else:
    # Get overall treatment effects from CATE model
    overall_effects = data_clean['individual_treatment_effect'].values
    mean_overall_effect = np.mean(overall_effects)
    std_overall_effect = np.std(overall_effects)

    print(f"Overall CATE Results:")
    print(f"  ‚Ä¢ Mean treatment effect: {mean_overall_effect:.4f} hours")
    print(f"  ‚Ä¢ Standard deviation: {std_overall_effect:.4f} hours")
    print(f"  ‚Ä¢ Sample size: {len(overall_effects)} individuals")

    # Calculate E-value for overall effect
    # Use approximate confidence interval based on standard error
    se_overall = std_overall_effect / np.sqrt(len(overall_effects))
    ci_lower_overall = mean_overall_effect - 1.96 * se_overall
    ci_upper_overall = mean_overall_effect + 1.96 * se_overall

    overall_evalue = calculate_evalue(
        estimate=mean_overall_effect,
        confidence_interval_lower=ci_lower_overall,
        confidence_interval_upper=ci_upper_overall
    )

    print(f"\nüìä E-value Analysis (Overall CATE):")
    print(f"  ‚Ä¢ Point estimate E-value: {overall_evalue['evalue_point']:.2f}")
    print(f"  ‚Ä¢ Confidence interval E-value: {overall_evalue['evalue_ci']:.2f}")
    print(f"\nInterpretation:")
    print(f"  An unmeasured confounder would need to be associated with both")
    print(f"  Copilot usage and after-hours collaboration by a risk ratio of")
    print(f"  {overall_evalue['evalue_point']:.1f} to fully explain away the observed effect.")

    # Rosenbaum bounds analysis for overall effects
    overall_rosenbaum = rosenbaum_bounds_approximation(overall_effects)

    print(f"\nüìä Rosenbaum Bounds Analysis (Overall CATE):")
    print(f"  ‚Ä¢ Original p-value: {overall_rosenbaum['original_p_value']:.2e}")
    if overall_rosenbaum['critical_gamma']:
        print(f"  ‚Ä¢ Critical Gamma (Œì): {overall_rosenbaum['critical_gamma']:.1f}")
    else:
        print(f"  ‚Ä¢ Critical Gamma (Œì): >5.0 (beyond tested range)")
    print(overall_rosenbaum['interpretation'])

    # Save sensitivity results
    sensitivity_results_overall = {
        'analysis_type': 'Overall_CATE',
        'sample_size': len(overall_effects),
        'mean_effect': mean_overall_effect,
        'standard_error': se_overall,
        'evalue_point': overall_evalue['evalue_point'],
        'evalue_ci': overall_evalue['evalue_ci'],
        'rosenbaum_critical_gamma': overall_rosenbaum['critical_gamma'],
        'original_p_value': overall_rosenbaum['original_p_value']
    }

print(f"\n‚úì Overall CATE sensitivity analysis completed")

### 5.2 Top Subgroups Sensitivity Analysis

Now we examine the sensitivity of our top-performing subgroups to unobserved confounding.

In [None]:
# Sensitivity analysis for top subgroups
print("=== Top Subgroups Sensitivity Analysis ===\n")

# Check if we have the required data
if 'individual_treatment_effect' not in data_clean.columns:
    print("‚ö†Ô∏è WARNING: 'individual_treatment_effect' column not found in data_clean")
    print("Skipping subgroup sensitivity analysis.\n")
    subgroup_sensitivity_results = []
elif 'top_5_subgroups' not in dir():
    print("‚ö†Ô∏è WARNING: 'top_5_subgroups' not defined")
    print("Please run the CATE subgroup identification cells first.\n")
    subgroup_sensitivity_results = []
else:
    subgroup_sensitivity_results = []

    for idx, (_, subgroup_info) in enumerate(top_5_subgroups.iterrows(), 1):
        print(f"--- Subgroup {idx}: {subgroup_info['name']} ---")
        
        # Get subgroup data and individual treatment effects
        mask = ((data_clean[subgroup_info['var1']] == subgroup_info['val1']) & 
                (data_clean[subgroup_info['var2']] == subgroup_info['val2']))
        subgroup_data = data_clean[mask]
        
        if len(subgroup_data) < 10:  # Skip if too small
            print(f"‚ö†Ô∏è Skipping - insufficient sample size ({len(subgroup_data)})")
            continue
        
        subgroup_effects = subgroup_data['individual_treatment_effect'].values
        mean_effect = subgroup_info['mean_effect']
        std_effect = subgroup_info['std_effect']
        se_effect = std_effect / np.sqrt(len(subgroup_effects))
        
        print(f"  ‚Ä¢ Mean effect: {mean_effect:.4f} hours")
        print(f"  ‚Ä¢ Standard error: {se_effect:.4f} hours") 
        print(f"  ‚Ä¢ P-value: {subgroup_info['p_value']:.4f}")
        print(f"  ‚Ä¢ Sample size: {len(subgroup_effects)} individuals")
        
        # Calculate E-value
        ci_lower_sub = mean_effect - 1.96 * se_effect
        ci_upper_sub = mean_effect + 1.96 * se_effect
        
        subgroup_evalue = calculate_evalue(
            estimate=mean_effect,
            confidence_interval_lower=ci_lower_sub,
            confidence_interval_upper=ci_upper_sub
        )
        
        print(f"  ‚Ä¢ E-value (point): {subgroup_evalue['evalue_point']:.2f}")
        print(f"  ‚Ä¢ E-value (CI): {subgroup_evalue['evalue_ci']:.2f}")
        
        # Rosenbaum bounds
        subgroup_rosenbaum = rosenbaum_bounds_approximation(subgroup_effects)
        critical_gamma = subgroup_rosenbaum['critical_gamma']
        
        if critical_gamma:
            print(f"  ‚Ä¢ Rosenbaum Œì: {critical_gamma:.1f}")
        else:
            print(f"  ‚Ä¢ Rosenbaum Œì: >3.0")
        
        # Store results
        subgroup_sensitivity_results.append({
            'subgroup_name': subgroup_info['name'],
            'rank': idx,
            'mean_effect': mean_effect,
            'p_value': subgroup_info['p_value'],
            'sample_size': len(subgroup_effects),
            'evalue_point': subgroup_evalue['evalue_point'],
            'evalue_ci': subgroup_evalue['evalue_ci'],
            'rosenbaum_gamma': critical_gamma,
            'robustness_score': min(subgroup_evalue['evalue_ci'], critical_gamma if critical_gamma else 3.0)
        })
        
        print()

# Create summary dataframe
if len(subgroup_sensitivity_results) > 0:
    sensitivity_df = pd.DataFrame(subgroup_sensitivity_results)
    sensitivity_df = sensitivity_df.sort_values('robustness_score', ascending=False)

    print("üìä SENSITIVITY SUMMARY (Top 5 Subgroups)")
    print("="*60)
    print(f"{'Rank':<4} {'Subgroup':<35} {'E-val':<6} {'Œì':<6} {'Robust':<6}")
    print("-"*60)

    for _, row in sensitivity_df.iterrows():
        gamma_str = f"{row['rosenbaum_gamma']:.1f}" if row['rosenbaum_gamma'] else ">3.0"
        print(f"{row['rank']:<4} {row['subgroup_name'][:34]:<35} {row['evalue_ci']:<6.1f} {gamma_str:<6} {row['robustness_score']:<6.1f}")

    print(f"\nMost robust subgroup: {sensitivity_df.iloc[0]['subgroup_name']}")
    print(f"Robustness score: {sensitivity_df.iloc[0]['robustness_score']:.1f}")

    print(f"\n‚úì Subgroup sensitivity analysis completed")
else:
    print("‚ö†Ô∏è No subgroup sensitivity results to display")

### 5.3 Robustness Summary and Interpretation

Let's create visualizations and provide comprehensive interpretation of our sensitivity analysis results.

---

## üìñ How to Interpret Sensitivity Analysis Metrics

### **E-value: Robustness to Unmeasured Confounding**

The **E-value** quantifies how strong an unmeasured confounder would need to be to completely explain away the observed effect.

**What it measures:** The minimum strength of association (risk ratio) that an unmeasured confounder would need to have with BOTH the treatment (Copilot usage) AND the outcome (after-hours collaboration) to fully explain away the observed effect.

**Interpretation Guidelines:**
- **E-value < 1.5**: ‚ö†Ô∏è **Potentially fragile** - Relatively weak confounding could explain away the effect
- **E-value 1.5-2.0**: üî∂ **Moderate robustness** - Would require moderate confounding
- **E-value > 2.0**: ‚úÖ **Strong robustness** - Would require substantial confounding  
- **E-value > 3.0**: ‚úÖ‚úÖ **Very robust** - Highly unlikely unmeasured confounders are this strong

**Example:** An E-value of 1.30 means an unmeasured confounder would need to increase the likelihood of both Copilot usage AND after-hours collaboration by at least 30% each to completely nullify the observed effect.

**Plausible Confounders to Consider:**
- Job role complexity (high-complexity roles ‚Üí more tool adoption + more collaboration)
- Team culture and norms
- Individual motivation/proactivity
- Manager support and expectations

---

### **Rosenbaum's Œì (Gamma): Hidden Bias in Treatment Assignment**

**Rosenbaum bounds** test how much hidden bias can exist in treatment assignment before the result becomes statistically non-significant.

**What it measures:** The maximum odds ratio of treatment assignment due to unobserved factors while maintaining statistical significance.

**Interpretation Guidelines:**
- **Œì < 1.5**: ‚ö†Ô∏è **Fragile** - Small amounts of hidden bias could eliminate significance
- **Œì = 1.5-2.0**: üî∂ **Moderate robustness** - Tolerates moderate hidden bias
- **Œì > 2.0**: ‚úÖ **Strong robustness** - Results remain significant with substantial hidden bias
- **Œì > 3.0**: ‚úÖ‚úÖ **Very robust** - Highly resistant to hidden bias

**Example:** A Œì of 2.0 means that even if two similar individuals differed by a factor of 2 in their odds of receiving treatment (due to unobserved factors), the finding would still be statistically significant.

**Key Distinction:** Rosenbaum Œì tests whether statistical **significance** survives hidden bias, while E-value tests whether the effect **magnitude** could be explained by confounding. Both matter!

---

### **How to Report Findings:**

**Strong Results (E-value > 2.0, Œì > 2.0):**
> "The findings are robust to unmeasured confounding. An unmeasured confounder would need to be associated with both treatment and outcome by a risk ratio of [E-value] to fully explain the effect, and the statistical significance is maintained even under substantial hidden bias (Œì = [value])."

**Moderate Results (E-value 1.5-2.0, Œì > 2.0):**
> "While the statistical significance is robust to hidden bias (Œì > [value]), the effect magnitude could be explained by a moderately strong unmeasured confounder (E-value = [value]). Plausible confounders such as [examples] should be considered."

**Fragile Results (E-value < 1.5):**
> "The findings are potentially sensitive to unmeasured confounding. A confounder with relatively modest associations (E-value = [value]) with both treatment and outcome could explain the observed effect. Caution is warranted in causal interpretation."

---

In [None]:
# Create visualization and comprehensive summary
print("=== COMPREHENSIVE SENSITIVITY ANALYSIS SUMMARY ===\n")

# Check if we have sensitivity results to visualize
if len(subgroup_sensitivity_results) == 0:
    print("‚ö†Ô∏è No subgroup sensitivity results available for visualization")
    print("Please ensure the previous sensitivity analysis cells completed successfully.\n")
else:
    # Create visualization of sensitivity results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: E-values vs Treatment Effects
    effects = [r['mean_effect'] for r in subgroup_sensitivity_results]
    evalues = [r['evalue_ci'] for r in subgroup_sensitivity_results]
    names = [r['subgroup_name'][:20] + '...' if len(r['subgroup_name']) > 20 else r['subgroup_name'] 
             for r in subgroup_sensitivity_results]
    
    ax1.scatter(effects, evalues, s=100, alpha=0.7, c='steelblue')
    for i, name in enumerate(names):
        ax1.annotate(f'{i+1}', (effects[i], evalues[i]), 
                    xytext=(5, 5), textcoords='offset points', fontsize=10)
    
    ax1.set_xlabel('Mean Treatment Effect (Hours)')
    ax1.set_ylabel('E-value (Confidence Interval)')
    ax1.set_title('Treatment Effect vs E-value\n(Higher E-value = More Robust)')
    ax1.grid(True, alpha=0.3)
    ax1.axhline(y=2.0, color='red', linestyle='--', alpha=0.5, label='E-value = 2.0')
    ax1.legend()

    # Plot 2: Rosenbaum Gamma values
    gammas = [r['rosenbaum_gamma'] if r['rosenbaum_gamma'] else 3.0 
              for r in subgroup_sensitivity_results]
    ranks = [r['rank'] for r in subgroup_sensitivity_results]
    
    colors = ['green' if g >= 2.0 else 'orange' if g >= 1.5 else 'red' for g in gammas]
    
    bars = ax2.bar(ranks, gammas, color=colors, alpha=0.7)
    ax2.set_xlabel('Subgroup Rank')
    ax2.set_ylabel('Critical Gamma (Œì)')
    ax2.set_title('Rosenbaum Bounds by Subgroup\n(Higher Œì = More Robust)')
    ax2.set_xticks(ranks)
    ax2.axhline(y=2.0, color='red', linestyle='--', alpha=0.5, label='Œì = 2.0 threshold')
    ax2.axhline(y=1.5, color='orange', linestyle='--', alpha=0.5, label='Œì = 1.5 threshold')
    ax2.grid(True, alpha=0.3)
    ax2.legend()

    plt.tight_layout()
    sensitivity_plot_path = os.path.join(output_base_dir, f"sensitivity_analysis_{timestamp}.png")
    plt.savefig(sensitivity_plot_path, dpi=300, bbox_inches='tight')
    plt.show()

    print(f"‚úì Sensitivity plot saved: {sensitivity_plot_path}")

# Comprehensive interpretation
print(f"\nüéØ ROBUSTNESS INTERPRETATION:")
print(f"="*50)

# Check if overall results exist
if 'overall_evalue' in dir() and overall_evalue.get('evalue_ci') is not None:
    print(f"\nüìà Overall CATE Model:")
    evalue_ci_str = f"{overall_evalue['evalue_ci']:.1f}" if overall_evalue['evalue_ci'] else "N/A"
    print(f"  ‚Ä¢ E-value: {evalue_ci_str}")
    
    if 'overall_rosenbaum' in dir() and overall_rosenbaum.get('critical_gamma'):
        print(f"  ‚Ä¢ Rosenbaum Œì: {overall_rosenbaum['critical_gamma']:.1f}")
    else:
        print(f"  ‚Ä¢ Rosenbaum Œì: >3.0")
else:
    print(f"\nüìà Overall CATE Model:")
    print(f"  ‚Ä¢ Results not available (run Cell 28 first)")

# Process subgroup robustness if we have results
if len(subgroup_sensitivity_results) > 0:
    robustness_levels = {
        'Very Robust': [],
        'Moderately Robust': [],
        'Potentially Fragile': []
    }

    for result in subgroup_sensitivity_results:
        score = result['robustness_score']
        name = result['subgroup_name'][:30]
        
        if score >= 2.0:
            robustness_levels['Very Robust'].append((name, score))
        elif score >= 1.5:
            robustness_levels['Moderately Robust'].append((name, score))
        else:
            robustness_levels['Potentially Fragile'].append((name, score))

    print(f"\nüìä Subgroup Robustness Categories:")
    for category, subgroups in robustness_levels.items():
        print(f"\n{category} (n={len(subgroups)}):")
        for name, score in subgroups:
            print(f"  ‚Ä¢ {name}: {score:.1f}")

    print(f"\nüí° KEY INSIGHTS:")
    print(f"  ‚Ä¢ Results requiring Œì > 2.0 or E-value > 2.0 are considered robust")
    print(f"  ‚Ä¢ An unmeasured confounder would need substantial associations")
    print(f"    with both treatment and outcome to explain away these effects")
    print(f"  ‚Ä¢ Higher-ranked subgroups generally show stronger robustness")

    # Save comprehensive sensitivity results
    sensitivity_summary = {
        'overall_analysis': sensitivity_results_overall if 'sensitivity_results_overall' in dir() else {},
        'subgroup_analysis': subgroup_sensitivity_results,
        'robustness_categories': robustness_levels,
        'timestamp': timestamp,
        'interpretation': {
            'very_robust_count': len(robustness_levels['Very Robust']),
            'moderate_robust_count': len(robustness_levels['Moderately Robust']),
            'fragile_count': len(robustness_levels['Potentially Fragile']),
            'most_robust_subgroup': sensitivity_df.iloc[0]['subgroup_name'] if 'sensitivity_df' in dir() and len(sensitivity_df) > 0 else None
        }
    }

    sensitivity_results_path = os.path.join(output_base_dir, f"sensitivity_analysis_results_{timestamp}.json")
    with open(sensitivity_results_path, 'w') as f:
        # Convert numpy types to native Python types for JSON serialization
        def convert_numpy(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            return obj
        
        import json
        json.dump(sensitivity_summary, f, indent=2, default=convert_numpy)

    print(f"\n‚úÖ SENSITIVITY ANALYSIS COMPLETE!")
    print(f"üìÅ Results saved to: {sensitivity_results_path}")
    if 'sensitivity_plot_path' in dir():
        print(f"üìä Visualization saved to: {sensitivity_plot_path}")

    print(f"\nüèÅ CONCLUSION:")
    print(f"The sensitivity analysis suggests that our findings are")
    print(f"{'robust' if len(robustness_levels['Very Robust']) >= 2 else 'moderately robust' if len(robustness_levels['Moderately Robust']) >= 2 else 'potentially sensitive'}")
    print(f"to unobserved confounding, with {len(robustness_levels['Very Robust'])} subgroups")
    print(f"showing very high robustness to hidden bias.")
else:
    print(f"\n‚ö†Ô∏è No subgroup sensitivity results available for interpretation")
    print(f"Please run the previous sensitivity analysis cells first.")