# Setup

In [None]:
spark

In [None]:
%info

In [3]:
from functools import reduce
import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql import DataFrame

In [4]:
def _get_operator(_str):
    if 0 < len(_str):
        op = ' + '
    else:
        op = ''
    return op

## Setup Snowflake Connection

In [None]:
import snowflake.connector
import getpass

# Snowflake credentials and directories

SNOWFLAKE_CREDS_DICT = {
}

SNOWFLAKE_SESH_DICT = {
}

SNOWFLAKE_WRITE_FROM_SPARK_DICT = {
    'sfURL':SNOWFLAKE_CREDS_DICT['account'] + '.snowflakecomputing.com',
    'sfUser':SNOWFLAKE_CREDS_DICT['user'],
    'sfRole':SNOWFLAKE_CREDS_DICT['role'],
    'sfPassword':SNOWFLAKE_CREDS_DICT['password'],
    'sfDatabase':SNOWFLAKE_SESH_DICT['database'],
    'sfSchema':SNOWFLAKE_SESH_DICT['schema'],
    'sfWarehouse':SNOWFLAKE_SESH_DICT['warehouse'],
    'tracing':'All',
}

SNOWFLAKE_CONNECT_FROM_SPARK_DICT = SNOWFLAKE_CREDS_DICT.copy()
SNOWFLAKE_CONNECT_FROM_SPARK_DICT.update(SNOWFLAKE_SESH_DICT)

# write a Spark dataframe to a NEW (non-temporary) table on Snowflake
def sf_write_spark_df_to_snowflake_new_table(spark_df, table):
    (spark_df.write.format('net.snowflake.spark.snowflake')
     .options(**SNOWFLAKE_WRITE_FROM_SPARK_DICT)
     .option('dbtable',table)
     .mode('overwrite')
     .save())

# Load Parquets From s3

In [6]:
# df_patient = spark.read.parquet('s3a://acic-causality-challenge-2022/parquet/patient')
df_patient_year = spark.read.parquet('s3a://acic-causality-challenge-2022/parquet/patient_year')
df_patient_joined = spark.read.parquet('s3a://acic-causality-challenge-2022/patient_joined')
# df_practice = spark.read.parquet('s3a://acic-causality-challenge-2022/parquet/practice')
df_practice_year = spark.read.parquet('s3a://acic-causality-challenge-2022/parquet/practice_year')
df_practice_joined = spark.read.parquet('s3a://acic-causality-challenge-2022/practice_joined')

# Generate New DGPs
Let's just keep it simple and only use the first pretreatment year (1,2) to compute the new values. This will drop some patients (~20%), but should be fine.

## Setup Base Variables

### Practice

In [7]:
df_practice_base = (df_practice_joined
    .withColumn('minYear', F.min('year').over(Window.partitionBy('dataset_num', 'id_practice')))
    .where((F.col('year') == F.col('minYear')) & F.col('year').isin([1,2]))
    .select(['dataset_num', 'id_practice',
             'n_patients', 'X1', 'X2_A', 'X2_B', 'X2_C', 'X3', 'X4_A', 'X4_B', 'X4_C', 'X5', 'X6', 'X7', 'X8', 'X9',
             'V1_avg', 'V2_avg', 'V3_avg', 'V4_avg', 'V5_A_avg', 'V5_B_avg', 'V5_C_avg']
           )
)

df_practice_base.persist();
print(f'{df_practice_base.count():,}')

1,700,000


In [8]:
print(f"{df_practice_joined.select('dataset_num', 'id_practice').distinct().count():,}")

1,700,000


### Patient

In [9]:
df_patient_base = (df_patient_joined
    .withColumn('minYear', F.min('year').over(Window.partitionBy('dataset_num', 'id_patient')))
    .where((F.col('year') == F.col('minYear')) & F.col('year').isin([1,2]))
    .withColumnRenamed('year', 'year_original')
    .withColumnRenamed('Y', 'Y_original')
    .withColumn('year', F.lit(0))
    .join(df_practice_base, ['dataset_num', 'id_practice'], 'left')
    .select(['dataset_num', 'id_practice', 'id_patient',
             'year_original', 'year', 'Y_original',
             'n_patients', 'X1', 'X2_A', 'X2_B', 'X2_C', 'X3', 'X4_A', 'X4_B', 'X4_C', 'X5', 'X6', 'X7', 'X8', 'X9',
             'V1', 'V2', 'V3', 'V4', 'V5_A', 'V5_B', 'V5_C']
           )
)

