# Random Forest

In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import RandomForestClassifier  # Import RandomForestClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import when

# Create a Spark session
spark = SparkSession.builder.appName("Random Forest Classifier").getOrCreate()

# Load the dataset
data = spark.read.csv(r"C:\Users\asus\Downloads\Ecommerce_Customers.csv", header=True, inferSchema=True)

# Create a binary label column ('Churn') based on 'Yearly Amount Spent'
data = data.withColumn("Churn", when(data["Yearly Amount Spent"] < 500, 1).otherwise(0))

# Assemble features
assembler = VectorAssembler(
    inputCols=['Avg Session Length', 'Time on App', 'Time on Website', 'Length of Membership'],
    outputCol='features'
)

# Prepare the features and label (target column for classification)
final_data = assembler.transform(data).select('features', 'Churn')

# Split the data into training and test sets
train_data, test_data = final_data.randomSplit([0.7, 0.3])

# Initialize Random Forest Classifier model
rf = RandomForestClassifier(labelCol='Churn', featuresCol='features')

# Fit the model
rf_model = rf.fit(train_data)

# Make predictions
pred_data = rf_model.transform(test_data)

# Evaluate the accuracy
evaluator = MulticlassClassificationEvaluator(labelCol='Churn', predictionCol='prediction', metricName='accuracy')
accuracy = evaluator.evaluate(pred_data)
print("Random Forest Model Accuracy:", accuracy)

# To see predictions
pred_data.select("features", "Churn", "prediction").show(5)

# Stop Spark session
spark.stop()


Random Forest Model Accuracy: 0.9197080291970803
+--------------------+-----+----------+
|            features|Churn|prediction|
+--------------------+-----+----------+
|[30.8364326747734...|    1|       1.0|
|[30.9716756438877...|    1|       0.0|
|[31.0472221394875...|    1|       1.0|
|[31.0662181616375...|    1|       1.0|
|[31.3123495994443...|    1|       1.0|
+--------------------+-----+----------+
only showing top 5 rows

