In [0]:
from functools import reduce
from pyspark.sql import Row
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, FloatType, DoubleType, StructType, StructField
from pyspark.sql.functions import udf, array
from pyspark.sql import functions as F
from pyspark.sql.functions import explode, col, lit, udf, mean as _mean, stddev as _stddev, log, log10, isnan, when, count, sqrt, greatest, least

from pyspark.mllib.tree import DecisionTree
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression

from pyspark.ml.feature import VectorAssembler, StandardScaler
import time

In [0]:
# File location and type
train_features_location = "/FileStore/tables/train_features.csv"
train_scored_location = "/FileStore/tables/train_targets_scored-1.csv"

# CSV options
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","
file_type = "csv"


# The applied options are for CSV files. For other file types, these will be ignored.
features_df = spark.read.format(file_type) \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(train_features_location)

labels_df = spark.read.format(file_type) \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(train_scored_location)

display(features_df)
display(labels_df)

sig_id,5-alpha_reductase_inhibitor,11-beta-hsd1_inhibitor,acat_inhibitor,acetylcholine_receptor_agonist,acetylcholine_receptor_antagonist,acetylcholinesterase_inhibitor,adenosine_receptor_agonist,adenosine_receptor_antagonist,adenylyl_cyclase_activator,adrenergic_receptor_agonist,adrenergic_receptor_antagonist,akt_inhibitor,aldehyde_dehydrogenase_inhibitor,alk_inhibitor,ampk_activator,analgesic,androgen_receptor_agonist,androgen_receptor_antagonist,anesthetic_-_local,angiogenesis_inhibitor,angiotensin_receptor_antagonist,anti-inflammatory,antiarrhythmic,antibiotic,anticonvulsant,antifungal,antihistamine,antimalarial,antioxidant,antiprotozoal,antiviral,apoptosis_stimulant,aromatase_inhibitor,atm_kinase_inhibitor,atp-sensitive_potassium_channel_antagonist,atp_synthase_inhibitor,atpase_inhibitor,atr_kinase_inhibitor,aurora_kinase_inhibitor,autotaxin_inhibitor,bacterial_30s_ribosomal_subunit_inhibitor,bacterial_50s_ribosomal_subunit_inhibitor,bacterial_antifolate,bacterial_cell_wall_synthesis_inhibitor,bacterial_dna_gyrase_inhibitor,bacterial_dna_inhibitor,bacterial_membrane_integrity_inhibitor,bcl_inhibitor,bcr-abl_inhibitor,benzodiazepine_receptor_agonist,beta_amyloid_inhibitor,bromodomain_inhibitor,btk_inhibitor,calcineurin_inhibitor,calcium_channel_blocker,cannabinoid_receptor_agonist,cannabinoid_receptor_antagonist,carbonic_anhydrase_inhibitor,casein_kinase_inhibitor,caspase_activator,catechol_o_methyltransferase_inhibitor,cc_chemokine_receptor_antagonist,cck_receptor_antagonist,cdk_inhibitor,chelating_agent,chk_inhibitor,chloride_channel_blocker,cholesterol_inhibitor,cholinergic_receptor_antagonist,coagulation_factor_inhibitor,corticosteroid_agonist,cyclooxygenase_inhibitor,cytochrome_p450_inhibitor,dihydrofolate_reductase_inhibitor,dipeptidyl_peptidase_inhibitor,diuretic,dna_alkylating_agent,dna_inhibitor,dopamine_receptor_agonist,dopamine_receptor_antagonist,egfr_inhibitor,elastase_inhibitor,erbb2_inhibitor,estrogen_receptor_agonist,estrogen_receptor_antagonist,faah_inhibitor,farnesyltransferase_inhibitor,fatty_acid_receptor_agonist,fgfr_inhibitor,flt3_inhibitor,focal_adhesion_kinase_inhibitor,free_radical_scavenger,fungal_squalene_epoxidase_inhibitor,gaba_receptor_agonist,gaba_receptor_antagonist,gamma_secretase_inhibitor,glucocorticoid_receptor_agonist,glutamate_inhibitor,glutamate_receptor_agonist,glutamate_receptor_antagonist,gonadotropin_receptor_agonist,gsk_inhibitor,hcv_inhibitor,hdac_inhibitor,histamine_receptor_agonist,histamine_receptor_antagonist,histone_lysine_demethylase_inhibitor,histone_lysine_methyltransferase_inhibitor,hiv_inhibitor,hmgcr_inhibitor,hsp_inhibitor,igf-1_inhibitor,ikk_inhibitor,imidazoline_receptor_agonist,immunosuppressant,insulin_secretagogue,insulin_sensitizer,integrin_inhibitor,jak_inhibitor,kit_inhibitor,laxative,leukotriene_inhibitor,leukotriene_receptor_antagonist,lipase_inhibitor,lipoxygenase_inhibitor,lxr_agonist,mdm_inhibitor,mek_inhibitor,membrane_integrity_inhibitor,mineralocorticoid_receptor_antagonist,monoacylglycerol_lipase_inhibitor,monoamine_oxidase_inhibitor,monopolar_spindle_1_kinase_inhibitor,mtor_inhibitor,mucolytic_agent,neuropeptide_receptor_antagonist,nfkb_inhibitor,nicotinic_receptor_agonist,nitric_oxide_donor,nitric_oxide_production_inhibitor,nitric_oxide_synthase_inhibitor,norepinephrine_reuptake_inhibitor,nrf2_activator,opioid_receptor_agonist,opioid_receptor_antagonist,orexin_receptor_antagonist,p38_mapk_inhibitor,p-glycoprotein_inhibitor,parp_inhibitor,pdgfr_inhibitor,pdk_inhibitor,phosphodiesterase_inhibitor,phospholipase_inhibitor,pi3k_inhibitor,pkc_inhibitor,potassium_channel_activator,potassium_channel_antagonist,ppar_receptor_agonist,ppar_receptor_antagonist,progesterone_receptor_agonist,progesterone_receptor_antagonist,prostaglandin_inhibitor,prostanoid_receptor_antagonist,proteasome_inhibitor,protein_kinase_inhibitor,protein_phosphatase_inhibitor,protein_synthesis_inhibitor,protein_tyrosine_kinase_inhibitor,radiopaque_medium,raf_inhibitor,ras_gtpase_inhibitor,retinoid_receptor_agonist,retinoid_receptor_antagonist,rho_associated_kinase_inhibitor,ribonucleoside_reductase_inhibitor,rna_polymerase_inhibitor,serotonin_receptor_agonist,serotonin_receptor_antagonist,serotonin_reuptake_inhibitor,sigma_receptor_agonist,sigma_receptor_antagonist,smoothened_receptor_antagonist,sodium_channel_inhibitor,sphingosine_receptor_agonist,src_inhibitor,steroid,syk_inhibitor,tachykinin_antagonist,tgf-beta_receptor_inhibitor,thrombin_inhibitor,thymidylate_synthase_inhibitor,tlr_agonist,tlr_antagonist,tnf_inhibitor,topoisomerase_inhibitor,transient_receptor_potential_channel_antagonist,tropomyosin_receptor_kinase_inhibitor,trpv_agonist,trpv_antagonist,tubulin_inhibitor,tyrosine_kinase_inhibitor,ubiquitin_specific_protease_inhibitor,vegfr_inhibitor,vitamin_b,vitamin_d_receptor_agonist,wnt_inhibitor
id_000644bb2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_000779bfc,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_000a6266a,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_0015fd391,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_001626bd3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_001762a82,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_001bd861f,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_0020d0484,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_00224bf20,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
id_0023f063e,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [0]:
feature_rows = features_df.count()
label_rows = labels_df.count()
print('# of feature rows: ', feature_rows)
print('# of label rows: ', label_rows)