df_patient_base.persist();
print(f'{df_patient_base.count():,}')

1,139,150,768


In [10]:
n_distinct_patients_all = df_patient_joined.select('dataset_num', 'id_patient').distinct().count()
n_distinct_patients_pre = df_patient_joined.where(F.col('year').isin([1,2])).select('dataset_num', 'id_patient').distinct().count()

In [11]:
print(f'n_distinct_patients_all = {n_distinct_patients_all:,}')
print(f'n_distinct_patients_pre = {n_distinct_patients_pre:,}')
print()
print(f'Dropping {(n_distinct_patients_all-n_distinct_patients_pre)/n_distinct_patients_all:.2%} of original patients, across all realizations')

n_distinct_patients_all = 1,348,851,109
n_distinct_patients_pre = 1,139,150,768

Dropping 15.55% of original patients, across all realizations


In [12]:
df_patient_base.groupBy('year_original').agg(F.count('*').alias('n')).orderBy('year_original').toPandas()

Unnamed: 0,year_original,n
0,1,1042575381
1,2,96575387


## Simulate Z

$Z \sim \mathrm{Bernoulli}(p)$  
where  
$\mathrm{logit}(p) = f(X, V_{\mathrm{Avg}}) \equiv t$  

$\implies$  

$Z = p < U$  
$p = 1/(1+\exp(-t))$  
$U \sim \mathrm{Uniform}(0,1)$  

### Set Parameters

In [13]:
Z_DGP_params = {
    1 : {
        'intercept': -1.5,
        'X1': -0.1,
        'X4_B': -0.1,
        'X4_C': 0.3,
        'X5': 0.2,
        'X7': 0.2,
        'X8': 0.5,
        'X4_B*X8': 0.2,
        'X4_C*X8': -0.6,
    },
    2 : {
        'intercept': -1.3,
        'X1': -0.1,
        'X4_B': -0.2,
        'X4_C': 0.4,
        'X5': 0.2,
        'X7': 0.2,
        'X8': 0.5,
        'X4_B*X8': 0.4,
        'X4_C*X8': -0.6,
    },
}

### Run

In [14]:
df_DFP_Z = df_practice_base.alias('df_DFP_Z')

Z_join_cols = ['dataset_num', 'id_practice']

iseed = 42

Z_cols = []
for iZ_DGP,Z_DGP_param in Z_DGP_params.items():
    # build expr to create t
    t_expr = ''
    for col,weight in Z_DGP_param.items():
        op = _get_operator(t_expr)
        if col == 'intercept':
            t_expr = f'{op}{weight}'
        elif weight != 0:
            t_expr = f'{t_expr}{op}{weight}*{col}'

    print(f'For iZ_DGP = {iZ_DGP}, t = {t_expr}')

    Z_col = f'Z_DGP_{iZ_DGP}'
    df_DFP_Z = (df_DFP_Z
        .withColumn('t', F.expr(t_expr))
        .withColumn(Z_col, F.when(1./(1.+F.exp(-F.col('t'))) < F.rand(seed=iseed), 1).otherwise(0))
        .drop('t')
    )
    Z_cols.append(Z_col)

    iseed += 1

    dfp_Z_counts = df_DFP_Z.groupBy(Z_col).agg(F.count('*').alias('n_patients')).orderBy(Z_col).toPandas()
    dfp_Z_counts['percent'] = 100.*dfp_Z_counts['n_patients'] / dfp_Z_counts['n_patients'].sum()
    print(dfp_Z_counts)

df_DFP_Z_complete = df_DFP_Z.select(Z_join_cols+Z_cols)

df_DFP_Z_complete.persist();
print(f'{df_DFP_Z_complete.count():,}')

For iZ_DGP = 1, t = -1.5 + -0.1*X1 + -0.1*X4_B + 0.3*X4_C + 0.2*X5 + 0.2*X7 + 0.5*X8 + 0.2*X4_B*X8 + -0.6*X4_C*X8
   Z_DGP_1  n_patients    percent
