# Refutation flow for multi-treatment IRM


This notebook mirrors `refutation_flow.ipynb`, but for the
`causalis.scenarios.multi_unconfoundedness` scenario.

We estimate pairwise ATE contrasts against the baseline treatment `d_0`, then run:
- overlap diagnostics,
- score diagnostics,
- unconfoundedness balance checks,
- sensitivity analysis.


In [1]:
from causalis.scenarios.multi_unconfoundedness.dgp import generate_multitreatment_gamma_26

df = generate_multitreatment_gamma_26(
    n=20_000,
    seed=42,
    include_oracle=True,
    return_causal_data=False,
)
df.head()


Unnamed: 0,y,d_0,d_1,d_2,tenure_months,avg_sessions_week,spend_last_month,premium_user,urban_resident,support_tickets_q,...,m_obs_d_1,tau_link_d_1,m_d_2,m_obs_d_2,tau_link_d_2,g_d_0,g_d_1,g_d_2,cate_d_1,cate_d_2
0,1.724081,1.0,0.0,0.0,27.656605,3.198667,89.609464,0.0,1.0,0.0,...,0.245737,-0.352005,0.220739,0.220739,0.494166,3.279384,2.306314,5.375338,-0.97307,2.095954
1,0.658436,0.0,1.0,0.0,23.798386,3.362415,102.337236,0.0,0.0,3.0,...,0.17864,-0.30736,0.23683,0.23683,0.420278,2.80785,2.064853,4.27463,-0.742997,1.46678
2,3.894951,0.0,1.0,0.0,28.425009,3.391819,102.660712,0.0,1.0,1.0,...,0.209711,-0.320189,0.218158,0.218158,0.502415,3.069919,2.228798,5.073677,-0.841121,2.003758
3,2.363204,0.0,1.0,0.0,18.860066,4.071175,83.593417,0.0,0.0,2.0,...,0.175985,-0.316241,0.237508,0.237508,0.441677,2.716805,1.980234,4.225485,-0.736571,1.50868
4,6.232463,0.0,0.0,1.0,17.853087,3.140075,79.20987,0.0,1.0,1.0,...,0.23159,-0.35013,0.246973,0.246973,0.493624,3.224354,2.271869,5.282273,-0.952485,2.057919


In [2]:
from causalis.data_contracts import MultiCausalData

multi_causaldata = MultiCausalData(
    df=df,
    treatment_names=['d_0', 'd_1', 'd_2'],
    control_treatment='d_0',
    outcome='y',
    confounders=[
        'tenure_months',
        'avg_sessions_week',
        'spend_last_month',
        'premium_user',
        'urban_resident',
        'support_tickets_q',
        'discount_eligible',
        'credit_utilization',
    ],
)
multi_causaldata


MultiCausalData(df=(20000, 12), treatment_names=['d_0', 'd_1', 'd_2'], control_treatment='d_0')outcome='y', confounders=['tenure_months', 'avg_sessions_week', 'spend_last_month', 'premium_user', 'urban_resident', 'support_tickets_q', 'discount_eligible', 'credit_utilization'], user_id=None, 

# Inference


In [3]:
from causalis.scenarios.multi_unconfoundedness import MultiTreatmentIRM

# Prefer CatBoost when available; fall back to sklearn if not installed.
try:
    from catboost import CatBoostClassifier, CatBoostRegressor

    ml_m = CatBoostClassifier(
        loss_function='MultiClass',
        verbose=False,
        allow_writing_files=False,
        random_seed=42,
    )
    ml_g = CatBoostRegressor(
        verbose=False,
        allow_writing_files=False,
        random_seed=42,
    )
except ImportError:
    from sklearn.ensemble import RandomForestRegressor
    from sklearn.linear_model import LogisticRegression

    ml_m = LogisticRegression(multi_class='multinomial', max_iter=2000)
    ml_g = RandomForestRegressor(n_estimators=300, random_state=42, n_jobs=-1)

model = MultiTreatmentIRM(
    ml_g=ml_g,
    ml_m=ml_m,
    n_folds=5,
    normalize_ipw=False,
    trimming_threshold=0.01,
    random_state=42,
).fit(multi_causaldata)


In [4]:
dml_result = model.estimate(score='ATE', diagnostic_data=True)
dml_result.summary()


Unnamed: 0_level_0,d_1 vs d_0,d_2 vs d_0
field,Unnamed: 1_level_1,Unnamed: 2_level_1
estimand,ATE,ATE
model,MultiTreatmentIRM,MultiTreatmentIRM
value,"-1.2254 (ci_abs: -1.3296, -1.1212)","2.5704 (ci_abs: 2.3562, 2.7846)"
value_relative,"-31.0949 (ci_rel: -33.3772, -28.8125)","65.2240 (ci_rel: 59.3040, 71.1441)"
alpha,0.0500,0.0500
p_value,0.0000,0.0000
is_significant,True,True
n_treated,5003,4877
n_control,10120,10120
treatment_mean,2.9112,6.5755


