In [1]:
import os
from time import time
import math
import random
import gc
import warnings

import numpy as np
from scipy import stats
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from prep import load_data

warnings.filterwarnings('ignore')
np.random.seed(123)
plt.style.use('seaborn-dark')

In [None]:
PROJECT = 'npf-brave-220ac3f'
input_dir = '../input'
pkl_file = os.path.join(input_dir, 'pvp.pkl')
rank_pkl_file = os.path.join(input_dir, 'pvp_rank.pkl')

print("Loading dataset...")
t0 = time()
df = load_data(PROJECT, pkl_file)
rank_df = load_data(PROJECT, rank_pkl_file, is_rank=True)
print(f'End in {time() - t0 :.2f}s.')

print("Merging dataframes...")
t0 = time()
df = pd.merge(df, rank_df, how='left', on='appUserId')
df = df.fillna(0)
del rank_df; gc.collect()
print(f'End in {time() - t0 :.2f}s.')

print(f'shape: {df.shape}')
display(df.head(2))

Loading dataset...
End in 1.40s.
Merging dataframes...


# Check dataset

In [None]:
print(df.columns)

In [None]:
treat_col = 'is_play'

In [None]:
cnt_tr = len(df[df[treat_col]==1])
cnt_ct = len(df[df[treat_col]==0])

print(f'# of samples (treatment): {cnt_tr}')
print(f'# of samples (control):   {cnt_ct}')

In [None]:
# treatmentごとの各特徴量のヒストグラム(調整前)
plt.figure(figsize=(24, 24))
tmp1 = df[df[treat_col]==1].sample(1000, random_state=123)
tmp0 = df[df[treat_col]==0].sample(2000, random_state=123)
vis_cols = [c for c in df.columns if c not in ['appUserId']]
for idx, col in enumerate(vis_cols):
    plt.subplot(10, 6, idx+1)
    plt.hist(tmp1[col], bins=10, alpha=0.5, label='treat')
    plt.hist(tmp0[col], bins=10, alpha=0.5, label='control')
    plt.legend(loc='best')
    plt.title(col)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 6))
tmp1 = df[(df[treat_col]==1) & (df['prev_grade']>0)].sample(3000, random_state=123)
tmp0 = df[(df[treat_col]==0) & (df['prev_grade']>0)].sample(600, random_state=123)
print(len(tmp1), len(tmp0))
plt.hist(tmp1['prev_grade'], bins=10, alpha=0.5, label='treat')
plt.hist(tmp0['prev_grade'], bins=10, alpha=0.5, label='control')
plt.legend(loc='best')
plt.title(col)

cols = ['sess_cnt_rank', 'sess_time_total_rank', 'sess_time_min_rank', 
        'sess_time_mean_rank', 'sess_time_median_rank', 'sess_time_max_rank']
sns.pairplot(df[cols].sample(1000, random_state=123))
sns.heatmap(df[cols].corr(), annot=True, cmap="Reds", vmax=1, vmin=0, center=0)

cols = ['purchase_amt', 'purchase_dates', 
        'spend_paid_amt', 'spend_paid_dates', 'spend_free_amt', 'spend_free_dates',]
sns.pairplot(df[cols].sample(1000, random_state=123))
sns.heatmap(df[cols].corr(), annot=True, cmap="Reds", vmax=1, vmin=0, center=0)

# Calculate propensity score

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve

