In [0]:
import pyspark

In [0]:
TRANSACTION_TABLE = 'tran'
TRANSACTION_ID_COLUMN = 'tid'

LHS_COLUMNS_LIST = ['diagcode']
#LHS_COLUMNS_LIST = ['diagcode', 'servprov']
RHS_COLUMN = 'servcode'

LHS_TABLES = ['diag']
LHS_TABLES_JOIN_COLUMNS = ['dimDiagnosisID']
LHS_TABLES_DESC_COLUMNS = ['DiagnosisShortDesc']

# LHS_TABLES = ['diag', 'provcleaned']
# LHS_TABLES_JOIN_COLUMNS = ['dimDiagnosisID', 'dimProviderID']
# LHS_TABLES_DESC_COLUMNS = ['DiagnosisShortDesc', 'ProviderName']

RHS_TABLE = 'serv'
RHS_TABLE_JOIN_COLUMN = 'dimServiceCodeID'
RHS_TABLE_DESC_COLUMN = 'ServiceCodeShortDesc'

In [0]:
transactions_df = sqlContext.sql(f'select * from {TRANSACTION_TABLE}')
transactions_df = transactions_df.distinct()
transactions_df.createOrReplaceTempView('transactions_df_sql')

In [0]:
tid_count_total = spark.sql(f'\
                            select count(distinct {TRANSACTION_ID_COLUMN}) as tid_count_total \
                            from transactions_df_sql \
                        ')

In [0]:
if len(LHS_COLUMNS_LIST) == 1:
    tid_count_lhs = spark.sql(f'\
                                select {LHS_COLUMNS_LIST[0]}, count(distinct {TRANSACTION_ID_COLUMN}) as tid_count_lhs \
                                from transactions_df_sql \
                                group by {LHS_COLUMNS_LIST[0]} \
                            ')
elif len(LHS_COLUMNS_LIST) == 2:
    tid_count_lhs = spark.sql(f'\
                                select {LHS_COLUMNS_LIST[0]}, {LHS_COLUMNS_LIST[1]}, count(distinct {TRANSACTION_ID_COLUMN}) as tid_count_lhs \
                                from transactions_df_sql \
                                group by {LHS_COLUMNS_LIST[0]}, {LHS_COLUMNS_LIST[1]} \
                            ')
elif len(LHS_COLUMNS_LIST) > 2:
    print('Only two columns allowed at a time in LHS')

In [0]:
tid_count_rhs = spark.sql(f'\
                            select {RHS_COLUMN}, count(distinct {TRANSACTION_ID_COLUMN}) as tid_count_rhs \
                            from transactions_df_sql \
                            group by {RHS_COLUMN} \
                        ')

In [0]:
if len(LHS_COLUMNS_LIST) == 1:
    tid_count_lhs_rhs = spark.sql(f'\
                                    select {LHS_COLUMNS_LIST[0]}, {RHS_COLUMN}, count(distinct {TRANSACTION_ID_COLUMN}) as tid_count_lhs_rhs \
                                    from transactions_df_sql \
                                    group by {LHS_COLUMNS_LIST[0]}, {RHS_COLUMN} \
                                ')
elif len(LHS_COLUMNS_LIST) == 2:
    tid_count_lhs_rhs = spark.sql(f'\
                                    select {LHS_COLUMNS_LIST[0]}, {LHS_COLUMNS_LIST[1]}, {RHS_COLUMN}, count(distinct {TRANSACTION_ID_COLUMN}) as tid_count_lhs_rhs \
                                    from transactions_df_sql \
                                    group by {LHS_COLUMNS_LIST[0]}, {LHS_COLUMNS_LIST[1]}, {RHS_COLUMN} \
                                ')
elif len(LHS_COLUMNS_LIST) > 2:
    print('Only two columns allowed at a time in LHS')

In [0]:
tid_count_total.createOrReplaceTempView('tid_count_total_sql')
tid_count_lhs.createOrReplaceTempView('tid_count_lhs_sql')
tid_count_rhs.createOrReplaceTempView('tid_count_rhs_sql')
tid_count_lhs_rhs.createOrReplaceTempView('tid_count_lhs_rhs_sql')

In [0]:
if len(LHS_COLUMNS_LIST) == 1:
    lift_df = spark.sql(f'select l.{LHS_COLUMNS_LIST[0]}, r.{RHS_COLUMN}, \
                            lr.tid_count_lhs_rhs/t.tid_count_total as support, \
                            lr.tid_count_lhs_rhs/l.tid_count_lhs   as confidence, \
                            r.tid_count_rhs/t.tid_count_total as expected_confidence, \
                            (lr.tid_count_lhs_rhs*t.tid_count_total)/(l.tid_count_lhs*r.tid_count_rhs) as lift \
                            from tid_count_lhs_sql     as l \
                            join tid_count_rhs_sql     as r \
                            join tid_count_total_sql   as t \
                            join tid_count_lhs_rhs_sql as lr on (lr.{LHS_COLUMNS_LIST[0]}=l.{LHS_COLUMNS_LIST[0]}) \
                                                            and (lr.{RHS_COLUMN}=r.{RHS_COLUMN}) \
                        ')
