# Lab 06: Prediction with Decision Trees

# Q1
### Loading Dataset

In [7]:
from pyspark.sql import SparkSession

# Create a Spark session
spark = SparkSession.builder.appName("IncomePrediction").getOrCreate()

# Load dataset into PySpark DataFrame
df = spark.read.csv("adult.csv", header=False, inferSchema=True)

# Rename columns for better understanding
df = df.withColumnRenamed("_c0", "age")\
       .withColumnRenamed("_c1", "workclass")\
       .withColumnRenamed("_c2", "fnlwgt")\
       .withColumnRenamed("_c3", "education")\
       .withColumnRenamed("_c4", "education_num")\
       .withColumnRenamed("_c5", "marital_status")\
       .withColumnRenamed("_c6", "occupation")\
       .withColumnRenamed("_c7", "relationship")\
       .withColumnRenamed("_c8", "race")\
       .withColumnRenamed("_c9", "sex")\
       .withColumnRenamed("_c10", "capital_gain")\
       .withColumnRenamed("_c11", "capital_loss")\
       .withColumnRenamed("_c12", "hours_per_week")\
       .withColumnRenamed("_c13", "native_country")\
       .withColumnRenamed("_c14", "income")

# Show basic statistics and information
df.printSchema()
df.describe().show()

root
 |-- age: integer (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: double (nullable = true)
 |-- education: string (nullable = true)
 |-- education_num: double (nullable = true)
 |-- marital_status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital_gain: double (nullable = true)
 |-- capital_loss: double (nullable = true)
 |-- hours_per_week: double (nullable = true)
 |-- native_country: string (nullable = true)
 |-- income: string (nullable = true)

+-------+------------------+------------+------------------+-------------+-----------------+--------------+-----------------+------------+-------------------+-------+------------------+----------------+------------------+--------------+------+
|summary|               age|   workclass|            fnlwgt|    education|    education_num|marital_status|       occupation|rel

# Q2
### Checking Missing Values

In [8]:
from pyspark.sql.functions import col
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline

# Handling missing values
# Dropping rows with missing values
df_clean = df.dropna()

# Handling categorical features using StringIndexer
categorical_columns = ['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country', 'income']
indexers = [StringIndexer(inputCol=column, outputCol=column+"_index") for column in categorical_columns]

# Assemble features into a feature vector
assembler = VectorAssembler(
    inputCols=['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week'] + [col+"_index" for col in categorical_columns[:-1]],
    outputCol='features'
)

# Create a pipeline for preprocessing
pipeline = Pipeline(stages=indexers + [assembler])
df_prepared = pipeline.fit(df_clean).transform(df_clean)

# Show the prepared dataset with features column
df_prepared.select("features", "income_index").show(5)


+--------------------+------------+
|            features|income_index|
+--------------------+------------+
|[39.0,77516.0,13....|         0.0|
|(14,[0,1,2,5,6,7,...|         0.0|
|(14,[0,1,2,5,8,9,...|         0.0|
|(14,[0,1,2,5,7,9,...|         0.0|
|[28.0,338409.0,13...|         0.0|
+--------------------+------------+
only showing top 5 rows



# Q3
### Training Model

In [9]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Split the data into training and testing sets
train_data, test_data = df_prepared.randomSplit([0.8, 0.2], seed=42)

# Create and train the Decision Tree model
dt = DecisionTreeClassifier(labelCol="income_index", featuresCol="features", maxBins=45)
dt_model = dt.fit(train_data)

# Make predictions on the test data
predictions = dt_model.transform(test_data)

# Show sample predictions
predictions.select("features", "income_index", "prediction").show(5)


+--------------------+------------+----------+
|            features|income_index|prediction|
+--------------------+------------+----------+
|[17.0,41643.0,7.0...|         0.0|       0.0|
|[17.0,64785.0,6.0...|         0.0|       0.0|
|[17.0,80077.0,7.0...|         0.0|       0.0|
|[17.0,104025.0,7....|         0.0|       0.0|
|[17.0,139183.0,6....|         0.0|       0.0|
+--------------------+------------+----------+
only showing top 5 rows



# Q4
### Performance Check

In [10]:
# Calculate accuracy
accuracy_evaluator = MulticlassClassificationEvaluator(labelCol="income_index", predictionCol="prediction", metricName="accuracy")
accuracy = accuracy_evaluator.evaluate(predictions)

# Calculate precision
precision_evaluator = MulticlassClassificationEvaluator(labelCol="income_index", predictionCol="prediction", metricName="weightedPrecision")
precision = precision_evaluator.evaluate(predictions)

# Calculate recall
recall_evaluator = MulticlassClassificationEvaluator(labelCol="income_index", predictionCol="prediction", metricName="weightedRecall")
recall = recall_evaluator.evaluate(predictions)

# Print accuracy, precision, and recall
print(f"Test Accuracy: {accuracy:.2f}")
print(f"Test Precision: {precision:.2f}")
print(f"Test Recall: {recall:.2f}")


Test Accuracy: 0.84
Test Precision: 0.83
Test Recall: 0.84