# Overlap


For multi-treatment IRM, overlap is checked pairwise: baseline `d_0` vs each active treatment.
Key metrics are reported by comparison (`d_0 vs d_1`, `d_0 vs d_2`, ...).


In [5]:
from causalis.scenarios.multi_unconfoundedness.refutation import run_overlap_diagnostics

rep = run_overlap_diagnostics(multi_causaldata, dml_result)
rep['summary']


Unnamed: 0,comparison,metric,value,flag
0,d_0 vs d_1,edge_0.01_below,0.000066,GREEN
1,d_0 vs d_1,edge_0.01_above,0.0,GREEN
2,d_0 vs d_1,KS,0.100367,GREEN
3,d_0 vs d_1,AUC,0.566421,GREEN
4,d_0 vs d_1,ESS_treated_ratio,0.721938,GREEN
5,d_0 vs d_1,ESS_baseline_ratio,0.896655,GREEN
6,d_0 vs d_1,clip_m_total,0.000066,GREEN
7,d_0 vs d_1,overlap_pass,True,GREEN
8,d_0 vs d_2,edge_0.01_below,0.0,GREEN
9,d_0 vs d_2,edge_0.01_above,0.0,GREEN


## `edge_0.01_below`, `edge_0.01_above`
Share of pairwise propensity mass near 0 or 1.


In [6]:
rep['overlap']['by_comparison'][[
    'comparison',
    'edge_0.01_below',
    'edge_0.01_above',
    'flag_edge_001',
]]


Unnamed: 0,comparison,edge_0.01_below,edge_0.01_above,flag_edge_001
0,d_0 vs d_1,6.6e-05,0.0,GREEN
1,d_0 vs d_2,0.0,0.0,GREEN


## `ks`
Kolmogorov-Smirnov distance between pairwise score distributions.


In [7]:
rep['overlap']['by_comparison'][['comparison', 'ks', 'flag_ks']]


Unnamed: 0,comparison,ks,flag_ks
0,d_0 vs d_1,0.100367,GREEN
1,d_0 vs d_2,0.056139,GREEN


## `auc`
AUC for separating baseline vs active treatment using pairwise propensity score.


In [8]:
rep['overlap']['by_comparison'][['comparison', 'auc', 'flag_auc']]


Unnamed: 0,comparison,auc,flag_auc
0,d_0 vs d_1,0.566421,GREEN
1,d_0 vs d_2,0.533669,GREEN


## `ess_ratio_treated`, `ess_ratio_baseline`
Effective sample size ratios implied by inverse-propensity weights.


In [9]:
rep['overlap']['by_comparison'][[
    'comparison',
    'ess_ratio_treated',
    'ess_ratio_baseline',
    'flag_ess_treated',
    'flag_ess_baseline',
]]


Unnamed: 0,comparison,ess_ratio_treated,ess_ratio_baseline,flag_ess_treated,flag_ess_baseline
0,d_0 vs d_1,0.721938,0.896655,GREEN,GREEN
1,d_0 vs d_2,0.814103,0.896655,GREEN,GREEN


## `clip_m_total`
Share of observations affected by propensity trimming in each comparison.


In [10]:
rep['overlap']['by_comparison'][['comparison', 'clip_m_total', 'flag_clip_m']]


Unnamed: 0,comparison,clip_m_total,flag_clip_m
0,d_0 vs d_1,6.6e-05,GREEN
1,d_0 vs d_2,0.0,GREEN


## Overall overlap verdict


In [11]:
{
    'overall_flag': rep['overall_flag'],
    'all_comparisons_pass': rep['overlap']['pass'],
}


{'overall_flag': 'GREEN', 'all_comparisons_pass': True}

# Score


Score diagnostics validate orthogonal moments and influence behavior for each baseline contrast.


In [12]:
from causalis.scenarios.multi_unconfoundedness.refutation import run_score_diagnostics

rep_score = run_score_diagnostics(multi_causaldata, dml_result)
rep_score['summary']


Unnamed: 0,comparison,metric,value,flag
0,d_1 vs d_0,se_plugin,0.05315128,
1,d_1 vs d_0,psi_p99_over_med,10.182,YELLOW
2,d_1 vs d_0,psi_kurtosis,93.74657,RED
3,d_1 vs d_0,max_|t|_gk,8.303548,RED
4,d_1 vs d_0,max_|t|_g0,6.04451,RED
5,d_1 vs d_0,max_|t|_mk,1.08217,RED
6,d_1 vs d_0,max_|t|_m0,1.220473,RED
7,d_1 vs d_0,max_|t|,8.303548,RED
8,d_1 vs d_0,oos_tstat_fold,-1.176496e-15,GREEN
9,d_1 vs d_0,oos_tstat_strict,-1.497106e-15,GREEN


