In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline

# If not done already
spark = SparkSession.builder.appName("HeartDisease").getOrCreate()

# Step 1: Load the file (already uploaded)
columns = [
    "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
    "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target"
]

file_p = "/Users/mehdiamian/Desktop/Sohrab/heart+disease/processed.cleveland.data"  # or "cleveland.csv" if you renamed

df = spark.read.csv(file_p, inferSchema=True)
df = df.toDF(*columns)

# Step 2: Clean rows containing "?" â€” convert to null, drop
for col_name in ["ca", "thal"]:
    df = df.withColumn(col_name, col(col_name).cast("string"))
    df = df.filter(~(col(col_name) == "?"))

# Convert cleaned columns to float
df = df.withColumn("ca", col("ca").cast("float"))
df = df.withColumn("thal", col("thal").cast("float"))

# Convert target to binary (0 = no disease, >=1 = disease)
df = df.withColumn("label", (col("target") > 0).cast("integer"))

# Step 3: Assemble features
feature_cols = [
    "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
    "thalach", "exang", "oldpeak", "slope", "ca", "thal"
]

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
final_df = assembler.transform(df).select("features", "label")

# Step 4: Split data
train_df, test_df = final_df.randomSplit([0.8, 0.2], seed=42)

# Step 5: Train model
lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(train_df)

# Step 6: Predict & Evaluate
predictions = model.transform(test_df)

evaluator = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
auc = evaluator.evaluate(predictions)

print(f"AUC on test set: {auc:.4f}")

# Optional: Show predictions
predictions.select("label", "prediction", "probability").show(5)


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/10/26 18:08:27 WARN Utils: Your hostname, Mehdis-MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 192.168.1.101 instead (on interface en0)
25/10/26 18:08:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/26 18:08:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/10/26 18:08:40 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/10/26 18:08:42 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


AUC on test set: 0.8441
+-----+----------+--------------------+
|label|prediction|         probability|
+-----+----------+--------------------+
|    0|       0.0|[0.93402440734486...|
|    0|       0.0|[0.99646799893433...|
|    0|       0.0|[0.96521858609493...|
|    0|       0.0|[0.97931070942687...|
|    1|       0.0|[0.60724918287314...|
+-----+----------+--------------------+
only showing top 5 rows


Running a more advanced classifier, random forest in this case

In [2]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=100)
rf_model = rf.fit(train_df)

rf_predictions = rf_model.transform(test_df)

auc_rf = evaluator.evaluate(rf_predictions)
print(f"AUC (Random Forest) on test set: {auc_rf:.4f}")


AUC (Random Forest) on test set: 0.8460


Expanding the binary classification task to multi-class

In [3]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

accuracy_evaluator = MulticlassClassificationEvaluator(labelCol="label", metricName="accuracy")
acc = accuracy_evaluator.evaluate(predictions)
print(f"Accuracy: {acc:.4f}")


Accuracy: 0.7826


In [4]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Test").getOrCreate()
spark.range(5).show()


+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
+---+



25/10/26 18:10:24 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [5]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("Test") \
    .getOrCreate()

spark.range(5).show()


+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
+---+



In [8]:
# import urllib.request

# url = "https://raw.githubusercontent.com/datasciencedojo/datasets/master/Heart%20Disease/heart-disease.csv"
# local_file = "/tmp/heart.csv"

# urllib.request.urlretrieve(url, local_file)

# # Read with Spark
# df = spark.read.csv(local_file, header=True, inferSchema=True)
# df.show(5)


In [7]:
columns = [
    "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
    "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target"
]

file_path = "/Users/mehdiamian/Desktop/Sohrab/heart+disease/processed.cleveland.data"  # or "cleveland.csv" if you renamed

df = spark.read.csv(file_path, inferSchema=True)
df = df.toDF(*columns)

df.show(5)


+----+---+---+--------+-----+---+-------+-------+-----+-------+-----+---+----+------+
| age|sex| cp|trestbps| chol|fbs|restecg|thalach|exang|oldpeak|slope| ca|thal|target|
+----+---+---+--------+-----+---+-------+-------+-----+-------+-----+---+----+------+
|63.0|1.0|1.0|   145.0|233.0|1.0|    2.0|  150.0|  0.0|    2.3|  3.0|0.0| 6.0|     0|
|67.0|1.0|4.0|   160.0|286.0|0.0|    2.0|  108.0|  1.0|    1.5|  2.0|3.0| 3.0|     2|
|67.0|1.0|4.0|   120.0|229.0|0.0|    2.0|  129.0|  1.0|    2.6|  2.0|2.0| 7.0|     1|
|37.0|1.0|3.0|   130.0|250.0|0.0|    0.0|  187.0|  0.0|    3.5|  3.0|0.0| 3.0|     0|
|41.0|0.0|2.0|   130.0|204.0|0.0|    2.0|  172.0|  0.0|    1.4|  1.0|0.0| 3.0|     0|
+----+---+---+--------+-----+---+-------+-------+-----+-------+-----+---+----+------+
only showing top 5 rows
