In [1]:
from typing import cast
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType, ArrayType

# initialise spark session
spark = cast(SparkSession, SparkSession \
    .builder \
    .config("spark.driver.memory", "32g") \
    .config("spark.executor.memory", "32g") \
    .config("spark.memory.fraction", "0.6") \
    .config("spark.memory.storageFraction", "0.5") \
    .config("spark.driver.cores", "4") \
    .config("spark.executor.cores", "4") \
    .config("spark.driver.maxResultSize", "4g")  \
    .config("spark.network.timeout", "1500s") \
    .master("local[*]") \
    .appName("Team4-Project-Hsin-Pao-Huang") \
    .getOrCreate()
)

sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/12 11:03:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Data Preprocessing

There are too many columns (>297k columns!) to load into spark.

So we first load the file without sperating the columns, then manually parse the metadata columns from the gene ones, and transform genes data into a column of array of integers.

We also need to do some preliminary data cleaning and transforming, like transforming sex from 1 and 2 to 0 and 1, as well as using 3 to represent missing data (marked as "NA" in the dataset)

In [2]:
column_names = ['FID', 'IID', 'PAT', 'MAT', 'SEX', 'PHENOTYPE', 'DATA']

def parse_data(spark: SparkSession):
    df = spark.read.format("csv") \
        .option("header", "true") \
        .load('../data/JanBDRcount.raw')

    df = df.withColumnRenamed(df.schema.names[0], 'raw_data')
    df = df.select(F.split('raw_data', ' ', limit=7).alias('raw_data'))

    parsed_df = df.withColumns({
        column_names[0]: df['raw_data'][0], # FID
        column_names[1]: df['raw_data'][1], # IID
        column_names[2]: df['raw_data'][2].cast(IntegerType()), # PAT
        column_names[3]: df['raw_data'][3].cast(IntegerType()), # MAT
        column_names[4]: df['raw_data'][4].cast(IntegerType()), # SEX

        # tranform from 1, 2 to 0, 1
        column_names[5]: df['raw_data'][5].cast(IntegerType()) - F.lit(1), # PHENOTYPE

        # clean up the data by using `3` to represent `NA`
        column_names[6]: F.split(F.regexp_replace(df['raw_data'][6], 'NA', '3'), ' ').cast(ArrayType(IntegerType())), # DATA
    })

    parsed_df = parsed_df.drop('PAT', 'MAT', 'IID', 'raw_data')

    return df, parsed_df

In [3]:
from os import path

df_path = '../data/df'
parsed_df_path = '../data/parsed_df'

# load parsed data from previous run
if path.exists(df_path) and path.exists(parsed_df_path):
    df = spark.read.load(df_path)
    parsed_df = spark.read.load(parsed_df_path)
else:
    df, parsed_df = parse_data(spark)

    # save the parsed data to save some time on future runs
    df.write.save(df_path)
    parsed_df.write.save(parsed_df_path)

### Building the Pipeline
The genomes comes in 0, 1, 2 (and 3 for "NA"), so we need to one-hot encode this data first.

Since the data is stored in a single column as an array of integers, it is difficult to make it work with built-in transformers like OneHotEncoder.

So we've opted to write our own UDF for this reason.

In [4]:
from pyspark.ml import Transformer
from pyspark.sql import DataFrame
from pyspark.ml.functions import array_to_vector

# one-hot encodes columns which we want to train and assemble them into a flattened array
class FeaturesTransformer(Transformer):
    def _transform(self, df: DataFrame):
        encoder = F.udf(
            self._one_hot_encode,
            ArrayType(IntegerType(), containsNull=False),
        )

        new_df = df.withColumn('FEATURES', encoder('SEX', 'SLICED_DATA'))
        return new_df.withColumn('FEATURES', array_to_vector(F.col('FEATURES')))

    def _one_hot_encode(self, sex: int, data: list[int]):
        output = [sex]

        for element in data:
            vec = [0, 0, 0, 0]
            vec[element] = 1
            output += vec

        return output

features_transformer = FeaturesTransformer()

In [5]:
# train test split

train, test = parsed_df.randomSplit([0.7, 0.3], seed=42)

train.cache()
test.cache()

DataFrame[FID: string, SEX: int, PHENOTYPE: int, DATA: array<int>]

In [6]:
length = 3000 # number of features per slice
current = 0 # current index, if resuming training, change to (index of last saved model + 1)