## `psi_p99_over_med`, `psi_kurtosis`
Tail diagnostics of influence values by comparison.


In [13]:
rep_score['influence_diagnostics']['by_comparison']


Unnamed: 0,comparison,se_plugin,kurtosis,p99_over_med
0,d_1 vs d_0,0.053151,93.74657,10.182
1,d_2 vs d_0,0.109287,287.654766,17.904955


## Top influential observations


In [14]:
rep_score['influence_diagnostics']['top_influential'].head(20)


Unnamed: 0,comparison,i,psi,m_k,residual_k,residual_0
0,d_1 vs d_0,2387,-206.58467,0.010101,-2.105333,-1.495193
1,d_1 vs d_0,13045,188.04722,0.014999,2.817467,1.790805
2,d_1 vs d_0,7609,166.136647,0.059897,9.972243,8.393727
3,d_1 vs d_0,18521,-150.717871,0.451285,21.968336,19.422518
4,d_1 vs d_0,5145,140.704496,0.030226,4.271055,2.447766
5,d_1 vs d_0,15110,112.291193,0.128562,14.566815,12.326445
6,d_1 vs d_0,13471,103.747853,0.186516,20.086296,14.916909
7,d_1 vs d_0,6257,-88.685294,0.412085,23.849319,21.475964
8,d_1 vs d_0,7117,-81.913914,0.52425,14.666899,15.07373
9,d_1 vs d_0,10254,-78.445073,0.328219,17.97709,17.290261


## `max_|t|_gk`, `max_|t|_g0`, `max_|t|_mk`, `max_|t|_m0`
Orthogonality derivative checks by comparison.


In [15]:
rep_score['orthogonality_max_t']


Unnamed: 0,comparison,max_|t|_gk,max_|t|_g0,max_|t|_mk,max_|t|_m0,max_|t|
0,d_1 vs d_0,8.303548,6.04451,1.08217,1.220473,8.303548
1,d_2 vs d_0,8.270394,6.04451,1.596762,1.220473,8.270394


## `oos_tstat_fold`, `oos_tstat_strict`
Out-of-sample moment tests.


In [16]:
rep_score['oos_moment_test']['by_comparison']


Unnamed: 0,comparison,oos_tstat_fold,oos_tstat_strict,p_value_fold,p_value_strict
0,d_1 vs d_0,-1.176496e-15,-1.497106e-15,1.0,1.0
1,d_2 vs d_0,-5.201603e-16,-1.248197e-15,1.0,1.0


In [17]:
{
    'overall_flag': rep_score['overall_flag'],
    'flags': rep_score['flags'],
    'flags_by_comparison': rep_score['flags_by_comparison'],
}


{'overall_flag': 'RED',
 'flags': {'psi_tail_ratio': 'YELLOW',
  'psi_kurtosis': 'RED',
  'ortho_max_|t|': 'RED',
  'oos_moment': 'GREEN',
  'ortho_max_|t|_gk': 'RED',
  'ortho_max_|t|_g0': 'RED',
  'ortho_max_|t|_mk': 'GREEN',
  'ortho_max_|t|_m0': 'GREEN'},
 'flags_by_comparison':    comparison psi_tail_ratio psi_kurtosis ortho_max_|t| oos_moment  \
 0  d_1 vs d_0         YELLOW          RED           RED      GREEN   
 1  d_2 vs d_0         YELLOW          RED           RED      GREEN   
 
   overall_flag  
 0          RED  
 1          RED  }

# SUTVA


In [18]:
from causalis.shared import print_sutva_questions

print_sutva_questions()


1.) Are your clients independent (i). Outcome of ones do not depend on others?
2.) Are all clients have full window to measure metrics?
3.) Do you measure confounders before treatment and outcome after?
4.) Do you have a consistent label of treatment, such as if a person does not receive a treatment, he has a label 0?


# Unconfoundedness


In [19]:
from causalis.scenarios.multi_unconfoundedness.refutation import run_unconfoundedness_diagnostics

rep_uc = run_unconfoundedness_diagnostics(multi_causaldata, dml_result)
rep_uc['summary']


Unnamed: 0,comparison,metric,value,flag
0,d_0 vs d_1,balance_max_smd,0.05714,GREEN
1,d_0 vs d_1,balance_frac_violations,0.0,GREEN
2,d_0 vs d_1,balance_pass,True,GREEN
3,d_0 vs d_2,balance_max_smd,0.022568,GREEN
4,d_0 vs d_2,balance_frac_violations,0.0,GREEN
5,d_0 vs d_2,balance_pass,True,GREEN
6,overall,balance_max_smd,0.05714,GREEN
7,overall,balance_frac_violations,0.0,GREEN
8,overall,balance_pass,True,GREEN