# Check for number of duplicate rows
duplicates = features_df.groupBy(features_df.columns).count().where(F.col('count') > 1).select(F.sum('count').alias('Duplicate Row Count')).collect()

# Check for number of rows with for missing value columns
dropped_na_feature_df = features_df.dropna(how="any", thresh=None, subset=None)
num_missing_rows = features_df.count() - dropped_na_feature_df.count()

num_duplicate_rows = 0 if duplicates[0].__getitem__("Duplicate Row Count")==None else duplicates[0].__getitem__("Duplicate Row Count")

print('Number of duplicate rows in features dataframe: ', num_duplicate_rows)
print('Number of rows with missing value in features dataframe: ', num_missing_rows)

In [0]:
'''
Data Processing: 
  - Clean Dataframe(drop useless rows: cp_type=ctrl) 
  - Feature Selection (based on importance of RF model) can use other features
'''
features_df = features_df.filter((features_df.cp_type == 'trt_cp'))
top_18_features = ['sig_id', 'c-98', 'c-65', 'c-17', 'c-70', 'g-392', 'g-100', 'c-53', 'g-628', 'g-91', 'c-6', 'c-32', 'c-85', 'g-75', 'g-175', 'g-385', 'cp_dose', 'cp_time']
df = features_df.select([c for c in features_df.columns if c in top_18_features])
display(df)