# generates the indices used to slice the data, based on the length and current set above
def indices_generator(length: int, current: int):
    arr_length = parsed_df.select(F.array_size('DATA').alias('length')).take(1)[0]['length']

    while current * length <= arr_length:
        yield current, (current * length) + 1
        current += 1

### Model Training Part 1
Slice the DATA array, and train a model on each slice

In [7]:
# %%script false --no-raise-error
# uncomment the line above to skip training part 1

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(labelCol='PHENOTYPE')

# collect results for analysis
results = list[tuple[int, float]]()

for i, idx in indices_generator(length, current):
    sliced_train = train.withColumn('SLICED_DATA', F.slice('DATA', idx, length)).drop('DATA')
    sliced_test = test.withColumn('SLICED_DATA', F.slice('DATA', idx, length)).drop('DATA')

    classifier = RandomForestClassifier(
        labelCol='PHENOTYPE',
        featuresCol='FEATURES',
        seed=24,
    )

    # parameter tuning, could use more improvments
    paramGrid = ParamGridBuilder() \
        .addGrid(classifier.maxDepth, [5, 7]) \
        .addGrid(classifier.numTrees, [10, 20]) \
        .build()

    validator = CrossValidator(
        estimator=classifier,
        estimatorParamMaps=paramGrid,
        evaluator=evaluator,
    )

    pipeline = Pipeline(stages=[features_transformer, validator])
    model = pipeline.fit(sliced_train)

    # evaluate the model
    predictions = model.transform(sliced_test)
    result = evaluator.evaluate(predictions)

    # save the best model, so that there's no need to retrain it every time
    model.stages[1].bestModel.write().overwrite().save(f'../models/model_{i}')

    # unpersist temporary dataframes to reduce memory usage
    sliced_train.unpersist()
    sliced_test.unpersist()

    results.append((i, result))

    print(f'Slice {i} completed, result: {result}')

                                                                                

Slice 0 completed, result: 0.5327680621798269


                                                                                

Slice 1 completed, result: 0.5732202791026321


                                                                                

Slice 2 completed, result: 0.5779014308426074


                                                                                

Slice 3 completed, result: 0.48277689454160044


                                                                                

Slice 4 completed, result: 0.602543720190779


                                                                                

Slice 5 completed, result: 0.5460166048401343


                                                                                

Slice 6 completed, result: 0.5937113584172408


                                                                                

Slice 7 completed, result: 0.5765765765765766


                                                                                

Slice 8 completed, result: 0.409468291821233


                                                                                

Slice 9 completed, result: 0.5099805688040981


                                                                                

Slice 10 completed, result: 0.44744744744744747


                                                                                

Slice 11 completed, result: 0.5905317081787669


                                                                                

Slice 12 completed, result: 0.47465112170994517


                                                                                

Slice 13 completed, result: 0.46511217099452395


                                                                                

Slice 14 completed, result: 0.5073308602720368


                                                                                

Slice 15 completed, result: 0.518459636106695


                                                                                

Slice 16 completed, result: 0.43985161632220454


                                                                                

Slice 17 completed, result: 0.5758699876346935


                                                                                

Slice 18 completed, result: 0.46970499911676383


                                                                                

Slice 19 completed, result: 0.5972443031266561


                                                                                

Slice 20 completed, result: 0.5472531354884297


                                                                                

Slice 21 completed, result: 0.5841724077018194


                                                                                

Slice 22 completed, result: 0.5195195195195195


                                                                                

Slice 23 completed, result: 0.5313548842960607


                                                                                

Slice 24 completed, result: 0.5908850026497086


                                                                                

Slice 25 completed, result: 0.46846846846846846


                                                                                

Slice 26 completed, result: 0.5267620561738209


                                                                                

Slice 27 completed, result: 0.555908850026497


                                                                                

Slice 28 completed, result: 0.5156332803391627


                                                                                

Slice 29 completed, result: 0.4857798975446035


                                                                                

Slice 30 completed, result: 0.4857798975446034


                                                                                

Slice 31 completed, result: 0.4317258434905494


                                                                                

Slice 32 completed, result: 0.4674085850556438


                                                                                

Slice 33 completed, result: 0.5274686451157039


                                                                                

Slice 34 completed, result: 0.6225048577989754


                                                                                

Slice 35 completed, result: 0.48330683624801274


                                                                                

Slice 36 completed, result: 0.5322381204734146


                                                                                

Slice 37 completed, result: 0.49284578696343406


                                                                                