elif len(LHS_COLUMNS_LIST) == 2:
    lift_df = spark.sql(f'select l.{LHS_COLUMNS_LIST[0]}, l.{LHS_COLUMNS_LIST[1]}, r.{RHS_COLUMN}, \
                            lr.tid_count_lhs_rhs/t.tid_count_total as support, \
                            lr.tid_count_lhs_rhs/l.tid_count_lhs   as confidence, \
                            r.tid_count_rhs/t.tid_count_total      as expected_confidence, \
                            (lr.tid_count_lhs_rhs*t.tid_count_total)/(l.tid_count_lhs*r.tid_count_rhs) as lift \
                            from tid_count_lhs_sql     as l \
                            join tid_count_rhs_sql     as r \
                            join tid_count_total_sql   as t \
                            join tid_count_lhs_rhs_sql as lr on (lr.{LHS_COLUMNS_LIST[0]}=l.{LHS_COLUMNS_LIST[0]}) \
                                                            and (lr.{LHS_COLUMNS_LIST[1]}=l.{LHS_COLUMNS_LIST[1]}) \
                                                            and (lr.{RHS_COLUMN}=r.{RHS_COLUMN}) \
                        ')
elif len(LHS_COLUMNS_LIST) > 2:
    print('Only two columns allowed at a time in LHS')

In [0]:
lift_df.createOrReplaceTempView('lift_df_sql')
RHS_df = sqlContext.sql(f'select * from {RHS_TABLE}')
RHS_df.createOrReplaceTempView('RHS_df_sql')

In [0]:
spark.sql('select * from lift_df_sql').show(4)

+--------+--------+--------------------+--------------------+--------------------+------------------+
|diagcode|servcode|             support|          confidence| expected_confidence|              lift|
+--------+--------+--------------------+--------------------+--------------------+------------------+
|   43753|    3923|3.274322760908952E-4|                 1.0|  0.7062277618912488|1.4159737891385653|
|   32981|     715|2.182881840605968E-5|0.009615384615384616|6.548645521817905E-5| 146.8301282051282|
|   36062|    1527|2.182881840605968E-5|                 0.2|8.731527362423872E-5|           2290.55|
|   37369|    1210|1.746305472484774...| 0.19047619047619047|0.001287900285957...|147.89669087974173|
+--------+--------+--------------------+--------------------+--------------------+------------------+
only showing top 4 rows



In [0]:
if len(LHS_COLUMNS_LIST) == 1:
    LHS_0_df = sqlContext.sql(f'select * from {LHS_TABLES[0]}')
    LHS_0_df.createOrReplaceTempView('LHS_0_df_sql')
    res_df = spark.sql(f'\
                            select lft.*, \
                                   lhs0df.{LHS_TABLES_DESC_COLUMNS[0]}, \
                                   rhsdf.{RHS_TABLE_DESC_COLUMN} \
                            from lift_df_sql  lft \
                            join LHS_0_df_sql lhs0df on lhs0df.{LHS_TABLES_JOIN_COLUMNS[0]}=lft.{LHS_COLUMNS_LIST[0]} \
                            join RHS_df_sql   rhsdf  on rhsdf.{RHS_TABLE_JOIN_COLUMN}=lft.{RHS_COLUMN} \
                    ')
elif len(LHS_COLUMNS_LIST) == 2:
    LHS_0_df = sqlContext.sql(f'select * from {LHS_TABLES[0]}')
    LHS_0_df.createOrReplaceTempView('LHS_0_df_sql')
    LHS_1_df = sqlContext.sql(f'select * from {LHS_TABLES[1]}')
    LHS_1_df.createOrReplaceTempView('LHS_1_df_sql')
    res_df = spark.sql(f'\
                            select lft.*, \
                                   lhs0df.{LHS_TABLES_DESC_COLUMNS[0]}, \
                                   lhs1df.{LHS_TABLES_DESC_COLUMNS[1]}, \
                                   rhsdf.{RHS_TABLE_DESC_COLUMN} \
                            from lift_df_sql  lft \
                            join LHS_0_df_sql lhs0df on lhs0df.{LHS_TABLES_JOIN_COLUMNS[0]}=lft.{LHS_COLUMNS_LIST[0]} \
                            join LHS_1_df_sql lhs1df on lhs1df.{LHS_TABLES_JOIN_COLUMNS[1]}=lft.{LHS_COLUMNS_LIST[1]} \
                            join RHS_df_sql   rhsdf  on rhsdf.{RHS_TABLE_JOIN_COLUMN}=lft.{RHS_COLUMN} \
                    ')  

In [0]:
res_df.show()

+--------+--------+--------------------+--------------------+--------------------+------------------+--------------------+--------------------+
|diagcode|servcode|             support|          confidence| expected_confidence|              lift|  DiagnosisShortDesc|ServiceCodeShortDesc|
+--------+--------+--------------------+--------------------+--------------------+------------------+--------------------+--------------------+
|   43753|    3923|3.274322760908952E-4|                 1.0|  0.7062277618912488|1.4159737891385653|         PLEURODYNIA|ROUTINE VENIPUNCTURE|
|   32981|     715|2.182881840605968E-5|0.009615384615384616|6.548645521817905E-5| 146.8301282051282|   ACTINIC KERATOSIS|SHAVE SKIN LESION...|
|   36062|    1527|2.182881840605968E-5|                 0.2|8.731527362423872E-5|           2290.55|OTHER SPONDYLOSIS...|ARTHRD ANT INTERB...|
|   37369|    1210|1.746305472484774...| 0.19047619047619047|0.001287900285957...|147.89669087974173|             MYALGIA|INJ TRIGGER PO

In [0]:
res_df.write.option('header',True).csv('associations.csv')