sig_id,cp_time,cp_dose,g-75,g-91,g-100,g-175,g-385,g-392,g-628,c-6,c-17,c-32,c-53,c-65,c-70,c-85,c-98
id_000644bb2,24,D1,-0.6332,-1.123,-0.5364,-1.413,1.76,-0.1301,-0.8976,0.2888,-0.2798,0.3817,0.0471,0.1403,-0.0553,0.1799,0.3801
id_000779bfc,72,D1,-0.5857,-0.3854,0.7129,0.2697,-0.1022,0.0452,-0.0585,0.2835,-0.2274,0.1888,0.2111,0.4151,-0.0765,0.442,0.6077
id_000a6266a,48,D1,-1.357,0.4286,-0.7646,1.187,-1.112,-0.228,0.662,-0.7513,0.0643,0.1171,0.7047,-2.364,0.1682,0.1172,-1.408
id_0015fd391,48,D1,-3.321,4.916,1.039,3.458,1.317,5.391,2.825,-2.583,-0.5252,0.067,-1.367,-1.493,-1.187,-1.539,-0.3876
id_001626bd3,72,D2,-0.7868,0.8025,0.0968,0.5282,-1.166,-0.5027,-0.6423,-0.0852,-0.5176,0.5227,-0.229,0.3973,0.1813,0.0698,-0.3786
id_001762a82,24,D1,-0.2243,0.4944,-1.522,-0.0131,0.6992,-0.8619,0.0522,1.334,1.436,1.185,1.494,2.002,0.6596,0.698,0.7848
id_001bd861f,24,D2,3.375,1.58,4.491,1.624,-4.608,4.843,1.978,1.048,0.9935,1.335,-0.3832,0.5135,0.2662,-0.423,0.1351
id_0020d0484,48,D1,2.021,1.232,0.7365,2.385,-1.884,2.807,1.533,-0.2375,0.2757,-0.2019,0.6213,0.2156,-0.1342,0.1194,-0.9622
id_00224bf20,48,D1,1.179,-0.7994,-0.2113,-0.1636,1.063,-0.1412,-0.0318,-0.4887,-0.9109,-0.8075,-0.4742,-0.4082,-0.6872,-0.009,-0.642
id_0023f063e,48,D2,0.1832,0.2758,-0.0175,-0.1438,0.1179,0.2637,-0.8548,-0.8989,-0.4817,-0.5435,-0.9559,-0.9295,0.1205,-0.3377,0.0189


In [0]:
# engineer something with cp_time and cp_dose (maybe as combination?)
'''
Feature Engineering: create statistic features
 - row sum, mean, max, min, std, for g, c, gc
'''
g_features = [feature for feature in df.columns if 'g-' in feature]
c_features = [feature for feature in df.columns if 'c-' in feature]
gc_features = g_features + c_features

df = df.withColumn("g_mean", reduce(lambda x,y: x+y, (col(x) for x in g_features)) / len(g_features))
df = df.withColumn("c_mean", reduce(lambda x,y: x+y, (col(x) for x in c_features)) / len(c_features))
df = df.withColumn("gc_mean", reduce(lambda x,y: x+y, (col(x) for x in gc_features)) / len(gc_features))