Slice 38 completed, result: 0.5110404522169228


                                                                                

Slice 39 completed, result: 0.5424836601307189


                                                                                

Slice 40 completed, result: 0.6074898427839603


                                                                                

Slice 41 completed, result: 0.555202261084614


                                                                                

Slice 42 completed, result: 0.5919448860625331


                                                                                

Slice 43 completed, result: 0.5467231937820173


                                                                                

Slice 44 completed, result: 0.47359123829712063


                                                                                

Slice 45 completed, result: 0.5642112700936232


                                                                                

Slice 46 completed, result: 0.5438968380144851


                                                                                

Slice 47 completed, result: 0.5788729906376965


                                                                                

Slice 48 completed, result: 0.5447800741918389


                                                                                

Slice 49 completed, result: 0.5214626391096979


                                                                                

Slice 50 completed, result: 0.44303126656067837


                                                                                

Slice 51 completed, result: 0.5297650591768239


                                                                                

Slice 52 completed, result: 0.5790496378731673


                                                                                

Slice 53 completed, result: 0.4518636283342165


                                                                                

Slice 54 completed, result: 0.509273979862215


                                                                                

Slice 55 completed, result: 0.5036212683271506


                                                                                

Slice 56 completed, result: 0.5188129305776363


                                                                                

Slice 57 completed, result: 0.5380674792439499


                                                                                

Slice 58 completed, result: 0.46511217099452395


                                                                                

Slice 59 completed, result: 0.5249955838191132


                                                                                

Slice 60 completed, result: 0.5458399576046635


                                                                                

Slice 61 completed, result: 0.5910616498851794


                                                                                

Slice 62 completed, result: 0.5089206853912736


                                                                                

Slice 63 completed, result: 0.5486663133721957


                                                                                

Slice 64 completed, result: 0.5117470411588059


                                                                                

Slice 65 completed, result: 0.5292351174704115


                                                                                

Slice 66 completed, result: 0.510863804981452


                                                                                

Slice 67 completed, result: 0.6101395513160219


                                                                                

Slice 68 completed, result: 0.5264087617028793


                                                                                

Slice 69 completed, result: 0.6046634870164281


                                                                                

Slice 70 completed, result: 0.5234057586998763


                                                                                

Slice 71 completed, result: 0.5892951775304716


                                                                                

Slice 72 completed, result: 0.5433668963080727


                                                                                

Slice 73 completed, result: 0.5617382087970324


                                                                                

Slice 74 completed, result: 0.48171701112877585


                                                                                

Slice 75 completed, result: 0.48383677795442503


                                                                                

Slice 76 completed, result: 0.5251722310545841


                                                                                

Slice 77 completed, result: 0.5283518812930577


                                                                                

Slice 78 completed, result: 0.4329623741388447


                                                                                

Slice 79 completed, result: 0.4474474474474474


                                                                                

Slice 80 completed, result: 0.5151033386327505


                                                                                

Slice 81 completed, result: 0.5294117647058824


                                                                                

Slice 82 completed, result: 0.516163222045575


                                                                                

Slice 83 completed, result: 0.4864864864864865


                                                                                

Slice 84 completed, result: 0.5143967496908673


                                                                                

Slice 85 completed, result: 0.5352411234764177


                                                                                

Slice 86 completed, result: 0.5400105988341283


                                                                                

Slice 87 completed, result: 0.5133368662780426


                                                                                

Slice 88 completed, result: 0.5574986751457339


                                                                                

Slice 89 completed, result: 0.4942589648472001


                                                                                

Slice 90 completed, result: 0.5020314432079137


                                                                                

Slice 91 completed, result: 0.4935523759053171


                                                                                

Slice 92 completed, result: 0.5945945945945946


                                                                                

Slice 93 completed, result: 0.6437025260554673


                                                                                

Slice 94 completed, result: 0.5290584702349408


                                                                                

Slice 95 completed, result: 0.5200494612259318


                                                                                

Slice 96 completed, result: 0.5887652358240594


                                                                                

Slice 97 completed, result: 0.4773008302420067


                                                                                

Slice 98 completed, result: 0.5066242713301538


                                                                                

Slice 99 completed, result: 0.5693340399222752


In [8]:
results_df = spark.createDataFrame(results, schema=('model_index', 'model_auroc'))
results_df.write.save('../data/part_1_results_df')

                                                                                

### Model Training part 2
Generate predictions using the models from the previous part, and aggregate those predictions
Then, train another model using those aggregated predictions to generate the final output