## `balance_max_smd`, `balance_frac_violations`
Weighted SMD checks for each baseline-vs-treatment comparison.


In [20]:
rep_uc['balance']['by_comparison'][[
    'comparison',
    'smd_max',
    'frac_violations',
    'pass',
    'flag_max_smd',
    'flag_violations',
    'overall_flag',
]]


Unnamed: 0,comparison,smd_max,frac_violations,pass,flag_max_smd,flag_violations,overall_flag
0,d_0 vs d_1,0.05714,0.0,True,GREEN,GREEN,GREEN
1,d_0 vs d_2,0.022568,0.0,True,GREEN,GREEN,GREEN


## Worst covariates by weighted SMD


In [21]:
rep_uc['balance']['worst_features']


avg_sessions_week     0.057140
tenure_months         0.046042
urban_resident        0.028845
premium_user          0.024638
support_tickets_q     0.016798
spend_last_month      0.013420
discount_eligible     0.007124
credit_utilization    0.004336
dtype: float64

In [22]:
{
    'overall_flag': rep_uc['overall_flag'],
    'flags': rep_uc['flags'],
    'overall_balance_pass': rep_uc['balance']['pass'],
}


{'overall_flag': 'GREEN',
 'flags': {'balance_max_smd': 'GREEN', 'balance_violations': 'GREEN'},
 'overall_balance_pass': True}

## Sensitivity analysis


In [23]:
from causalis.scenarios.multi_unconfoundedness.refutation.unconfoundedness.sensitivity import (
    sensitivity_analysis,
    sensitivity_benchmark,
    get_sensitivity_summary,
)

sens = sensitivity_analysis(
    dml_result,
    r2_y=0.01,
    r2_d=[0.01, 0.01],
    rho=[1.0, 1.0],
    alpha=0.05,
)

print(get_sensitivity_summary(dml_result))
sens



------------------ Scenario          ------------------
Significance Level: alpha=0.05
Null Hypothesis: H0=0.0
Sensitivity parameters: cf_y=0.010101010101010102; r2_d=[0.01 0.01], rho=[1. 1.], use_signed_rr=False

               theta        se  max_bias  max_bias_base  bound_width     sigma2       nu2  sampling_ci_l  sampling_ci_u   theta_l   theta_u  bias_aware_ci_l  bias_aware_ci_u        rv       rva
d_1 vs d_0 -1.225401  0.053151  0.071358       7.064466     0.071358  11.726175  4.256007      -1.329575      -1.121226 -1.296759 -1.154043        -1.400830        -1.048932  0.748664  0.713779
d_2 vs d_0  2.570377  0.109287  0.075791       7.503273     0.075791  11.726175  4.801148       2.356178       2.784576  2.494587  2.646168         2.281413         2.861468  0.920747  0.907083


{'theta': array([-1.22540079,  2.57037715]),
 'se': array([0.05315128, 0.10928717]),
 'alpha': 0.05,
 'z': 1.959963984540054,
 'H0': 0.0,
 'sampling_ci': array([[-1.32957538, -1.1212262 ],
        [ 2.35617823,  2.78457606]]),
 'theta_bounds_cofounding': array([[-1.29675903, -1.15404255],
        [ 2.49458652,  2.64616778]]),
 'bias_aware_ci': array([[-1.40082993, -1.0489322 ],
        [ 2.28141319,  2.86146838]]),
 'max_bias_base': array([7.06446594, 7.50327258]),
 'max_bias': array([0.07135824, 0.07579063]),
 'bound_width': array([0.07135824, 0.07579063]),
 'sigma2': 11.726175012924882,
 'nu2': array([4.25600667, 4.8011478 ]),
 'rv': array([0.74866425, 0.92074748]),
 'rva': array([0.71377934, 0.90708265]),
 'contrast_labels': ['d_1 vs d_0', 'd_2 vs d_0'],
 'params': {'cf_y': 0.010101010101010102,
  'r2_d': array([0.01, 0.01]),
  'rho': array([1., 1.]),
  'use_signed_rr': False}}

In [24]:
sensitivity_benchmark(dml_result, benchmarking_set=['tenure_months'])


Unnamed: 0,cf_y,r2_y,r2_d,rho,theta_long,theta_short,delta
d_1 vs d_0,8.091184e-08,8.091183e-08,4.65316e-06,1.0,-1.225401,-1.243423,0.018023
d_2 vs d_0,8.091184e-08,8.091183e-08,3.465145e-07,1.0,2.570377,2.500075,0.070302