df = df.withColumn("g_std", sqrt(reduce(lambda x,y: x-col('g_mean') + y-col('g_mean'), (col(x) for x in g_features))**2 / len(g_features)))
df = df.withColumn("c_std", sqrt(reduce(lambda x,y: x-col('c_mean') + y-col('c_mean'), (col(x) for x in c_features))**2 / len(c_features)))
df = df.withColumn("gc_std", sqrt(reduce(lambda x,y: x-col('gc_mean') + y-col('gc_mean'), (col(x) for x in gc_features))**2 / len(gc_features)))

df = df.withColumn("g_sum", reduce(lambda x,y: x+y, (col(x) for x in g_features)))
df = df.withColumn("c_sum", reduce(lambda x,y: x+y, (col(x) for x in c_features)))
df = df.withColumn("gc_sum", reduce(lambda x,y: x+y, (col(x) for x in gc_features)))

df = df.withColumn("g_max", reduce(lambda x,y: greatest(x,y), g_features))
df = df.withColumn("c_max", reduce(lambda x,y: greatest(x,y), c_features))
df = df.withColumn("gc_max", reduce(lambda x,y: greatest(x,y), gc_features))

df = df.withColumn("g_min", reduce(lambda x,y: least(x,y), g_features))
df = df.withColumn("c_min", reduce(lambda x,y: least(x,y), c_features))
df = df.withColumn("gc_min", reduce(lambda x,y: least(x,y), gc_features))


engineered_features = ["g_mean", "c_mean", "gc_mean", "g_std", "c_std", "gc_std", "g_sum", "c_sum", "gc_sum", "g_max", "c_max", "gc_max", "g_min", "c_min", "gc_min",]

"""
One-Hot-Encoder
  for cp_dose; create the one-hot-vector then drop the original categorical vector
"""
def cp_dose_map(val):
  if val == 'D1':
    return 1
  if val == 'D2':
    return 0
  
dose_udf = udf(cp_dose_map)

df = df.withColumn("dose_one_hot", dose_udf("cp_dose"))
df = df.withColumn("dose_one_hot", df["dose_one_hot"].cast(DoubleType()))
df = df.drop('cp_dose')

In [0]:
# Labels are assembled into one vector. Each cell is a tuple: (length of labels, index of existing value, value)
# assembled_labels_df.select('labels').show()
features = df.columns
features.remove('sig_id')

'''
Feature Vectors Assembler (required)
 - assembles all features into one vector. Each row in the assembled vector is a tuple of all the features
 - using also the one-hot-encoded features in the assembly
'''
features_assembler = VectorAssembler(inputCols=features, outputCol="assembledFeatures")
df = features_assembler.transform(df)

In [0]:
# create a 'train' dataframe: both features and labels
trainingData = df.join(labels_df, on=['sig_id'], how='inner')

'''
Label Vectors Assembler
'''
labels_assembler = VectorAssembler(inputCols=labels_df.columns[1:], outputCol="assembledLabels")
trainingData = labels_assembler.transform(trainingData)

'''
Assemble both feature and label vectors for clustering
  Only for KNN cluster algorithm
  Drop after KNN
'''
feature_label_assembler = VectorAssembler(inputCols=["assembledFeatures", "assembledLabels"], outputCol="assemebled")
trainingData = feature_label_assembler.transform(trainingData)


'''
KMeans clustering - engineer new feature based on cluster results
Used K=6 to fit the assembled features, KNN returned 4 cluster classes.

Include the labels and see what happens - same thing

What does the cluster numbers mean? IE: How to test the validity of the cluster results
'''
from pyspark.ml.clustering import KMeans
# kmeans = KMeans(k=6, featuresCol='assembledFeatures', predictionCol='clusterPrediction', initMode='k-means||', initSteps=2, tol=0.0001, maxIter=20, seed=None, distanceMeasure='euclidean', weightCol=None)
kmeans = KMeans(k=6, featuresCol='assemebled', predictionCol='clusterPrediction', distanceMeasure='euclidean',)
model = kmeans.fit(trainingData)
transformed = model.transform(trainingData).select("sig_id", "clusterPrediction")
# rows = transformed.collect()
summary = model.summary