0        0      993830  58.460588
1        1      706170  41.539412
For iZ_DGP = 2, t = -1.3 + -0.1*X1 + -0.2*X4_B + 0.4*X4_C + 0.2*X5 + 0.2*X7 + 0.5*X8 + 0.4*X4_B*X8 + -0.6*X4_C*X8
   Z_DGP_2  n_patients    percent
0        0     1061194  62.423176
1        1      638806  37.576824
1,700,000


In [15]:
df_DFP_Z_complete.limit(5).toPandas()

Unnamed: 0,dataset_num,id_practice,Z_DGP_1,Z_DGP_2
0,3,22,0,0
1,3,89,0,1
2,4,457,1,0
3,6,117,0,0
4,10,264,1,0


## Simulate Y

$Y_{1} = Y_{min(1,2)}^{\mathrm{Original}} + u \log(U_{i})$  
then  
$Y_{i} = T\,Y_{i-1} + (1-T) \, y_{i} + u \log(U_{i})$  
for $i=2,3,4$  
$T \sim \mathrm{Bernoulli}(p_{T})$  

$y_{i} = -\exp(\eta_{i}) \log(U_{i})$  
$U \sim \mathrm{Uniform}(0,1)$  
$\eta_{i} \equiv f(i,X,V,\mathrm{post})$  

$\mathrm{post} = 1$ if $Z = 1$ and $i = 3, 4$, and $0$ otherwise  
Use all $Z$'s from prior step  

### Set Parameters

In [16]:
Y_DGP_params = {
    1 : {
        'u_noise': -20.,
        'p_T': 0.5,
        'Z': [0, 0, -0.05, -0.05],
        'eta': {
            'intercept': [6.6, 6.6, 6.8, 6.8],
            'X8': [-0.2, -0.2, -0.2, -0.2],
            'V4': [0.3, 0.3, 0.5, 0.5],
        },
    },
    2 : {
        'u_noise': -20.,
        'p_T': 0.5,
        'Z': [0, 0, -0.1, -0.1],
        'eta': {
            'intercept': [6.6, 6.6, 6.8, 6.8],
            'X8': [-0.2, -0.2, -0.2, -0.2],
            'V4': [0.3, 0.3, 0.5, 0.5],
        },
    },
}

### Run

In [17]:
df_DFP_Y = df_patient_base.join(df_DFP_Z_complete, Z_join_cols, 'left')

Y_join_cols = ['dataset_num', 'id_practice', 'id_patient', 'year'] # id_practice is not really required, but nice to have
Y_features_X_V = ['n_patients', 'X1', 'X2_A', 'X2_B', 'X2_C', 'X3', 'X4_A', 'X4_B', 'X4_C', 'X5', 'X6', 'X7', 'X8', 'X9', 'V1', 'V2', 'V3', 'V4', 'V5_A', 'V5_B', 'V5_C']

iseed - 1042

