In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Read one of the datasets
df = pd.read_csv('../datasets/Dataset1.csv')

# Prepare some variables for later use
df['treated'] = (df['exposure'] >= 0.25).astype(int)
treatedcol = 'treated'
vname1 = 'cov_1'
vname2 = 'cov_2'
col_names = [vname1, vname2]
cols = [0, 1]
tr = df[df['treated'] == 1][col_names].to_numpy()
ct = df[df['treated'] == 0][col_names].to_numpy()
n_tr = len(tr)
n_ct = len(ct)
dropcols = ['id', 'risk', 'exposure', 'outcome', 'Unnamed..0']

In [None]:
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import localconverter

# Defining the R script for calling CEM, and loading the instance in Python
r = robjects.r
r['source']('call_cem.r')

# Default values, filled out with values from histogram binning further down
delta_1 = 0
delta_2 = 0
stats_1 = np.inf
stats_2 = np.inf
max_tr_unmatched = 0
max_ct_unmatched = 0

with localconverter(robjects.default_converter + pandas2ri.converter):
  df_r = robjects.conversion.py2rpy(df)

  # Call R-script for overall imbalance and statistics before stratification
  overall_imbalance = robjects.globalenv['overall_imbalance']
  vars = robjects.StrVector([i for i in col_names])
  overall_res = overall_imbalance(df_r, treatedcol, vars)
  print('Overall imbalance')
  print(overall_res)

  overall_statistics = overall_res.rx2('tab')['statistic']
  stats_1 = overall_statistics[vars[0]]
  stats_2 = overall_statistics[vars[1]]

  # Call R-script for hisogram binning
  call_plain_cem = robjects.globalenv['call_plain_cem']
  plain_mat = call_plain_cem(df_r, treatedcol, dropcols)
  print('Histogram binning')
  print(plain_mat)
  print(plain_mat.rx2('breaks'))

  delta_1 = np.max(np.diff(plain_mat.rx2('breaks').rx2(vname1)))
  delta_2 = np.max(np.diff(plain_mat.rx2('breaks').rx2(vname2)))  
  max_tr_unmatched = plain_mat.rx2('tab')[2, 1]
  max_ct_unmatched = plain_mat.rx2('tab')[2, 0]


In [None]:
# Parameter type P1 or P2 from the paper, and also P3 not included in paper.
P = 2

if (max_tr_unmatched > len(tr)):
    max_tr_unmatched = len(tr)

if (max_ct_unmatched > len(ct)):
    max_ct_unmatched = len(ct)

if P == 3:
    if (n_tr < n_ct):
        max_ct_unmatched = round(n_ct / n_tr)
        max_tr_unmatched = 1
    else:
        max_tr_unmatched = round(n_tr / n_ct)
        max_ct_unmatched = 1

# For P2, the maximum allowed stratum widths are set to infinity
if P == 2:
    delta_1 = np.inf
    delta_2 = np.inf

# The maximum allowed stratum widths for the covariates
deltas = [delta_1, delta_2]

print(f'Max unmatched G0: {max_ct_unmatched} G1: {max_tr_unmatched}')
print(f'Max widths {vname1}: {delta_1}, {vname2}: {delta_2}')


In [None]:
from scipy import stats
from numba import njit
from numba.core import types
from numba.typed import Dict, List

four_floats_tuple = types.UniTuple(types.float64, 4)
two_floats_tuple = types.UniTuple(types.float64, 2)

def potential_edges(tr, ct, cols):
    '''Gets the potential stratum edges for the given column indices.'''
    edges = []
    data = np.concatenate((tr, ct))
    for col in cols:
        col_data = np.unique(data[:, col]) # also sorts

        # Move the leftmost edge to the left by half the distance to the 
        # closest value to the right
        leftmost_edge = col_data[0]
        second_leftmost_value = col_data[1]
        leftmost_edge = leftmost_edge - abs(leftmost_edge - second_leftmost_value) / 2.
        edges.append((col, leftmost_edge))

        # Move each edge to the left to the middle between two adjecent values
        for i in range(1, len(col_data)):
            adjusted_edge = col_data[i] - abs(col_data[i] - col_data[i-1]) / 2.
            edges.append((col, adjusted_edge))
        
        # Add a rightmost edge with a distance equal to half the distance 
        # between the last and next last values for the column.
        rightmost_value = col_data[-1] 
        second_rightmost_value = col_data[-2]
        rightmost_edge = rightmost_value + abs(rightmost_value - second_rightmost_value) / 2.
        edges.append((col, rightmost_edge))
    return edges


