In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns

## Investigate how many unique values there are in 3000 datapoints

On average, we see 1,150 unique values -- we can reduce our datapoints by a lot

In [2]:
fp = 'data/Warfarin_v2/rf_balance_proba_white/'

In [75]:
cols = ['Age1.2', 'Age3.4', 'Age5.6', 'Age7', 'Age8.9', 'Height1', 'Height2',
       'Height3', 'Height4', 'Height5', 'Weight1', 'Weight2', 'Weight3',
       'Weight4', 'Weight5', 'Amiodarone..Cordarone.', 'Asian',
       'Black.or.African.American', 'White', 'Enzyme.Inducer', 'Unknown.Cyp2C9',
       'Unknown.Race', 'VKORC1.A.A', 'VKORC1.A.G', 'VKORC1.Missing',
       'X.1..1', 'X.1..3', 'X.2..2', 'X.2..3', 'X.3..3']

In [76]:
unique = []
for seed in [1]:
    for split in range(1, 6):
        df = pd.read_csv(os.path.join(fp, f'seed{seed}', f'data_train_enc_0.33_{split}.csv'))
        df_dropped = df[cols].drop_duplicates(subset=cols)
        unique.append(len(df_dropped))

In [25]:
unique

[1173, 1142, 1160, 1165, 1164]

## Implementing the transform on these datasets

In [66]:
def make_dr(row, ml_col, treatment):
    y = row['y']
    t = row['t']
    mu = row['prob_t_pred_tree']
    
    # ml_col is a string corresponding to the column
    ml_col_tr = f'{ml_col}{treatment}'
    nu = row[ml_col_tr]
    
    total = 0
    
    for i in range(len(y)):
        y_i = y[i]
        t_i = t[i]
        mu_i = mu[i]
        nu_i = nu[i]
        
        if treatment == t_i:
            total += (y_i - nu_i)/mu_i
    
    return total

In [None]:
def make_ipw(row, treatment):
    y = row['y']
    t = row['t']
    mu = row['prob_t_pred_tree']

    total = 0
    
    for i in range(len(y)):
        y_i = y[i]
        t_i = t[i]
        mu_i = mu[i]
        
        if treatment == t_i:
            total += y_i/mu_i
    
    return total

In [78]:
ml_cols = ['lr0', 'lr1', 'lr2', 'lrrf0', 'lrrf1', 'lrrf2', 'ml0', 'ml1', 'ml2']
cols = ['Age1.2', 'Age3.4', 'Age5.6', 'Age7', 'Age8.9', 'Height1', 'Height2',
       'Height3', 'Height4', 'Height5', 'Weight1', 'Weight2', 'Weight3',
       'Weight4', 'Weight5', 'Amiodarone..Cordarone.', 'Asian',
       'Black.or.African.American', 'White', 'Enzyme.Inducer', 'Unknown.Cyp2C9',
       'Unknown.Race', 'VKORC1.A.A', 'VKORC1.A.G', 'VKORC1.Missing',
       'X.1..1', 'X.1..3', 'X.2..2', 'X.2..3', 'X.3..3']

In [79]:
for r in ['0.33', 'r0.06', 'r0.11']:
    for seed in range(1, 6):
        for split in range(1, 6):
            df = pd.read_csv(os.path.join(fp, f'seed{seed}', f'data_train_enc_{r}_{split}.csv'))
            agg = df[cols+ml_cols+['prob_t_pred_tree', 't', 'y']].groupby(cols).agg(list).reset_index()

            for ml_col in ['ml', 'lr', 'lrrf']:
                for t in range(3):
                    agg[f'DR_{ml_col}{t}'] = agg.apply(lambda row: make_dr(row, ml_col, t), axis=1)

            for ml in ml_cols:
                agg[ml] = agg[ml].apply(lambda x: np.sum(x))

            agg = agg.drop(columns=['prob_t_pred_tree', 't', 'y'])

            agg.to_csv(os.path.join(fp, f'seed{seed}', f'data_train_enc_{r}_{split}_agg.csv'), index=False)

In [65]:
agg = df[cols+ml_cols+['prob_t_pred_tree', 't', 'y']].groupby(cols).agg(list).reset_index()

In [71]:
for t in range(3):
    agg[f'IPW_{ml_col}{t}'] = agg.apply(lambda row: make_dr(row, ml_col, t), axis=1)

In [72]:
for ml in ml_cols:
    agg[ml] = agg[ml].apply(lambda x: np.sum(x))

In [73]:
agg = agg.drop(columns=['prob_t_pred_tree', 't', 'y'])

In [74]:
agg

Unnamed: 0,Age1.2,Age3.4,Age5.6,Age7,Age8.9,Height1,Height2,Height3,Height4,Height5,...,ml2,DR_ml0,DR_ml1,DR_ml2,DR_lr0,DR_lr1,DR_lr2,DR_lrrf0,DR_lrrf1,DR_lrrf2
0,0,0,0,0,1,0,0,0,0,1,...,0.19,0.000000,0.000000,0.0,-0.036132,0.000000,0.000000,-0.036132,0.000000,0.000000
1,0,0,0,0,1,0,0,0,0,1,...,0.00,0.000000,0.060241,0.0,0.000000,1.586230,0.000000,0.000000,1.291347,0.000000
2,0,0,0,0,1,0,0,0,0,1,...,0.05,-0.029354,0.000000,0.0,-0.167336,0.000000,-0.025412,-0.167336,0.000000,-0.025412
3,0,0,0,0,1,0,0,0,0,1,...,0.04,-0.058708,0.000000,0.0,-0.922434,1.365757,-0.000595,-0.922434,0.991792,-0.000595
4,0,0,0,0,1,0,0,0,0,1,...,0.00,0.645793,0.000000,0.0,0.269873,0.000000,0.000000,0.269873,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1168,1,1,1,1,1,1,1,1,1,1,...,0.00,-0.234834,0.000000,0.0,-0.030603,0.000000,0.000000,-0.030603,0.000000,0.000000
1169,1,1,1,1,1,1,1,1,1,1,...,0.02,-0.117417,0.000000,0.0,-0.026450,0.000000,0.000000,-0.026450,0.000000,0.000000
1170,1,1,1,1,1,1,1,1,1,1,...,0.01,0.000000,0.210843,0.0,0.000000,1.092340,0.000000,0.000000,1.280538,0.000000
1171,1,1,1,1,1,1,1,1,1,1,...,0.00,0.000000,0.000000,0.0,0.000000,0.000000,-0.000235,0.000000,0.000000,-0.000235