Y_results = {}
for iY_DGP,Y_DGP_param in Y_DGP_params.items():
    u_noise = Y_DGP_param.get('u_noise', 0.)

    for iZ_col,Z_col in enumerate(Z_cols):
        Y_col = f'Y_DGP_{iY_DGP}_with_{Z_col}'

        df_lagged = df_DFP_Y.withColumn(Y_col, F.col('Y_original'))

        Z_col_with_post = [Z_col, f'post_{Z_col}']
        Y_results_cols = [Y_col, f'{Y_col}_counterfactual', f'eta_{Y_col}']
        cols_of_df_year = Y_join_cols+Z_col_with_post+Y_results_cols+Y_features_X_V

        df_years = []
        for year in range(1,5):

            # build expr to create eta
            eta_expr_base = ''
            for col,weight_array in Y_DGP_param.get('eta', {}).items():
                op = _get_operator(eta_expr_base)
                if col == 'intercept':
                    eta_expr_base = f'{op}{weight_array[year-1]}'
                elif weight != 0:
                        eta_expr_base = f'{eta_expr_base}{op}{weight_array[year-1]}*{col}'

            # add the Z component
            Z_weight = Y_DGP_param.get('Z', [0,0,0,0])[year-1]
            op = _get_operator(eta_expr_base)

            eta_expr = f'{eta_expr_base}{op}{Z_weight}*post_{Z_col}'
            eta_expr_counterfactual = f'{eta_expr_base}{op}{Z_weight}*post_{Z_col}_counterfactual'

            print(f'For Y_col = {Y_col}, year = {year}, u_noise = {u_noise}, eta = {eta_expr}')

            df_year = (df_lagged
                .where(F.col('year') == year-1)
                .withColumnRenamed(Y_col, 'Y_lag_1')
                .withColumnRenamed(f'{Y_col}_counterfactual', 'Y_lag_1_counterfactual')
                .withColumn('year', F.lit(year))
                .withColumn(f'post_{Z_col}', F.col(Z_col) if year in [3,4] else F.lit(0) )
                .withColumn(f'post_{Z_col}_counterfactual',
                    F.when((F.col('year').isin([3,4])) & (F.col(f'post_{Z_col}') == 0), F.lit(1))
                    .when(F.col('year').isin([3,4]) & (F.col(f'post_{Z_col}') == 1), F.lit(0))
                    .otherwise(F.lit(0))
                )
                .withColumn(f'eta_{Y_col}', F.expr(eta_expr))
                .withColumn(f'eta_{Y_col}_counterfactual', F.expr(eta_expr_counterfactual))
            )

            # compute Y columns
            # note that Y and Y_counterfactual can have different values, even in year 1 and 2, due to their different random noise seeds
            if year == 1:
                df_year = (df_year
                    .withColumn(Y_col, F.col('Y_lag_1') + u_noise*F.log(F.rand(seed=iseed)))
                    .withColumn(f'{Y_col}_counterfactual', F.col('Y_lag_1') + u_noise*F.log(F.rand(seed=iseed+1)))
                )
                iseed += 2

            else:
                df_year = (df_year
                    .withColumn('T', F.when( Y_DGP_param.get('p_T', 0.5) < F.rand(seed=iseed), 1).otherwise(0))
                    .withColumn('T_counterfactual', F.when( Y_DGP_param.get('p_T', 0.5) < F.rand(seed=iseed+1), 1).otherwise(0))

                    .withColumn('y', -F.exp(F.col(f'eta_{Y_col}'))*F.log(F.rand(seed=iseed+2)) )
                    .withColumn('y_counterfactual', -F.exp(F.col(f'eta_{Y_col}_counterfactual'))*F.log(F.rand(seed=iseed+3)) )

                    .withColumn(Y_col, F.col('T')*F.col('Y_lag_1') + (1-F.col('T'))*F.col('y') + u_noise*F.log(F.rand(seed=iseed+4)))
                    .withColumn(f'{Y_col}_counterfactual', F.col('T_counterfactual')*F.col('Y_lag_1_counterfactual') + (1-F.col('T_counterfactual'))*F.col('y_counterfactual') + u_noise*F.log(F.rand(seed=iseed+5)))
                )
                iseed += 6

            # done with this year, save results
            df_years.append(df_year.select(cols_of_df_year))
            df_lagged = reduce(DataFrame.unionAll, df_years)

        # done with this Y, save results
        Y_results[Y_col] = {'Z_col_with_post': Z_col_with_post,
                            'Y_results_cols': Y_results_cols,
                           'df': df_lagged.select([col for col in cols_of_df_year if col not in Y_features_X_V]),
                          }

# combine results from multiple Y_results into one df
# need these steps to get the columns right due to duplicate Y_join_cols, Z_col_with_post
df_DFP_Y_complete = Y_results[list(Y_results.keys())[0]]['df'].select(Y_join_cols)

Z_col_with_post_flat = []
Y_results_cols_flat = []
for Y_col,_dict in Y_results.items():
    Z_col_with_post = _dict['Z_col_with_post']
    if set(Z_col_with_post).issubset(set(df_DFP_Y_complete.columns)):
        # we already have these Z cols, so drop them
        Z_col_with_post = []
    Z_col_with_post_flat += Z_col_with_post

    Y_results_cols = _dict['Y_results_cols']
    Y_results_cols_flat += Y_results_cols

    df_DFP_Y_complete = df_DFP_Y_complete.join(_dict['df'].select(Y_join_cols+Z_col_with_post+Y_results_cols), Y_join_cols, 'left')