drop_cols = ['appUserId', 'is_ios', 'elapsed_date',   'is_jp', 
             'sess_cnt', 'sess_time_total', 'sess_time_min', 'sess_time_mean', 'sess_time_median', 'sess_time_max',
             'purchase_amt', 'purchase_dates', 'spend_paid_amt', 'spend_paid_dates', 'spend_free_amt', 'spend_free_dates',
             'vc_possession', 'unit_possession', 
             'is_played_sc_0050', 'is_played_sc_0051', 'is_played_sc_0052', 'is_played_sc_0053', 
             'cnt_sc_0050', 'cnt_sc_0051', 'cnt_sc_0052', 'cnt_sc_0053', 
             'prev_grade_rank',
             # 似た共変量は相関が強いので代表的なものだけ使う
             'sess_time_total_rank', 'sess_time_max_rank', 
             'sess_time_mean_rank', 'sess_time_median_rank', 
             #'sess_time_min_rank', 
             #'purchase_amt', 'purchase_dates', 
             'spend_paid_amt_rank', 'spend_paid_dates_rank', 
             #'spend_free_amt', 'spend_free_dates',
             'cnt_sc_0050_rank', 'cnt_sc_0051_rank', 'cnt_sc_0052_rank', 'cnt_sc_0053_rank', 
             'is_play', 'pay_amt', 'is_pay', 'event_time']
drop_cols.remove(treat_col)
use_cols = [c for c in df.columns if c not in drop_cols]
print('use_cols: ', use_cols)
merged_df = df[use_cols]    # treat_col は含んでいる

# 前回プレイユーザーに絞る
merged_df = merged_df[merged_df['prev_grade']>0].reset_index(drop=True)

#tr_df = merged_df[merged_df[treat_col]==1].reset_index(drop=True)
#ct_df = merged_df[merged_df[treat_col]==0].reset_index(drop=True)
#merged_df = pd.concat([tr_df, ct_df.sample(n=len(tr_df), random_state=123)]).reset_index(drop=True)
#merged_df = pd.concat([tr_df, tr_df, ct_df]).reset_index(drop=True)
merged_x = merged_df.drop(treat_col, axis=1).values
merged_y = merged_df[treat_col].values

train_df, test_df = train_test_split(merged_df, test_size=0.25, random_state=123)
train_x = train_df.drop([treat_col], axis=1).values
train_y = train_df[treat_col].values
test_x = test_df.drop([treat_col], axis=1).values
test_y = test_df[treat_col].values

In [None]:
import statsmodels.api as sm

X = merged_x
y = merged_y

X = sm.add_constant(X)
model = sm.Logit(y, X)
res = model.fit()

print(merged_df.drop(treat_col, axis=1).columns)
res.summary()

In [None]:
print('Start logistic regression....')
clf = LogisticRegression(penalty='l1', C=1, solver='saga', random_state=123)
clf.fit(train_x, train_y)
pred = clf.predict(test_x)
auc = roc_auc_score(test_y, pred)
print(f'test-AUC(logistic regression): {auc :.6f}')

In [None]:
print('\nStart random forest....')
clf = RandomForestClassifier(
    max_depth=8, 
    min_samples_leaf=int(len(merged_df)/1000), 
    n_jobs=-1, random_state=123)
clf.fit(train_x, train_y)
pred = clf.predict(test_x)
auc = roc_auc_score(test_y, pred)
print(f'test-AUC(random forest): {auc :.6f}')

In [None]:
#TODO: scalerやる
#scaler = StandardScaler()
#merged_x = scaler.fit_transform(merged_x)

In [None]:
clf = LogisticRegression(penalty='l1', C=0.1, solver='saga', random_state=123)
clf.fit(merged_x, merged_y)
merged_df['propensity_lr'] = clf.predict_proba(merged_x)[:, 1]
clf.coef_

In [None]:
clf = RandomForestClassifier(max_depth=8, 
                             min_samples_leaf=int(len(merged_df)/100), 
                             n_jobs=-1, random_state=123)
clf.fit(merged_x, merged_y)
merged_df['propensity_rf'] = clf.predict_proba(merged_x)[:, 1]

In [None]:
importance_df = pd.DataFrame()
drop_cols = [treat_col, 'propensity_lr', 'propensity_rf']
use_cols = [c for c in merged_df.columns if c not in drop_cols]
importance_df['feature'] = merged_df[use_cols].columns
importance_df['importance'] = clf.feature_importances_
importance_df.sort_values(by='importance', ascending=False, inplace=True)
plt.barh(
    importance_df.sort_values(by='importance')['feature'], 
    importance_df.sort_values(by='importance')['importance'])