# join the cluster prediction vector to the dataframe
df = df.join(transformed, on=['sig_id'], how='inner')
df = df.drop("assembledFeatures") # get rid of this assembledFeature column, and make another one without one-hot-encoded features for standardization


# Drop the 'assembled' column vector because we no longer need this (only for clustering)
# trainingData = trainingData.drop("assembled")
# trainingData = trainingData.drop("sig_id") # drop id before training


'''
Standardize/Normalization 
  This step is after feature engineering step, before Feature Selection
  Standardize features in df (dataframe with all the necessary features: original + feature engineered)
  **NOTE** - do not standardize one-hot-encoded labels or cluster labels
'''
df_features = df.columns
df_features.remove('sig_id')

# feature_labels = ["cp_time_onehot", "cp_dose_onehot", "clusterPrediction",] # list of features that we do not want to standardize. IE: one-hot-encoded vectors, cluster vector, ...
feature_labels = ["dose_one_hot", "clusterPrediction",] # list of features that we do not want to standardize. IE: one-hot-encoded vectors, cluster vector, ...
features_to_standardize = [feat_col for feat_col in df_features if feat_col not in feature_labels] # list of feature vectors we want to standardize

# assemble before standardizing
features_assembler = VectorAssembler(inputCols=features_to_standardize, outputCol="assembledFeatures")
df = features_assembler.transform(df)

scaler = StandardScaler(inputCol="assembledFeatures", outputCol="assembledScaledFeatures")
df = scaler.fit(df).transform(df)


# Assemble feature_labels and assembledScaledFeatures
feature_labels.append("assembledScaledFeatures")
features_assembler = VectorAssembler(inputCols=feature_labels, outputCol="allFeatures")
df = features_assembler.transform(df)



aasembled_train = df.select("sig_id", "allFeatures")
display(aasembled_train)

"""
Feature Selection goes here
"""


"""
** Do a final round of assembling for feature vectors

For training - Need to assemble: assembledScaledFeatures & clusterPrediction vectors before fitting a learner on assembledLabels
"""
# features_assembler = VectorAssembler(inputCols=["assembledScaledFeatures", "clusterPrediction"], outputCol="allFeatures")
# trainingData = features_assembler.transform(trainingData)

In [0]:
display(aasembled_train)