@njit
def faster_count(edges_0, edges_1, tr, ct, unmatched_counts: Dict):
    '''Count the number of unmatched for the strata between the edges'''
    tr_unmatched = np.float64(0.0)
    ct_unmatched = np.float64(0.0)
    n_edges_0 = len(edges_0)
    n_edges_1 = len(edges_1)
    checked_keys = List()

    left_0 = edges_0[0]
    for i in range(1, n_edges_0):
        right_0 = edges_0[i]
        left_1 = edges_1[0]
        for j in range(1, n_edges_1):
            right_1 = edges_1[j]
            key = (left_0, right_0, left_1, right_1)
            checked_keys.append(key)
            if key in unmatched_counts:
                counts = unmatched_counts[key]
            else:
                tr_s = tr[np.where(np.logical_and(
                    np.logical_and(tr[:, 0] >= left_0, tr[:, 0] < right_0), 
                    np.logical_and(tr[:, 1] >= left_1, tr[:, 1] < right_1)))]
                ct_s = ct[np.where(np.logical_and(
                    np.logical_and(ct[:, 0] >= left_0, ct[:, 0] < right_0), 
                    np.logical_and(ct[:, 1] >= left_1, ct[:, 1] < right_1)))]
                counts = (len(tr_s), len(ct_s))
                unmatched_counts[key] = counts
            if counts[0] > 0 and counts[1] == 0:
                tr_unmatched += counts[0]
            elif counts[0] == 0 and counts[1] > 0:
                ct_unmatched += counts[1]
            left_1 = right_1
        left_0 = right_0  
  
    return (tr_unmatched, ct_unmatched, checked_keys)


def count_unmatched(df_edges, tr, ct, unmatched_counts: Dict):
    '''Count number of unmatched treated and controls'''
    df = df_edges

    # Currently only two columns
    edges_0 = df[df['col'] == 0]['val'].to_numpy()
    edges_1 = df[df['col'] == 1]['val'].to_numpy()

    tr_unmatched, ct_unmatched, checked_keys = faster_count(edges_0, edges_1, tr, ct, unmatched_counts)

    return int(tr_unmatched), int(ct_unmatched), checked_keys


def df_no_outl(df, c):
    '''Filter out outlier values, i.e., |z| < 3'''
    return df[(np.abs(stats.zscore(df[c])) < 3)][c]


In [None]:
candidate_edges = potential_edges(tr, ct, cols)
df_edge_candidates = pd.DataFrame(candidate_edges, columns=['col', 'val'])
df_edges = df_edge_candidates.copy()

cur_tr_match_inc: int = -1
cur_tr_match_inc: int = -1
remove_index: int
selected_edge_width: float

col_widths = []
for c in col_names:
    df_hat = df_no_outl(df, c)
    w = np.abs(df_hat.max() - df_hat.min())
    col_widths.append(w)

col_widths = np.array(col_widths, dtype=np.float64)
col_scale_factor = col_widths

# Cache for already counted strata results that can be reused
unmatched_counts = Dict.empty(key_type=four_floats_tuple, value_type=two_floats_tuple)

cur_tr_unmatched, cur_ct_unmatched, _ = count_unmatched(df_edges, tr, ct, unmatched_counts)

print(f'Initial unmatched: G0: {cur_ct_unmatched} G1: {cur_tr_unmatched}')
print(f'Column maximum widths: {col_widths}')
print(f'Column scale factors: {col_scale_factor}')