In [9]:
from pyspark.ml.classification import RandomForestClassificationModel
from pyspark.ml import PipelineModel

# aggregated predictions
predictions_df = parsed_df.drop('DATA')
predictions_df.cache()

# load each model from part 1
for i, idx in indices_generator(length, current):
    sliced = parsed_df.withColumn('SLICED_DATA', F.slice('DATA', idx, length)).drop('DATA')

    # load the model and generate predictions
    rf_model = RandomForestClassificationModel.load(f'../models/model_{i}')
    model = PipelineModel(stages=[features_transformer, rf_model])

    predictions = model.transform(sliced) \
        .select('FID', 'prediction') \
        .withColumnRenamed('prediction', f'model_{i}')

    # add predictions to predictions_df as a new column
    predictions_df = predictions_df.join(predictions, on='FID')

    # unpersist temporary dataframes to reduce memory usage
    sliced.unpersist()
    predictions.unpersist()


In [10]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# train test split for part 2
train2, test2 = predictions_df.randomSplit([0.7, 0.3], seed=24)
train2.cache()
test2.cache()

# train a new model using the aggregated predictions
model_cols = [col for col in predictions_df.columns if col.startswith('model_') ]
assembler = VectorAssembler(inputCols=model_cols, outputCol='FEATURES')

classifier2 = RandomForestClassifier(
    labelCol='PHENOTYPE',
    featuresCol='FEATURES',
    seed=42,
)
paramGrid2 = ParamGridBuilder() \
    .addGrid(classifier2.maxDepth, [5, 7]) \
    .addGrid(classifier2.numTrees, [10, 20]) \
    .build()
evaluator2 = BinaryClassificationEvaluator(labelCol='PHENOTYPE')
validator2 = CrossValidator(
    estimator=classifier2,
    estimatorParamMaps=paramGrid2,
    evaluator=evaluator2,
)

pipeline2 = Pipeline(stages=[assembler, validator2])
model2 = pipeline2.fit(train2)

# evaluate the model
predictions2 = model2.transform(test2)
result = evaluator2.evaluate(predictions2)

24/05/12 11:38:35 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/05/12 11:41:22 WARN DAGScheduler: Broadcasting large task binary with size 9.5 MiB
24/05/12 11:41:30 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:41:37 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:41:45 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:41:45 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:41:52 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:41:59 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:42:05 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:42:12 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 11:42:19 WARN DAGScheduler: 

In [11]:
print(f'Result: {result}')

Result: 0.933624295326423


In [12]:
predictions2.where('prediction = 1').count()

24/05/12 11:58:23 WARN DAGScheduler: Broadcasting large task binary with size 9.7 MiB
                                                                                

120

In [13]:
# save the final model
model2.stages[1].bestModel.save('../models/final_model')

In [18]:
predictions2.select('FID', 'PHENOTYPE', 'prediction').write.save('../data/predictions2')

24/05/12 12:02:53 WARN DAGScheduler: Broadcasting large task binary with size 9.8 MiB
                                                                                

In [16]:
predictions2.select('FID', 'PHENOTYPE', 'prediction').show()

24/05/12 12:02:34 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 12:02:34 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB
24/05/12 12:02:35 WARN DAGScheduler: Broadcasting large task binary with size 9.6 MiB


+-------------------+---------+----------+
|                FID|PHENOTYPE|prediction|
+-------------------+---------+----------+
|201904690143_R02C02|        1|       1.0|
|202136020022_R07C01|        1|       1.0|
|201058890047_R07C01|        1|       1.0|
|201039770046_R03C02|        1|       1.0|
|201904690139_R12C02|        0|       0.0|
|202136020014_R07C01|        1|       1.0|
|201904690030_R01C02|        1|       1.0|
|202136020008_R04C01|        0|       0.0|
|202136020022_R07C02|        0|       0.0|
|201039780020_R11C02|        1|       1.0|
|202062520039_R11C01|        1|       1.0|
|201039770044_R04C02|        1|       1.0|
|201023670027_R04C01|        1|       1.0|
|201023680016_R10C02|        1|       1.0|
|201023680028_R08C02|        1|       1.0|
|201039780023_R06C02|        1|       0.0|
|201039770053_R10C01|        1|       1.0|
|201039770133_R09C02|        1|       1.0|
|201904690142_R12C02|        1|       1.0|
|201039780020_R09C01|        1|       1.0|
+----------

                                                                                