sig_id,allFeatures
id_000644bb2,"List(1, 33, List(), List(1.0, 2.0, 1.2368650097727056, -0.2737066921494381, -0.5734680814668819, -0.24356451261839523, -0.8456710759796818, 0.9322994272077685, -0.06347510666600883, -0.7465676683439648, 0.1251957216140379, -0.1305412989454399, 0.20875095602902383, 0.02606952407800695, 0.06105621528662336, -0.02445615521331749, 0.09256787593636896, 0.2012491488069956, -0.6234982412415624, 0.07038344624763396, -0.12300266548920595, 0.7554046976258204, 0.07488954360851474, 0.133068028474023, -0.6234982412415624, 0.07038344624763396, -0.123002665489206, 0.7396021913775526, 0.2298149895598879, 0.7698510193070086, -0.6801574856910926, -0.12240125853979449, -0.6140053087986103))"
id_000779bfc,"List(1, 33, List(), List(1.0, 0.0, 3.710595029318117, -0.2531743676435974, -0.19680730062095841, 0.3237083166399216, 0.1614136512326399, -0.054136932648087464, 0.022052842592648725, -0.04865664950771162, 0.1228981547007609, -0.10609396490419239, 0.10325433717128556, 0.11684238923285067, 0.18064458279028764, -0.033831751786958186, 0.22743191308435284, 0.32175508479350495, -0.021808703154448775, 0.11988196334919778, 0.11323011835539017, 0.02642252330847392, 0.12755706065495373, 0.122495789449004, -0.02180870315444878, 0.11988196334919778, 0.11323011835539022, 0.2995809103596916, 0.3658856933600836, 0.3118334043545264, -0.2819308134248216, -0.0994783638025349, -0.2545101977093744))"
id_000a6266a,"List(1, 33, List(), List(1.0, 1.0, 2.4737300195454113, -0.5865760916721218, 0.2188676934253834, -0.34718386716634037, 0.7104115832893717, -0.5890437290085446, -0.11124000245849358, 0.5506102901556426, -0.325690947536796, 0.02999930493992775, 0.06404175255697848, 0.39004657362572176, -1.0287732924987711, 0.0743856294191682, 0.06030547559612252, -0.7454848764016042, -0.2482836974506474, -0.2178714768496671, -0.2951153081861626, 0.3008102653580102, 0.2318200704350017, 0.31926472549720886, -0.2482836974506474, -0.2178714768496671, -0.2951153081861627, 0.4988112506620198, 0.42428772109733565, 0.5192120226803518, -0.6532014919198957, -1.034155022116062, -1.0272530431704987))"
id_0015fd391,"List(1, 33, List(), List(1.0, 1.0, 2.4737300195454113, -1.4355336775557233, 2.510390995985033, 0.47178137324853214, 2.069589936827841, 0.6976354236549039, 2.630240584446223, 2.349658715543339, -1.119738742829155, -0.2450332030241066, 0.036642164144471044, -0.7566250406504351, -0.6497286487735472, -0.5249449590995996, -0.791895281078776, -0.20522012648669163, 3.2765479498871324, -0.5859740406660241, 0.4300831628758261, 3.9697300643740774, 0.6234893403417429, 0.46527705994127855, 3.276547949887133, -0.5859740406660241, 0.43008316287582626, 2.265451939611583, 0.0403395449319164, 2.358106161979593, -1.5985867020383004, -1.1299587234034638, -1.4431080187687082))"
id_001626bd3,"List(1, 33, List(), List(0.0, 0.0, 3.710595029318117, -0.3401017457093776, 0.4098024357766453, 0.043954222262230905, 0.3161241771638131, -0.6176483705251465, -0.24526468963107334, -0.5342250594667208, -0.036934471888905925, -0.24148740648377298, 0.2858635701241047, -0.12674991536865374, 0.1728983202663967, 0.0801790405094839, 0.035915718401103684, -0.20045495327105636, -0.35026035460457455, -0.0025545524912560136, -0.11123266697717349, 0.4243609680975373, 0.002718100354464933, 0.12033488573350411, -0.3502603546045746, -0.0025545524912560136, -0.11123266697717356, 0.33723338555709426, 0.3147086587449658, 0.35102581988288317, -0.5612622988788492, -0.22642920450392287, -0.5066738783150598))"
id_001762a82,"List(1, 33, List(), List(1.0, 2.0, 1.2368650097727056, -0.09695579761389603, 0.2524689398728641, -0.6910984120156554, -0.007840262629394077, 0.3703771360816317, -0.4205164829779632, 0.043416702637650365, 0.5782932570399121, 0.6699689252525077, 0.6480740971820624, 0.8269186618374178, 0.8712369422937987, 0.29170488207421724, 0.35915718401103686, 0.41552310440339424, -0.2884410691244641, 0.6235838134762207, 0.5346858528028277, 0.3494632770269789, 0.6635069704967916, 0.5784394346451588, -0.2884410691244641, 0.6235838134762207, 0.5346858528028279, 0.29382377966544587, 1.2053696858760692, 0.8757055344617222, -0.7326254021386008, 0.2885484994026034, -0.6613701910767762))"
id_001bd861f,"List(1, 33, List(), List(0.0, 2.0, 1.2368650097727056, 1.4588756885728895, 0.8068384405322114, 2.0392397952446175, 0.9719531687126703, -2.440929409416703, 2.3628742627477384, 1.6451769696795484, 0.45431134436119025, 0.4635195872133472, 0.7301087930278931, -0.21209854833741532, 0.22346661831561718, 0.11772565131618654, -0.217655428132763, 0.07153054460359143, 2.78543273077445, 0.22653615489252787, 1.090992327526874, 3.374715164485175, 0.24103947952533694, 1.1802687163476677, 2.78543273077445, 0.22653615489252787, 1.0909923275268745, 2.035166711841754, 0.8037804848374388, 2.1184025491499106, -2.2180932017441997, -0.18504550522635121, -2.002361261814576))"
id_0020d0484,"List(1, 33, List(), List(1.0, 1.0, 2.4737300195454113, 0.8735963752906102, 0.6291297207187876, 0.3344244286790606, 1.4274065932141125, -0.9979841595792248, 1.3695205565832957, 1.275053738381571, -0.10295700790628118, 0.12862843502236518, -0.1104187006084881, 0.3438852507360024, 0.0938255168624091, -0.059349295291631236, 0.06143748964314871, -0.5094499631204714, 1.8517476269746125, -0.01974740577210122, 0.5547774809980494, 2.2435008853411382, 0.02101167653146177, 0.6001751697378741, 1.8517476269746127, -0.01974740577210122, 0.5547774809980497, 1.1795814495436305, 0.3740740188985024, 1.227824892724303, -0.9068766475881234, -0.42092384191204524, -0.8186737450648136))"
id_00224bf20,"List(1, 33, List(), List(1.0, 1.0, 2.4737300195454113, 0.5096339072081294, -0.4082193983300315, -0.09594552855381601, -0.09791350886785274, 0.5630876654101465, -0.06889073836464601, -0.026449255629832982, -0.21185300953178787, -0.4249823774460372, -0.4416201126367219, -0.26246641863674935, -0.17764181810406027, -0.3039108474248061, -0.004630966556016234, -0.3399156893819815, 0.18761775684889712, -0.28780641388127903, -0.2298695674019384, 0.22730991927011113, 0.30623239031716, 0.2486798966404244, 0.18761775684889714, -0.28780641388127903, -0.22986956740193853, 0.4954494225193946, -0.005418744841600711, 0.51571269986532, -0.3847968110838354, -0.39848215298033884, -0.395822672175976))"
id_0023f063e,"List(1, 33, List(), List(0.0, 1.0, 2.4737300195454113, 0.07918993367305284, 0.14083926702454677, -0.007946269520547944, -0.08606334092418842, 0.06245346731124767, 0.12865784494870505, -0.7109692991314853, -0.3896760185556049, -0.22473818335246032, -0.2972390479480599, -0.5290840353750921, -0.404502866064978, 0.053290537128476624, -0.17376415621852026, 0.010006863752834034, -0.03680218657313227, -0.2605123530395894, -0.27218040229621543, 0.04458800808304948, 0.27719090586830947, 0.29445304602770717, -0.03680218657313227, -0.2605123530395894, -0.27218040229621554, 0.1158990252170051, 0.07255097260143174, 0.12063915404822328, -0.41146399063605515, -0.4181678450256953, -0.41537698137338397))"