In [None]:
treat_ps_lr = merged_df.loc[merged_df[treat_col]==1, 'propensity_lr'].values
control_ps_lr = merged_df.loc[merged_df[treat_col]==0, 'propensity_lr'].values

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.hist(treat_ps_lr, bins=10, alpha=0.5, label='treat')
#plt.hist(control_ps_lr, bins=10, alpha=0.5, label='control')
plt.xlim(0, 1)
plt.title('logistic regression')
plt.legend(loc='best')

plt.subplot(1, 2, 2)
#plt.hist(treat_ps_lr, bins=10, alpha=0.5, label='treat')
plt.hist(control_ps_lr, bins=10, alpha=0.5, label='control')
plt.xlim(0, 1)
plt.title('logistic regression')
plt.legend(loc='best')

plt.show()


treat_ps_rf = merged_df.loc[merged_df[treat_col]==1, 'propensity_rf'].values
control_ps_rf = merged_df.loc[merged_df[treat_col]==0, 'propensity_rf'].values

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.hist(treat_ps_rf, bins=10, alpha=0.5, label='treat')
#plt.hist(control_ps_rf, bins=10, alpha=0.5, label='control')
plt.xlim(0, 1)
plt.title('random forest')
plt.legend(loc='best')

plt.subplot(1, 2, 2)
#plt.hist(treat_ps_rf, bins=10, alpha=0.5, label='treat')
plt.hist(control_ps_rf, bins=10, alpha=0.5, label='control')
plt.xlim(0, 1)
plt.title('random forest')
plt.legend(loc='best')

plt.show()

# treatment有無での各特徴量のヒストグラム
plt.rcParams["font.size"] = 10

# オリジナル, バイアス有り(unweighted)
plt.figure(figsize=(18, 4))
plt.suptitle('unweighted', fontsize=16)
for idx in range(12):
    col = f'f{idx}'
    plt.subplot(2, 6, idx+1)
    plt.hist(treat_df[col], bins=10, alpha=0.5)
    plt.hist(control_df[col], bins=10, alpha=0.5)
    plt.title(col)
plt.tight_layout()
plt.show()

# 傾向スコアの逆数を乗じて補正(IPW)
inv_tr_ps = 1 / treat_ps_lr
inv_ct_ps = 1 / (1-control_ps_lr)
plt.figure(figsize=(18, 4))
plt.suptitle('IPW', fontsize=16)
for idx in range(12):
    col = f'f{idx}'
    plt.subplot(2, 6, idx+1)
    plt.hist(treat_df[col] * inv_tr_ps, bins=10, alpha=0.5)
    plt.hist(control_df[col] * inv_ct_ps, bins=10, alpha=0.5)
    plt.title(col)
plt.tight_layout()
plt.show()

# 傾向スコアの逆数を乗じて補正(IPW, truncated)
ps_min, ps_max = 0.05, 0.95
inv_tr_ps = 1 / np.clip(treat_ps_lr, ps_min, ps_max)
inv_ct_ps = 1 / np.clip(1 - control_ps_lr, ps_min, ps_max)
plt.figure(figsize=(18, 4))
plt.suptitle('IPW, truncated', fontsize=16)
for idx in range(12):
    col = f'f{idx}'
    plt.subplot(2, 6, idx+1)
    plt.hist(treat_df[col] * inv_tr_ps, bins=10, alpha=0.5)
    plt.hist(control_df[col] * inv_ct_ps, bins=10, alpha=0.5)
    plt.title(col)
plt.tight_layout()
plt.show()

# 傾向スコアの逆数を乗じて補正(IPW, discarded)
is_valid_tr = np.where((treat_ps_lr > ps_min) & (treat_ps_lr < ps_max), 1, 0)
is_valid_ct = np.where((1 - control_ps_lr > ps_min) & (1 - control_ps_lr < ps_max), 1, 0)
inv_tr_ps = 1 / treat_ps_lr
inv_ct_ps = 1 / (1 - control_ps_lr)
plt.figure(figsize=(18, 4))
plt.suptitle('IPW, discarded', fontsize=16)
for idx in range(12):
    col = f'f{idx}'
    plt.subplot(2, 6, idx+1)
    plt.hist((treat_df[col] * inv_tr_ps).loc[is_valid_tr == 1], bins=10, alpha=0.5)
    plt.hist((control_df[col] * inv_ct_ps).loc[is_valid_ct == 1], bins=10, alpha=0.5)
    plt.title(col)