df_DFP_Y_complete = df_DFP_Y_complete.select(Y_join_cols+Z_col_with_post_flat+Y_results_cols_flat)

df_DFP_Y_complete.persist();
print(f'{df_DFP_Y_complete.count():,}')


For Y_col = Y_DGP_1_with_Z_DGP_1, year = 1, u_noise = -20.0, eta = 6.6 + -0.2*X8 + 0.3*V4 + 0*post_Z_DGP_1
For Y_col = Y_DGP_1_with_Z_DGP_1, year = 2, u_noise = -20.0, eta = 6.6 + -0.2*X8 + 0.3*V4 + 0*post_Z_DGP_1
For Y_col = Y_DGP_1_with_Z_DGP_1, year = 3, u_noise = -20.0, eta = 6.8 + -0.2*X8 + 0.5*V4 + -0.05*post_Z_DGP_1
For Y_col = Y_DGP_1_with_Z_DGP_1, year = 4, u_noise = -20.0, eta = 6.8 + -0.2*X8 + 0.5*V4 + -0.05*post_Z_DGP_1
For Y_col = Y_DGP_1_with_Z_DGP_2, year = 1, u_noise = -20.0, eta = 6.6 + -0.2*X8 + 0.3*V4 + 0*post_Z_DGP_2
For Y_col = Y_DGP_1_with_Z_DGP_2, year = 2, u_noise = -20.0, eta = 6.6 + -0.2*X8 + 0.3*V4 + 0*post_Z_DGP_2
For Y_col = Y_DGP_1_with_Z_DGP_2, year = 3, u_noise = -20.0, eta = 6.8 + -0.2*X8 + 0.5*V4 + -0.05*post_Z_DGP_2
For Y_col = Y_DGP_1_with_Z_DGP_2, year = 4, u_noise = -20.0, eta = 6.8 + -0.2*X8 + 0.5*V4 + -0.05*post_Z_DGP_2
For Y_col = Y_DGP_2_with_Z_DGP_1, year = 1, u_noise = -20.0, eta = 6.6 + -0.2*X8 + 0.3*V4 + 0*post_Z_DGP_1
For Y_col = Y_DGP_2_w

In [18]:
df_DFP_Y_complete.where(F.col('year') == 1).limit(5).toPandas()

Unnamed: 0,dataset_num,id_practice,id_patient,year,Z_DGP_1,post_Z_DGP_1,Z_DGP_2,post_Z_DGP_2,Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_1_counterfactual,eta_Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_2,Y_DGP_1_with_Z_DGP_2_counterfactual,eta_Y_DGP_1_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_1_counterfactual,eta_Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_2_counterfactual,eta_Y_DGP_2_with_Z_DGP_2
0,1,2,363,1,1,0,0,0,551.420774,500.990259,6.7766,504.830658,506.944063,6.7766,509.490303,521.381308,6.7766,508.807436,524.141194,6.7766
1,1,6,3009,1,1,0,0,0,1766.877638,1766.373945,6.4648,1763.159127,1810.959684,6.4648,1833.61048,1774.728645,6.4648,1774.277601,1768.725348,6.4648
2,1,7,4038,1,1,0,0,0,201.993679,174.952658,6.0723,185.112281,178.443863,6.0723,179.53033,176.826448,6.0723,179.761145,187.268608,6.0723
3,1,8,4300,1,1,0,0,0,71.271274,82.434823,6.7185,67.911397,84.312916,6.7185,82.959929,82.484712,6.7185,67.79099,75.240347,6.7185
4,1,8,4540,1,1,0,0,0,388.489843,415.920122,6.9855,407.402459,393.457063,6.9855,365.182255,385.593401,6.9855,417.690311,358.786817,6.9855


In [19]:
df_DFP_Y_complete.where((F.col('year') == 3) & (F.col('Z_DGP_1') == 0)).limit(5).toPandas()