In [0]:
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml import Pipeline

'''
Binary Relevance - logistic on each target
'''
# loop through each label and train a logistic regression model on the train data that classifies that label
# model_dict = dict()
# start_time = time.time()

# for ind, label in enumerate(labels_df.columns[1:]):
#   sub_start_time = time.time()
#   print(f'Training on label {ind}: {label}')
#   lr = LogisticRegression(featuresCol = 'assembledFeatures', labelCol = label, maxIter=15)
#   pipeline = Pipeline(stages=[lr])
#   sub_model = pipeline.fit(trainingData)
  
#   elapsed_time = time.time() - sub_start_time
#   print(f'Training time for: {label}, {elapsed_time}s')
  
#   model_dict[label] = {
#     'model': lr,
#     'time': elapsed_time
#   }
  
#   lr.save_model()
#   print(f'Total elapsed time: {time.time()-start_time}s\n')

In [0]:
'''
This doesn't work, haven't figured out how to write a trained model to file
'''
import pickle

for key, val in model_dict.items():
  print(key)
  model = val['model']
  file_name = '/tmp/ml_python_model_export/logit_' + key + '.pkl'
  
  with open(file_name, 'wb') as f:
    pickle.dumps(model, f)
    
  ModelExport.exportModel(model, f"/tmp/ml_python_model_export/{file_name}")