plt.tight_layout()
plt.show()

# overlap weightを加えた傾向スコアの逆数を乗じて補正(overlap weight)
inv_tr_ps = np.clip(1 - treat_ps_lr, ps_min, ps_max)
inv_ct_ps = np.clip(control_ps_lr, ps_min, ps_max)
plt.figure(figsize=(18, 4))
plt.suptitle('overlap weight', fontsize=16)
for idx in range(12):
    col = f'f{idx}'
    plt.subplot(2, 6, idx+1)
    plt.hist(treat_df[col] * inv_tr_ps, bins=10, alpha=0.5)
    plt.hist(control_df[col] * inv_ct_ps, bins=10, alpha=0.5)
    #sns.distplot(treat_df[col] * inv_tr_ps, kde = True)
    #sns.distplot(control_df[col] * inv_ct_ps, kde = True)
    plt.title(col)
plt.tight_layout()
plt.show()

# Estimate ATE/ATT/ATC

In [None]:
ps_col = 'propensity_lr'
outcome_col = 'pay_amt'

In [None]:
# 調整せず比較した場合
cvr_tr = df.loc[df[treat_col]==1, outcome_col].mean()
cvr_ct = df.loc[df[treat_col]==0, outcome_col].mean()

print('Unadjusted estimation:')
print(f'CVR(treatment): {cvr_tr :.6f}')
print(f'CVR(control):   {cvr_ct :.6f}')
print(f'ATE(biased):    {cvr_tr - cvr_ct :.6f}')

# Matching

In [None]:
"""
https://microsoft.github.io/dowhy/_modules/dowhy/causal_estimators/propensity_score_matching_estimator.html
"""
from sklearn.neighbors import NearestNeighbors