while (len(df_edges) > 0 and (cur_tr_unmatched > max_tr_unmatched or cur_ct_unmatched > max_ct_unmatched)):
    selected_edge_width = np.inf
    cur_rel_inc = -1
    cur_tr_match_inc = -1
    cur_ct_match_inc = -1
    remove_index = None

    for col, col_name in enumerate(col_names):
        # Get candidate edges for current column
        df_col_edges = df_edges[df_edges['col'] == col]
        
        for i in range(1, len(df_col_edges) - 1): # Leftmost and rightmost edge will not be removed
            assess_val = df_col_edges.loc[df_col_edges.index[i]]['val']
            left_assess_val = df_col_edges.loc[df_col_edges.index[i-1]]['val']
            right_assess_val = df_col_edges.loc[df_col_edges.index[i+1]]['val']

            width = right_assess_val - left_assess_val # width in the col's dimension if removing edge i
            if width > deltas[col]:
                continue

            # Adjust width for this variables to be comparable with the others
            width /= col_scale_factor[col]

            # Drop stratum edges for given column that is greater or smaller than what we are assessing 
            df_assess = df_edges[
                ~((df_edges.col == col) & 
                ((df_edges.val < left_assess_val) | (df_edges.val > right_assess_val)))                
            ]

            # Also irrelevant units are filtered out for now
            tr_filtered = tr[np.where(np.logical_and(tr[:, col] >= left_assess_val, tr[:, col] < right_assess_val))]
            ct_filtered = ct[np.where(np.logical_and(ct[:, col] >= left_assess_val, ct[:, col] < right_assess_val))]

            # First check strata with the current edge
            tr_unmatched_with, ct_unmatched_with, checked_keys = count_unmatched(
                df_assess, tr_filtered, ct_filtered, unmatched_counts)

            # Then check what happens if we remove the edge (in the middle)
            df_assess = df_assess[~((df_assess['val'] == assess_val) & (df_assess['col'] == col))]
                
            # df_assess = df_assess.drop(df_assess.index[1])
            tr_unmatched_without, ct_unmatched_without, _ = count_unmatched(
                df_assess, tr_filtered, ct_filtered, unmatched_counts)

            tr_match_inc = tr_unmatched_with - tr_unmatched_without
            ct_match_inc = ct_unmatched_with - ct_unmatched_without

            tr_w = (cur_tr_unmatched - max_tr_unmatched) / (n_tr - max_tr_unmatched)
            if (tr_w < 0):
                tr_w = 0

            ct_w = (cur_ct_unmatched - max_ct_unmatched) / (n_ct - max_ct_unmatched)
            if (ct_w < 0):
                ct_w = 0
                
            rel_inc = (tr_match_inc * tr_w) + (ct_match_inc * ct_w)

            if (rel_inc > cur_rel_inc) or (rel_inc == cur_rel_inc and width < selected_edge_width):
                cur_rel_inc = rel_inc
                cur_tr_match_inc = tr_match_inc
                cur_ct_match_inc = ct_match_inc
                remove_index = df_col_edges.index[i]
                selected_edge_width = width
                for key in checked_keys:
                    unmatched_counts.pop(key)
            checked_keys.clear()

    if remove_index is not None:
        df_edges.drop(remove_index, inplace=True)
    else:
        print('No more improvements found...')
        break

    if ((cur_tr_match_inc > 0) or (cur_ct_match_inc > 0)):
        cur_tr_unmatched -= cur_tr_match_inc
        cur_ct_unmatched -= cur_ct_match_inc
        print(f'edges: {len(df_edges)}\nUnmatched G0: {cur_ct_unmatched}, G1: {cur_tr_unmatched}')

    if ((cur_tr_unmatched <= max_tr_unmatched) and (cur_ct_unmatched <= max_ct_unmatched)):
        print('Finished!')
        break

unmatched_counts.clear()    
cur_tr_unmatched, cur_ct_unmatched, _ = count_unmatched(df_edges, tr, ct, unmatched_counts)
print(f'Num. edges: {len(df_edges)}\nUnmatched G0: {cur_ct_unmatched}, G1: {cur_tr_unmatched}')
# print(df_edges)

In [None]:
cov_1_cutoffs = df_edges[df_edges['col'] == 0]['val'].to_list()
cov_2_cutoffs = df_edges[df_edges['col'] == 1]['val'].to_list()
cov_1_cutoffs.sort()
cov_2_cutoffs.sort()

with localconverter(robjects.default_converter + pandas2ri.converter):
  call_cem_autostrata = robjects.globalenv['call_cem_autostrata']

  dropcols = robjects.StrVector(dropcols)

  cutoffs = {}

  cutoffs[vname1] = cov_1_cutoffs
  cutoffs[vname2] = cov_2_cutoffs

  cps_dict = {}
  for key in cutoffs:
    cps_dict[key] = robjects.FloatVector(cutoffs[key])
  cps = robjects.ListVector(cps_dict)

  autostrata_res = call_cem_autostrata(df_r, treatedcol, dropcols, cps)

  print('Autostrata results')
  print(autostrata_res)

  print(autostrata_res.rx2('breaks'))

  print(autostrata_res.rx2('strata'))

In [None]:
import seaborn as sns

%matplotlib inline

sns.set_theme(style="white")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

full_cov_1_cutoffs = plain_mat.rx2('breaks').rx2(vname1)
full_cov_2_cutoffs = plain_mat.rx2('breaks').rx2(vname2)

ax1.scatter(tr[:, 0], tr[:, 1], marker='x', label='Treated', c='0.25')
ax1.scatter(ct[:, 0], ct[:, 1], marker='.', label='Controls', c='0.25')

ax1.vlines(full_cov_1_cutoffs, np.min(full_cov_2_cutoffs), np.max(full_cov_2_cutoffs), colors=['0.5'], linewidth=2)
ax1.hlines(full_cov_2_cutoffs, np.min(full_cov_1_cutoffs), np.max(full_cov_1_cutoffs), colors=['0.5'], linewidth=2)
ax1.legend(loc=1, framealpha=1, fontsize=14)

ax2.scatter(tr[:, 0], tr[:, 1], marker='x', label='Treated', c='0.25')
ax2.scatter(ct[:, 0], ct[:, 1], marker='.', label='Controls', c='0.25')

ax2.vlines(cov_1_cutoffs, np.min(cov_2_cutoffs), np.max(cov_2_cutoffs), colors=['0.5'], linewidth=2)
ax2.hlines(cov_2_cutoffs, np.min(cov_1_cutoffs), np.max(cov_1_cutoffs), colors=['0.5'], linewidth=2)
sns.despine(right = True)
ax2.legend(loc=1, framealpha=1, fontsize=16)

plt.show()