In [2]:
# !pip install pyspark
# !pip install findspark
import findspark
findspark.init()

In [3]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

from pyspark.sql.functions import isnull
from pyspark.sql.functions import col

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics

from sklearn.metrics import classification_report, confusion_matrix

In [4]:
sc.setLogLevel('ERROR')

In [5]:
TRAIN_PATH = '../data/mitbih_train.csv'
TEST_PATH = '../data/mitbih_test.csv'

In [6]:
train_set = spark.read.csv(TRAIN_PATH, header=False, inferSchema=True).cache()
test_set = spark.read.csv(TEST_PATH, header=False, inferSchema=True).cache()

In [14]:
# check for columns with at least one null value
assert len([x for x in train_set.columns if train_set.filter(col(x).isNull()).count() > 0]) == 0
assert len([x for x in test_set.columns if test_set.filter(col(x).isNull()).count() > 0]) == 0

In [7]:
print('Training Data Shape: ', train_set.count(), len(train_set.columns))
print('Test Data Shape: ', test_set.count(), len(test_set.columns))

[Stage 4:>                                                          (0 + 1) / 1]                                                                                

Training Data Shape:  100 188
Test Data Shape:  100 188


In [8]:
# class balance
train_set.groupBy(train_set.columns[187]).count().orderBy('count', ascending=False).show()

+-----+-----+
|_c187|count|
+-----+-----+
|  0.0|  100|
+-----+-----+



In [17]:
# Classes: ['N': 0, 'S': 1, 'V': 2, 'F': 3, 'Q': 4]
# 0=normal (N), 1=SVEB (S), 2=VEB (V), 3=Fusion beat (F), 4=Unknown beat (Q)

"""
• N: Normal beat
• S: Supraventricular premature beat
• V: Premature ventricular contraction
• F: Fusion of ventricular and normal beat
• Q: Unclassifiable beat

N
Normal
Left/Right bundle branch block
Atrial escape
Nodal escape

S
Atrial premature
Aberrant atrial premature
Nodal premature
Supra-ventricular premature

V
Premature ventricular contraction
Ventricular escape

F
Fusion of ventricular and normal

Q
Paced
Fusion of paced and normal
Unclassifiable
""";

In [18]:
feature_cols = train_set.select(train_set.columns[:-1])
label_col = train_set.select(train_set.columns[-1])

In [19]:
indexer = StringIndexer(
    inputCol='_c187', 
    outputCol="label"
)

assembler = VectorAssembler(
    inputCols=feature_cols.columns,
    outputCol="features"
)

In [20]:
layers = [187, 75, 75, 75, 5]

mlp = MultilayerPerceptronClassifier(
    maxIter=300, 
    layers=layers, 
    blockSize=128, 
    seed=42
)

pipeline = Pipeline(stages=[indexer, assembler, mlp])
mlpModel = pipeline.fit(train_set)

                                                                                

In [21]:
test_feature_cols = test_set.select(test_set.columns[:-1])
test_label_col = test_set.select(test_set.columns[-1])

In [22]:
predictions = mlpModel.transform(test_set)

In [23]:
predictions.columns[-5:]

['label', 'features', 'rawPrediction', 'probability', 'prediction']

In [24]:
predictionAndLabels = predictions.select("prediction", "label")

In [25]:
evaluator_f1 = MulticlassClassificationEvaluator(metricName="f1")

In [32]:
print("Test set f1 =", str(evaluator_f1.evaluate(predictionAndLabels)))

Test set f1 = 0.9713789511997215


In [29]:
y_true = predictions.select(['label']).collect()
y_pred = predictions.select(['prediction']).collect()

                                                                                

In [30]:
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

         0.0       0.98      0.99      0.99     18118
         1.0       0.97      0.96      0.97      1608
         2.0       0.93      0.90      0.92      1448
         3.0       0.90      0.62      0.74       556
         4.0       0.78      0.64      0.70       162

    accuracy                           0.97     21892
   macro avg       0.91      0.82      0.86     21892
weighted avg       0.97      0.97      0.97     21892