Unnamed: 0,dataset_num,id_practice,id_patient,year,Z_DGP_1,post_Z_DGP_1,Z_DGP_2,post_Z_DGP_2,Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_1_counterfactual,eta_Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_2,Y_DGP_1_with_Z_DGP_2_counterfactual,eta_Y_DGP_1_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_1_counterfactual,eta_Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_2_counterfactual,eta_Y_DGP_2_with_Z_DGP_2
0,1,15,7976,3,0,0,1,1,286.892119,336.072983,6.1305,471.822968,1021.87341,6.0805,609.037232,555.601966,6.1305,589.201346,240.426331,6.0305
1,1,15,8778,3,0,0,1,1,474.2303,403.763624,6.264,1280.007226,237.180569,6.214,574.417118,93.991173,6.264,277.648564,65.609711,6.164
2,1,17,330897,3,0,0,0,0,1035.0191,50.945003,6.9109,162.243129,250.302923,6.9109,520.193596,305.077055,6.9109,450.106231,278.48477,6.9109
3,1,18,12272,3,0,0,0,0,1862.115972,1331.013941,6.5819,319.623643,233.02379,6.5819,654.749362,441.042663,6.5819,955.829184,97.757503,6.5819
4,1,21,14122,3,0,0,0,0,782.824724,815.632554,6.3546,790.812844,816.637356,6.3546,467.590869,811.395452,6.3546,1824.192779,248.358486,6.3546


In [20]:
df_DFP_Y_complete.where((F.col('year') == 3) & (F.col('Z_DGP_1') == 1)).limit(5).toPandas()

Unnamed: 0,dataset_num,id_practice,id_patient,year,Z_DGP_1,post_Z_DGP_1,Z_DGP_2,post_Z_DGP_2,Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_1_counterfactual,eta_Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_2,Y_DGP_1_with_Z_DGP_2_counterfactual,eta_Y_DGP_1_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_1_counterfactual,eta_Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_2_counterfactual,eta_Y_DGP_2_with_Z_DGP_2
0,1,3,729,3,1,1,0,0,6734.282128,665.419107,8.1131,697.071258,2725.85746,8.1631,1053.403617,5198.190114,8.0631,841.901471,1335.778518,8.1631
1,1,3,987,3,1,1,0,0,1218.11507,19691.061162,6.9561,154.355087,1580.283375,7.0061,296.695108,25.85451,6.9061,226.457565,19720.025051,7.0061
2,1,3,1559,3,1,1,0,0,27.121543,1465.200326,7.4456,1380.953642,296.939817,7.4956,1375.178289,6577.506579,7.3956,1426.018128,1640.941526,7.4956
3,1,7,3200,3,1,1,0,0,171.508439,553.336109,6.2452,746.689804,145.263438,6.2952,319.163689,477.806867,6.1952,66.092184,129.603149,6.2952
4,1,7,3521,3,1,1,0,0,630.167057,125.833437,7.1352,2239.442812,391.982381,7.1852,179.315206,1661.515433,7.0852,1724.778374,186.438354,7.1852


In [21]:
df_DFP_Y_complete.where((F.col('year') == 4) & (F.col('Z_DGP_1') == 0)).limit(5).toPandas()

Unnamed: 0,dataset_num,id_practice,id_patient,year,Z_DGP_1,post_Z_DGP_1,Z_DGP_2,post_Z_DGP_2,Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_1_counterfactual,eta_Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_2,Y_DGP_1_with_Z_DGP_2_counterfactual,eta_Y_DGP_1_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_1_counterfactual,eta_Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_2_counterfactual,eta_Y_DGP_2_with_Z_DGP_2
0,1,10,5432,4,0,0,0,0,131.22052,412.670054,6.1846,164.99301,1456.924866,6.1846,368.100729,876.052051,6.1846,233.517006,45.064861,6.1846
1,1,11,5872,4,0,0,0,0,274.658402,649.867442,7.1726,1050.362231,429.07648,7.1726,2305.795895,727.494708,7.1726,449.859276,1471.112821,7.1726
2,1,15,7249,4,0,0,1,1,1011.462463,424.178811,6.264,207.042178,50.082428,6.214,713.636031,270.32242,6.264,570.482809,1224.135844,6.164
3,1,15,7423,4,0,0,1,1,488.257196,325.464887,6.1305,140.867541,567.482994,6.0805,192.17777,2190.32297,6.1305,200.274585,431.046114,6.0305
4,1,15,330596,4,0,0,1,1,1399.329717,1348.014087,6.175,415.748896,796.739237,6.125,618.012513,23.315425,6.175,226.10064,661.8221,6.075


