In [48]:
# !pip install pyspark
# !pip install findspark
import findspark
findspark.init()

In [49]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

from pyspark.sql.functions import isnull

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [50]:
sc.setLogLevel('ERROR')

In [51]:
TRAIN_PATH = '../data/mitbih_train.csv'
TEST_PATH = '../data/mitbih_test.csv'

In [52]:
train_set = spark.read.csv(TRAIN_PATH, header=False, inferSchema=True).cache()
test_set = spark.read.csv(TEST_PATH, header=False, inferSchema=True).cache()

                                                                                

In [53]:
# check for columns with at least one null value
assert len([x for x in train_set.columns if train_set.filter(col(x).isNull()).count() > 0]) == 0
assert len([x for x in test_set.columns if test_set.filter(col(x).isNull()).count() > 0]) == 0

In [54]:
print('Training Data Shape: ', train_set.count(), len(train_set.columns))
print('Test Data Shape: ', test_set.count(), len(test_set.columns))

Training Data Shape:  87554 188
Test Data Shape:  21892 188


In [55]:
# class balance
train_set.groupBy(train_set.columns[187]).count().orderBy('count', ascending=False).show()

+-----+-----+
|_c187|count|
+-----+-----+
|  0.0|72471|
|  4.0| 6431|
|  2.0| 5788|
|  1.0| 2223|
|  3.0|  641|
+-----+-----+



In [56]:
# Classes: ['N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4]
# 0=normal (N), 1=SVEB (S), 2=VEB (V), 3=Fusion beat (F), 4=Unknown beat (Q)

"""
• N: Normal beat
• S: Supraventricular premature beat
• V: Premature ventricular contraction
• F: Fusion of ventricular and normal beat
• Q: Unclassifiable beat

N
Normal
Left/Right bundle branch block
Atrial escape
Nodal escape

S
Atrial premature
Aberrant atrial premature
Nodal premature
Supra-ventricular premature

V
Premature ventricular contraction
Ventricular escape

F
Fusion of ventricular and normal

Q
Paced
Fusion of paced and normal
Unclassifiable
""";

In [57]:
feature_cols = train_set.select(train_set.columns[:-1])
label_col = train_set.select(train_set.columns[-1])

In [72]:
indexer = StringIndexer(
    inputCol='_c187', 
    outputCol="label"
)

assembler = VectorAssembler(
    inputCols=feature_cols.columns,
    outputCol="features"
)

In [74]:
layers = [187, 75, 75, 75, 5]

mlp = MultilayerPerceptronClassifier(
    maxIter=300, 
    layers=layers, 
    blockSize=128, 
    seed=42
)

pipeline = Pipeline(stages=[indexer, assembler, mlp])
mlpModel = pipeline.fit(train_set)

                                                                                

In [75]:
test_feature_cols = test_set.select(test_set.columns[:-1])
test_label_col = test_set.select(test_set.columns[-1])

In [82]:
predictions = mlpModel.transform(test_set)

In [85]:
predictions.columns[-5:]

['label', 'features', 'rawPrediction', 'probability', 'prediction']

In [90]:
predictionAndLabels = predictions.select("prediction", "label")

In [91]:
evaluator_f1 = MulticlassClassificationEvaluator(metricName="f1")

In [92]:
print("Test set f1 = " + str(evaluator.evaluate(predictionAndLabels)))

Test set f1 = 0.9713789511997215