def matching(df, t_col, y_col, ps_col, caliper=None):
    print('Matching....')
    t0 = time()
    treated = df.loc[df[t_col] == 1].reset_index(drop=True)
    control = df.loc[df[t_col] == 0].reset_index(drop=True)

    # estimate ATT on treated by summing over difference between matched neighbors
    nn = NearestNeighbors(n_neighbors=1, algorithm='ball_tree')
    control_neighbors = nn.fit(control[ps_col].values.reshape(-1, 1))
    distances, indices = control_neighbors.kneighbors(treated[ps_col].values.reshape(-1, 1))

    if caliper == None:
        caliper = np.median(distances)
        print(f'Caliper is setted to the median of distances, caliper = {caliper}')
    use_idx = np.where(distances.reshape(-1) <= caliper)
    distances = distances[use_idx]
    indices = indices[use_idx]
    num_treated = len(distances)
    #plt.hist(distances, bins=20)
    
    treated_outcome = treated[y_col].values[use_idx]
    control_outcome = control[y_col].values[indices.reshape(-1)]
    att = treated_outcome.mean() - control_outcome.mean()
    _, pvalue = stats.ttest_ind(treated_outcome, control_outcome, equal_var=False)
    
    print(f'\nTreatment outcome: {treated_outcome.mean() :.6f}')
    print(f'Control outcome:   {control_outcome.mean() :.6f}')
    print(f'ATT(matching):     {att :.6f}, (p-value={pvalue :.6f})')

    # prepair dataframes to check SD 
    print(f'# of matched pairs: {num_treated} ({num_treated / len(treated) :.3f} of all records)')
    att_pair_tr_df = pd.DataFrame(treated.values[use_idx])
    att_pair_tr_df.columns = [col for col in treated.columns]
    att_pair_ct_df = pd.DataFrame(control.values[indices.reshape(-1), :])
    att_pair_ct_df.columns = [col for col in control.columns]
    
    # Now computing ATC
    nn = NearestNeighbors(n_neighbors=1, algorithm='ball_tree')
    treated_neighbors = nn.fit(treated[ps_col].values.reshape(-1, 1))
    distances, indices = treated_neighbors.kneighbors(control[ps_col].values.reshape(-1, 1))

    use_idx = np.where(distances.reshape(-1) <= caliper)
    distances = distances[use_idx]
    indices = indices[use_idx]
    num_control = len(distances)
    
    treated_outcome = treated[y_col].values[indices.reshape(-1)]
    control_outcome = control[y_col].values[use_idx]
    atc = treated_outcome.mean() - control_outcome.mean()
    _, pvalue = stats.ttest_ind(treated_outcome, control_outcome, equal_var=False)
    
    print(f'\nTreatment outcome: {treated_outcome.mean() :.6f}')
    print(f'Control outcome:   {control_outcome.mean() :.6f}')
    print(f'ATC(matching):     {atc :.6f}, (p-value={pvalue :.6f})')

    # prepair dataframes to check SD 
    print(f'# of matched pairs: {num_control} ({num_control / len(control) :.3f} of all records)')
    atc_pair_tr_df = pd.DataFrame(treated.values[indices.reshape(-1), :])
    atc_pair_tr_df.columns = [col for col in treated.columns]
    atc_pair_ct_df = pd.DataFrame(control.values[use_idx])
    atc_pair_ct_df.columns = [col for col in control.columns]

    # Estimate ATE
    ate = (att * num_treated + atc * num_control) / (num_treated + num_control)
    print(f'\nATE(matching): {ate :.6f}')

    print(f'\nEnd in {time() - t0 :.2f}s.')

    return att_pair_tr_df, att_pair_ct_df, atc_pair_tr_df, atc_pair_ct_df

merged_df[outcome_col] = df[outcome_col]
att_pair_tr_df, att_pair_ct_df, atc_pair_tr_df, atc_pair_ct_df = matching(
    merged_df, t_col=treat_col, y_col=outcome_col, ps_col=ps_col, caliper=1.e-4)

# Validation of adjusted covariates

In [None]:
# standardized difference
def absolute_standardized_difference(treat, control):
    sd = (treat.mean() - control.mean()) / np.sqrt((treat.var() + control.var()) / 2)
    return abs(sd)

sd_list = []
vis_cols = [c for c in merged_df.columns if c not in [treat_col, outcome_col, 'propensity_lr', 'propensity_rf']]
for col in vis_cols:
    tr_df = merged_df.loc[merged_df[treat_col]==1, col].reset_index(drop=True)
    ct_df = merged_df.loc[merged_df[treat_col]==0, col].reset_index(drop=True)
    sd_biased = absolute_standardized_difference(tr_df, ct_df)
    sd_adjusted = absolute_standardized_difference(att_pair_tr_df[col], att_pair_ct_df[col])
    sd_list.append([sd_biased, sd_adjusted])
    print(f'col: {col :<26}, SD(biased): {sd_biased :.6f}, SD(matched): {sd_adjusted :.6f}')

sd_arr = np.array(sd_list)
plt.figure(figsize=(6, 6))
plt.scatter(vis_cols, sd_arr[:, 0], label='biased')
plt.scatter(vis_cols, sd_arr[:, 1], label='matched')
plt.hlines(y=0.1, xmin=-1, xmax=len(vis_cols), linestyles='dotted', linewidths=0.5)
plt.xlim(-0.5, len(vis_cols) - 0.5)
plt.xticks(rotation=90)
plt.legend(loc='best')
plt.tight_layout()
plt.show();