In [22]:
df_DFP_Y_complete.where((F.col('year') == 4) & (F.col('Z_DGP_1') == 1)).limit(5).toPandas()

Unnamed: 0,dataset_num,id_practice,id_patient,year,Z_DGP_1,post_Z_DGP_1,Z_DGP_2,post_Z_DGP_2,Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_1_counterfactual,eta_Y_DGP_1_with_Z_DGP_1,Y_DGP_1_with_Z_DGP_2,Y_DGP_1_with_Z_DGP_2_counterfactual,eta_Y_DGP_1_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_1_counterfactual,eta_Y_DGP_2_with_Z_DGP_1,Y_DGP_2_with_Z_DGP_2,Y_DGP_2_with_Z_DGP_2_counterfactual,eta_Y_DGP_2_with_Z_DGP_2
0,1,3,1445,4,1,1,0,0,4069.879631,503.807719,7.1341,533.718722,222.494855,7.1841,1434.066188,1665.788817,7.0841,481.426462,1627.616952,7.1841
1,1,4,1944,4,1,1,0,0,1851.233526,1486.973123,7.3375,4034.555178,1683.808316,7.3875,2001.040926,172.309454,7.2875,331.365757,263.919768,7.3875
2,1,4,2048,4,1,1,0,0,1230.49778,491.845558,6.9815,1023.884186,557.064983,7.0315,362.125098,4445.227304,6.9315,347.459505,1613.602665,7.0315
3,1,6,2830,4,1,1,0,0,116.275164,1368.8131,6.7356,512.320221,564.947905,6.7856,607.72407,360.632933,6.6856,1162.679927,645.173248,6.7856
4,1,7,3622,4,1,1,0,0,1014.352553,666.097118,6.2007,366.435451,145.510554,6.2507,605.841584,47.080749,6.1507,815.435138,401.135091,6.2507


# Write to Snowflake

In [23]:
sf_write_spark_df_to_snowflake_new_table(df_DFP_Y_complete, 'DFP_Y')

# Write to Parquet

In [34]:
df_DFP_Y_complete.write.parquet('s3a://acic-causality-challenge-2022/DGP/DFP_Y', mode='overwrite')

# Convert to Original Format [DEPRECATED]

## Put it all together

In [24]:
# df_practice_year_with_DGPs = df_practice_year.join(df_DFP_Z_complete, Z_join_cols, 'left')
# df_practice_year_with_DGPs.persist();
# df_practice_year_with_DGPs.count()

In [25]:
# df_practice_year_with_DGPs.printSchema()

In [26]:
# df_patient_year_with_DGPs = df_patient_year.join(df_DFP_Y_complete, Y_join_cols, 'left')
# df_patient_year_with_DGPs.persist();
# df_patient_year_with_DGPs.count()

In [27]:
# df_patient_year_with_DGPs.printSchema()

Some of the patients in df_patient_year were not simulated, as they didn't have any pretreatment years. Similarly, some of our simulated patient's years didn't join to df_patient_year, as the patient could be missing year 1 and we used year 2 to start the simulation instead.

In [28]:
# df_patient_year.count()

In [29]:
# df_DFP_Y_complete.count()

In [30]:
# df_patient_year_with_DGPs.count()

## Write to Parquet

In [31]:
# df_practice_year_with_DGPs.write.parquet('s3a://acic-causality-challenge-2022/DGP/practice_year_with_DGPs', mode='overwrite')

In [32]:
# df_patient_year_with_DGPs.write.parquet('s3a://acic-causality-challenge-2022/DGP/patient_year_with_DGPs', mode='overwrite')

## Write to CSV

In [33]:
# # has ugly csv names, but only one file / dir per dataset_num
# df_practice_year_with_DGPs.repartition(3400, 'dataset_num').write.partitionBy('dataset_num').csv('s3a://acic-causality-challenge-2022/DGP/csv/', mode='overwrite', header=True)

# # Would like to write to efs where it's easier to rename files, but there is an error
# df_practice_year_with_DGPs.repartition(3400, 'dataset_num').write.partitionBy('dataset_num').csv('/efs/mepland/acic_DGP_data/csv/practice_year', mode='overwrite', header=True)