# Training the Decision Tree Model

We train a multiclass Decision Tree Classifier using PySpark to predict drug response based on patient data.


In [10]:
from pyspark.ml.feature import StringIndexer

# Index categorical features and label
sex_indexer = StringIndexer(inputCol="Sex", outputCol="Sex_idx")
bp_indexer = StringIndexer(inputCol="BP", outputCol="BP_idx")
chol_indexer = StringIndexer(inputCol="Cholesterol", outputCol="Cholesterol_idx")
label_indexer = StringIndexer(inputCol="Drug", outputCol="label")


In [15]:
from pyspark.ml.feature import VectorAssembler

# Combine all features into one vector
assembler = VectorAssembler(
    inputCols=["Age", "Na_to_K", "Sex_idx", "BP_idx", "Cholesterol_idx"],
    outputCol="features"
)


In [19]:
from pyspark.ml import Pipeline

# Define pipeline
pipeline = Pipeline(stages=[sex_indexer, bp_indexer, chol_indexer, label_indexer, assembler])

# Fit and transform the data
data = pipeline.fit(df).transform(df)

# Show processed features
data.select("features", "label").show(5, truncate=False)


+-------------------------+-----+
|features                 |label|
+-------------------------+-----+
|[23.0,25.355,1.0,0.0,0.0]|0.0  |
|[47.0,13.093,0.0,1.0,0.0]|4.0  |
|[47.0,10.114,0.0,1.0,0.0]|4.0  |
|[28.0,7.798,1.0,2.0,0.0] |1.0  |
|[61.0,18.043,1.0,1.0,0.0]|0.0  |
+-------------------------+-----+
only showing top 5 rows



In [22]:
train_data, test_data = data.randomSplit([0.7, 0.3], seed=42)
train_data.show(2)

+---+---+----+-----------+-------+-----+-------+------+---------------+-----+--------------------+
|Age|Sex|  BP|Cholesterol|Na_to_K| Drug|Sex_idx|BP_idx|Cholesterol_idx|label|            features|
+---+---+----+-----------+-------+-----+-------+------+---------------+-----+--------------------+
| 15|  F|HIGH|     NORMAL| 16.725|drugY|    1.0|   0.0|            1.0|  0.0|[15.0,16.725,1.0,...|
| 15|  M|HIGH|     NORMAL| 17.206|drugY|    0.0|   0.0|            1.0|  0.0|[15.0,17.206,0.0,...|
+---+---+----+-----------+-------+-----+-------+------+---------------+-----+--------------------+
only showing top 2 rows



In [23]:
from pyspark.ml.classification import DecisionTreeClassifier

dt = DecisionTreeClassifier(featuresCol="features", labelCol="label")
model = dt.fit(train_data)


In [24]:
model.save("../models/decision_tree_model")


### Summary

- Trained a Decision Tree on patient drug response data.
- Encoded categorical variables using StringIndexer.
- Assembled input features and saved trained model to disk.