# Sensitivity Analysis
- [Sensitivity Analysis without Assumptions, Ding and VanderWeele (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4820664/)

In [None]:
def get_evalue_curve(tr, ct):
    # risk ratio
    rr = tr_est / ct_est if tr_est / ct_est >= 1. else ct_est / tr_est
    evalue = rr + np.sqrt(rr * (rr - 1))

    print(f'Est. risk ratio: {rr :.4f}')
    print(f'E-value:         {evalue :.4f}')

    # visualize
    x_max = math.ceil(evalue * 3)
    y_max = x_max
    x_start = rr * (1 - y_max) / (rr - y_max)
    x = np.arange(x_start, x_max, 0.02)
    y = rr * (1 - x) / (rr - x)

    plt.figure(figsize=(6, 6))
    plt.rcParams["font.size"] = 12
    plt.plot(x, y, 'b:', label='boundary')
    plt.scatter(evalue, evalue, c='r', label='E-value')
    plt.text(evalue * 0.5, evalue * 0.5, 'significant-effect zone',
             horizontalalignment='left', verticalalignment='center')
    plt.text(evalue * 1.5, evalue * 1.5, 'null-effect zone',
             horizontalalignment='center', verticalalignment='center')
    plt.text(evalue, evalue, f'({evalue :.2f}, {evalue :.2f})')
    plt.xlim(0, x_max)
    plt.ylim(0, y_max)
    plt.xlabel('RR(UX)')
    plt.ylabel('RR(UY)')
    plt.grid()
    plt.legend(loc='best')
    plt.show()
    
print('\nATT case:')
tr_est = att_pair_tr_df[outcome_col].mean()
ct_est = att_pair_ct_df[outcome_col].mean()
get_evalue_curve(tr_est, ct_est)

print('\nATC case:')
tr_est = atc_pair_tr_df[outcome_col].mean()
ct_est = atc_pair_ct_df[outcome_col].mean()
get_evalue_curve(tr_est, ct_est)

# IPW, Overlap weight
**Overlap weight**
- https://speakerdeck.com/tomoshige_n/causal-inference-and-data-analysis?slide=37
- http://www2.stat.duke.edu/~fl35/OW/MultiTrt_talk.pdf

In [None]:
ps_min, ps_max = 0.01, 0.99
#ps_min, ps_max = 0.1, 0.9    # Li et al., 2018
tr = merged_df[treat_col].values
ps = merged_df[ps_col].values
ps = np.clip(ps, ps_min, ps_max)
outcome = merged_df[outcome_col].values

ipwe0 = ((1 - tr) * outcome / (1 - ps)).sum() / ((1 - tr) / (1 - ps)).sum()
ipwe1 = (tr * outcome / ps).sum() / (tr / ps).sum()
print(f'ATE(IPW): {ipwe1 - ipwe0 :.6f}')

owe0 = ((1 - tr) * outcome * ps).sum() / ((1 - tr) * ps).sum()
owe1 = (tr * outcome * (1 - ps)).sum() / (tr * (1 - ps)).sum()
print(f'ATE(overlap weight): {owe1 - owe0 :.6f}')

In [None]:
def show_result(arr, cols, adusted_label):
    plt.scatter(cols, arr[:, 0], label='biased')
    plt.scatter(cols, arr[:, 1], label=adusted_label)
    plt.hlines(y=0.1, xmin=-1, xmax=len(vis_cols), linestyles='dotted', linewidths=0.5)
    plt.xlim(-0.5, len(vis_cols) - 0.5)
    plt.xticks(rotation=90)
    plt.legend(loc='best')
    plt.title(adusted_label)

plt.figure(figsize=(12, 12))
vis_cols = [c for c in merged_df.columns 
            if c not in [treat_col, outcome_col, 'propensity_lr', 'propensity_rf']]

# IPW
print('\nStandardized difference for covariates using IPW')
sd_list = []
for col in vis_cols:
    tr_df = merged_df.loc[merged_df[treat_col]==1, col].reset_index(drop=True)
    ct_df = merged_df.loc[merged_df[treat_col]==0, col].reset_index(drop=True)
    tr_ps = merged_df.loc[merged_df[treat_col]==1, ps_col].values
    ct_ps = merged_df.loc[merged_df[treat_col]==0, ps_col].values
    sd_biased = absolute_standardized_difference(tr_df, ct_df)
    sd_adjusted = absolute_standardized_difference(tr_df / tr_ps, ct_df / (1 - ct_ps))
    sd_list.append([sd_biased, sd_adjusted])
    print(f'col: {col :<26}, SD(biased): {sd_biased :.6f}, SD(matched): {sd_adjusted :.6f}')
    
plt.subplot(2, 2, 1)
show_result(np.array(sd_list), vis_cols, 'IPW')

# IPW, truncated
print('\nStandardized difference for covariates using truncated IPW')
sd_list = []
for col in vis_cols:
    tr_df = merged_df.loc[merged_df[treat_col]==1, col].reset_index(drop=True)
    ct_df = merged_df.loc[merged_df[treat_col]==0, col].reset_index(drop=True)
    tr_ps = merged_df.loc[merged_df[treat_col]==1, ps_col].values
    ct_ps = merged_df.loc[merged_df[treat_col]==0, ps_col].values
    inv_tr_ps = 1 / np.clip(tr_ps, ps_min, ps_max)
    inv_ct_ps = 1 / np.clip(1 - ct_ps, ps_min, ps_max)
    sd_biased = absolute_standardized_difference(tr_df, ct_df)
    sd_adjusted = absolute_standardized_difference(tr_df * inv_tr_ps, ct_df * inv_ct_ps)
    sd_list.append([sd_biased, sd_adjusted])
    print(f'col: {col :<26}, SD(biased): {sd_biased :.6f}, SD(matched): {sd_adjusted :.6f}')

plt.subplot(2, 2, 2)
show_result(np.array(sd_list), vis_cols, 'IPW, truncated')

# IPW, discarded
print('\nStandardized difference for covariates using discarded IPW')
sd_list = []
for col in vis_cols:
    tr_df = merged_df.loc[merged_df[treat_col]==1, col].reset_index(drop=True)
    ct_df = merged_df.loc[merged_df[treat_col]==0, col].reset_index(drop=True)
    tr_ps = merged_df.loc[merged_df[treat_col]==1, ps_col].values
    ct_ps = merged_df.loc[merged_df[treat_col]==0, ps_col].values
    is_valid_tr = np.where((tr_ps > ps_min) & (tr_ps < ps_max), 1, 0)
    is_valid_ct = np.where((1 - ct_ps > ps_min) & (1 - ct_ps < ps_max), 1, 0)
    inv_tr_ps = 1 / tr_ps
    inv_ct_ps = 1 / (1 - ct_ps)
    sd_biased = absolute_standardized_difference(tr_df, ct_df)
    sd_adjusted = absolute_standardized_difference(tr_df * inv_tr_ps, ct_df * inv_ct_ps)
    sd_list.append([sd_biased, sd_adjusted])
    print(f'col: {col :<26}, SD(biased): {sd_biased :.6f}, SD(matched): {sd_adjusted :.6f}')

plt.subplot(2, 2, 3)
show_result(np.array(sd_list), vis_cols, 'IPW, discarded')

# overlap weight
print('\nStandardized difference for covariates using overlap weight')
sd_list = []
for col in vis_cols:
    tr_df = merged_df.loc[merged_df[treat_col]==1, col].reset_index(drop=True)
    ct_df = merged_df.loc[merged_df[treat_col]==0, col].reset_index(drop=True)
    tr_ps = merged_df.loc[merged_df[treat_col]==1, ps_col].values
    ct_ps = merged_df.loc[merged_df[treat_col]==0, ps_col].values
    sd_biased = absolute_standardized_difference(tr_df, ct_df)
    sd_adjusted = absolute_standardized_difference(tr_df * (1 - tr_ps), ct_df * ct_ps)
    sd_list.append([sd_biased, sd_adjusted])
    print(f'col: {col :<26}, SD(biased): {sd_biased :.6f}, SD(matched): {sd_adjusted :.6f}')

plt.subplot(2, 2, 4)
show_result(np.array(sd_list), vis_cols, 'overlap weight')

plt.tight_layout()
plt.show();