In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')
# Find the latest version of spark 3.x  from http://www.apache.org/dist/spark/ and enter as the spark version
# For example:
# spark_version = 'spark-3.5.5'
spark_version = 'spark-3.5.5'
os.environ['SPARK_VERSION']=spark_version

# Install Spark and Java
!apt-get update
!apt-get install openjdk-11-jdk-headless -qq > /dev/null
!wget -q http://www.apache.org/dist/spark/$SPARK_VERSION/$SPARK_VERSION-bin-hadoop3.tgz
!tar xf $SPARK_VERSION-bin-hadoop3.tgz
!pip install -q findspark

# Set Environment Variables
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = f"/content/{spark_version}-bin-hadoop3"

# Start a SparkSession
import findspark
findspark.init()


0% [Working]            Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
0% [Waiting for headers] [Waiting for headers] [Connected to r2u.stat.illinois.                                                                               Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
0% [Waiting for headers] [Waiting for headers] [Connected to r2u.stat.illinois.                                                                               Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
                                                                               Hit:4 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.l

In [None]:
# Import packages
from pyspark.sql import SparkSession
import time

# Create a SparkSession
spark = SparkSession.builder.appName("Mental Health Analysis").getOrCreate()


In [None]:
csv_path = "mental_health_dataset/Mental Health Dataset.csv"  # Adjust filename if different

df = spark.read.csv(csv_path, header=True, inferSchema=True)
df.show(5)
df.printSchema()


+-------------------+------+-------------+----------+-------------+--------------+---------+------------+--------------+--------------+---------------------+-----------+----------------+-------------+---------------+-----------------------+------------+
|          Timestamp|Gender|      Country|Occupation|self_employed|family_history|treatment|Days_Indoors|Growing_Stress|Changes_Habits|Mental_Health_History|Mood_Swings|Coping_Struggles|Work_Interest|Social_Weakness|mental_health_interview|care_options|
+-------------------+------+-------------+----------+-------------+--------------+---------+------------+--------------+--------------+---------------------+-----------+----------------+-------------+---------------+-----------------------+------------+
|2014-08-27 11:29:31|Female|United States| Corporate|         NULL|            No|      Yes|   1-14 days|           Yes|            No|                  Yes|     Medium|              No|           No|            Yes|                     N

# **Data Preprocessing**

Drop Timestamp - Not useful for modeling treatment

In [None]:
# Count the number of rows
print(f"Total rows: {df.count()}")

# Check for missing values in key columns
df.summary().show()

Total rows: 292364
+-------+------+-------------+----------+-------------+--------------+---------+------------------+--------------+--------------+---------------------+-----------+----------------+-------------+---------------+-----------------------+------------+
|summary|Gender|      Country|Occupation|self_employed|family_history|treatment|      Days_Indoors|Growing_Stress|Changes_Habits|Mental_Health_History|Mood_Swings|Coping_Struggles|Work_Interest|Social_Weakness|mental_health_interview|care_options|
+-------+------+-------------+----------+-------------+--------------+---------+------------------+--------------+--------------+---------------------+-----------+----------------+-------------+---------------+-----------------------+------------+
|  count|292364|       292364|    292364|       287162|        292364|   292364|            292364|        292364|        292364|               292364|     292364|          292364|       292364|         292364|                 292364|   

In [None]:
from pyspark.sql.functions import isnull, count, when

# Check for missing values in each column
df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()

+---------+------+-------+----------+-------------+--------------+---------+------------+--------------+--------------+---------------------+-----------+----------------+-------------+---------------+-----------------------+------------+
|Timestamp|Gender|Country|Occupation|self_employed|family_history|treatment|Days_Indoors|Growing_Stress|Changes_Habits|Mental_Health_History|Mood_Swings|Coping_Struggles|Work_Interest|Social_Weakness|mental_health_interview|care_options|
+---------+------+-------+----------+-------------+--------------+---------+------------+--------------+--------------+---------------------+-----------+----------------+-------------+---------------+-----------------------+------------+
|        0|     0|      0|         0|         5202|             0|        0|           0|             0|             0|                    0|          0|               0|            0|              0|                      0|           0|
+---------+------+-------+----------+-----------

In [None]:
df = df.dropna(subset=['self_employed'])
df.count()

287162

In [None]:
# from pyspark.sql.functions import count, col

# # Group by all columns and count occurrences
# duplicate_counts = df.groupBy(df.columns).agg(count("*").alias("count"))

# # Filter for rows with count > 1 (duplicates)
# duplicate_rows = duplicate_counts.filter(col("count") > 1)

# # Show the duplicate rows and their counts
# duplicate_rows.show()
df = df.dropDuplicates()
df.count()

# Optionally, you can drop the "count" column if you only need the duplicate rows:
# duplicate_rows = duplicate_rows.drop("count")
# duplicate_rows.show()

286808

In [None]:
df = df.drop("Timestamp")
df.count()

286808

# **1. Using LOGISTIC REGRESSION MODEL**

# **Define Target and Categorical Columns**

# LOGISTIC REGRESSION

In [None]:
target_col = "treatment"
categorical_cols = [col for col in df.columns if col != target_col]


# **Encode Categorical Variables**

In [None]:
from pyspark.ml.feature import StringIndexer, VectorAssembler

# Create StringIndexer for each categorical column
indexers = [
    StringIndexer(inputCol=column, outputCol=column + "_index", handleInvalid='keep')
    for column in categorical_cols + [target_col]
]


# **Assemble Features into One Vector**

In [None]:
assembler = VectorAssembler(
    inputCols=[col + "_index" for col in categorical_cols],
    outputCol="features"
)


# **1. Define the Logistic Regression Model**

In [None]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
lr = LogisticRegression(labelCol="treatment_index", featuresCol="features")


# **Build a Pipeline**

In [None]:
pipeline = Pipeline(stages=indexers + [assembler, lr])

# **Split the Dataset**

In [None]:
train_data, test_data = df.randomSplit([0.8, 0.2], seed=42)

# **Train the Model**

In [None]:
model = pipeline.fit(train_data)


# **Make Predictions**

In [None]:
predictions = model.transform(test_data)


# **Evaluate Accuracy**

In [None]:
predictions.select("treatment_index").distinct().show()

+---------------+
|treatment_index|
+---------------+
|            0.0|
|            1.0|
+---------------+



# **Evaluate Accuracy**

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Use 'areaUnderROC' or 'areaUnderPR' for raw predictions
evaluator_acc = BinaryClassificationEvaluator(
    labelCol="treatment_index",
    rawPredictionCol="prediction",  # Using rawPredictionCol
    metricName="areaUnderROC"  # Changed metricName to 'areaUnderROC'
)

accuracy = evaluator_acc.evaluate(predictions)

# **R2 Score**

In [None]:
from pyspark.ml.evaluation import RegressionEvaluator # Import RegressionEvaluator

r2_evaluator = RegressionEvaluator(
    labelCol="treatment_index", predictionCol="prediction", metricName="r2"
)
r2_score = r2_evaluator.evaluate(predictions)

# **Final Results**

In [None]:
print(f"Logistic Regression Accuracy (AUC): {accuracy:.4f}")
print(f"R2 Score: {r2_score:.4f}")


Logistic Regression Accuracy (AUC): 0.6816
R2 Score: -0.2755


# **2. Using RANDOM FOREST MODEL**

In [None]:
#RANDOM fOREST
#  4. Convert target 'treatment' column to numerical
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.functions import vector_to_array
from pyspark.sql.functions import col
label_indexer = StringIndexer(inputCol="treatment", outputCol="label")

# 5. Identify categorical features
categorical_cols = [col for col in df.columns if col != "treatment"]

# Index categorical columns
indexers = [StringIndexer(inputCol=col, outputCol=col + "_Index", handleInvalid='keep') for col in categorical_cols]

# Prepare encoded column names
indexed_cols = [col + "_Index" for col in categorical_cols]
encoded_cols = [col + "_Vec" for col in indexed_cols]

# One-hot encode
encoder = OneHotEncoder(inputCols=indexed_cols, outputCols=encoded_cols)

# Assemble all features into one vector
assembler = VectorAssembler(inputCols=encoded_cols, outputCol="features")

# 8. Random Forest Classifier
rf = RandomForestClassifier(labelCol="label", featuresCol="features", numTrees=100)

# 9. Build pipeline
pipeline = Pipeline(stages=indexers + [encoder, assembler, label_indexer, rf])

# 10. Split data
train_data, test_data = df.randomSplit([0.8, 0.2], seed=42)

# 11. Train the model
model = pipeline.fit(train_data)

# 12. Make predictions
predictions = model.transform(test_data)

# 13. Evaluate Accuracy
accuracy_evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = accuracy_evaluator.evaluate(predictions)
print(f"Accuracy: {accuracy:.4f}")

# 14. Evaluate R2 Score using Binary Classification Evaluator (AUC is better, but for R2 approximation...)
from pyspark.ml.evaluation import RegressionEvaluator

# Convert vector to array and extract probability of class 1
preds_with_prob = predictions.withColumn("probability_array", vector_to_array("probability"))
preds_with_prob = preds_with_prob.withColumn("prob_class_1", col("probability_array")[1])
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType

# UDF to extract the second element (probability of class 1)
get_prob = udf(lambda v: float(v[1]), DoubleType())

# Apply it to extract the predicted probability for class 1
preds_with_prob = predictions.withColumn("prob_class_1", get_prob(predictions["probability"]))

reg_eval = RegressionEvaluator(labelCol="label", predictionCol="prob_class_1", metricName="r2")
r2 = reg_eval.evaluate(preds_with_prob)
print(f"R2 Score: {r2:.4f}")

from pyspark.ml.evaluation import BinaryClassificationEvaluator

auc_evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
auc = auc_evaluator.evaluate(predictions)
print(f"AUC: {auc:.4f}")


Accuracy: 0.7206
R2 Score: 0.2144
AUC: 0.7922


# **3. Using CATBOOST MODEL**

In [None]:
# CATBOOST
pip install catboost


Collecting catboost
  Downloading catboost-1.2.7-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting numpy<2.0,>=1.16.0 (from catboost)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading catboost-1.2.7-cp311-cp311-manylinux2014_x86_64.whl (98.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.7/98.7 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m80.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy, catboost
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip

In [None]:
!pip install --upgrade --force-reinstall numpy


Collecting numpy
  Downloading numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/62.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m57.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
catboost 1.2.7 requires numpy<2.0,>=1.16.0, but you hav

In [None]:
!pip install --upgrade --force-reinstall catboost


Collecting catboost
  Using cached catboost-1.2.7-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting graphviz (from catboost)
  Downloading graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)
Collecting matplotlib (from catboost)
  Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting numpy<2.0,>=1.16.0 (from catboost)
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting pandas>=0.24 (from catboost)
  Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy (from catboost)
  Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m4.5 MB/s

In [None]:
from pyspark.sql.functions import pandas_udf, PandasUDFType, col as pyspark_col, when
from pyspark.sql.types import StringType  # Change to StringType
import catboost as cb
import numpy as np
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
import pandas as pd # Import pandas and alias it as 'pd'

categorical_columns = ['Gender', 'Country', 'Occupation', 'self_employed', 'family_history',
                       'Days_Indoors', 'Growing_Stress', 'Changes_Habits', 'Mental_Health_History',
                       'Mood_Swings', 'Coping_Struggles', 'Work_Interest', 'Social_Weakness',
                       'mental_health_interview', 'care_options']

# Define the target and feature columns
target_col = "treatment"
feature_cols = [col for col in categorical_columns if col != target_col]

# Index the categorical features and label
indexers = [StringIndexer(inputCol=column, outputCol=column+"_Index", handleInvalid="keep") for column in feature_cols]
label_indexer = StringIndexer(inputCol=target_col, outputCol="label", handleInvalid="keep")

# Create the pipeline for indexing
indexing_pipeline = Pipeline(stages=indexers + [label_indexer])
indexed_df = indexing_pipeline.fit(df).transform(df)

# Convert Spark DataFrame to Pandas DataFrame for CatBoost
pandas_df = indexed_df.select([target_col, *[col+"_Index" for col in feature_cols]]).toPandas()

# Cast all columns in the Pandas dataframe to int
# This is necessary to ensure CatBoost interprets them as categorical features
for col in pandas_df.columns:
    if col != target_col:  # Exclude target column
        pandas_df[col] = pandas_df[col].astype(int)

# Define CatBoost model
catboost_model = cb.CatBoostClassifier(iterations=100,  # Adjust parameters as needed
                                      learning_rate=0.1,
                                      depth=6,
                                      loss_function='MultiClass',
                                      verbose=False)

# Define features and target for CatBoost
X = pandas_df[[col+"_Index" for col in feature_cols]]
y = pandas_df[target_col]

# Fit the CatBoost model
catboost_model.fit(X, y, cat_features=list(X.columns))  # Specify categorical features
# assembler = VectorAssembler(inputCols = [col+"_Index" for col in feature_cols], outputCol = "features")
# assembled_df = assembler.transform(indexed_df)
# Create a Pandas UDF for prediction

# Create a Pandas UDF for prediction
@pandas_udf(returnType=StringType(), functionType=PandasUDFType.SCALAR)  # Change to StringType
def predict_udf(*cols) -> pd.Series:
    features = pd.DataFrame(list(zip(*cols)), columns=[col+"_Index" for col in feature_cols])
    for col in features.columns:
        features[col] = features[col].astype(int)
    predictions = catboost_model.predict(features)
    return pd.Series(predictions.flatten())

indexed_df = indexed_df.withColumn("prediction", predict_udf(
    *[indexed_df[col+"_Index"] for col in feature_cols]
))

from pyspark.sql.functions import pandas_udf, PandasUDFType, col, when
from pyspark.sql.types import StringType  # Change to StringType
import catboost as cb
import numpy as np
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
import pandas as pd # Import pandas and alias it as 'pd'

# Define the target and feature columns
target_col = "treatment"
feature_cols = [col for col in categorical_columns if col != target_col]

# Index the categorical features and label
indexers = [StringIndexer(inputCol=column, outputCol=column+"_Index", handleInvalid="keep") for column in feature_cols]
label_indexer = StringIndexer(inputCol=target_col, outputCol="label", handleInvalid="keep")

# Create the pipeline for indexing
indexing_pipeline = Pipeline(stages=indexers + [label_indexer])
indexed_df = indexing_pipeline.fit(df).transform(df)

# Convert Spark DataFrame to Pandas DataFrame for CatBoost
pandas_df = indexed_df.select([target_col, *[col+"_Index" for col in feature_cols]]).toPandas()

# Cast all columns in the Pandas dataframe to int
# This is necessary to ensure CatBoost interprets them as categorical features
for col in pandas_df.columns:
    if col != target_col:  # Exclude target column
        pandas_df[col] = pandas_df[col].astype(int)

# Define CatBoost model
catboost_model = cb.CatBoostClassifier(iterations=100,  # Adjust parameters as needed
                                      learning_rate=0.1,
                                      depth=6,
                                      loss_function='MultiClass',
                                      verbose=False)

# Define features and target for CatBoost
X = pandas_df[[col+"_Index" for col in feature_cols]]
y = pandas_df[target_col]

# Fit the CatBoost model
catboost_model.fit(X, y, cat_features=list(X.columns))  # Specify categorical features
# assembler = VectorAssembler(inputCols = [col+"_Index" for col in feature_cols], outputCol = "features")
# assembled_df = assembler.transform(indexed_df)
# Create a Pandas UDF for prediction

# Create a Pandas UDF for prediction
@pandas_udf(returnType=StringType(), functionType=PandasUDFType.SCALAR)  # Change to StringType
def predict_udf(*cols) -> pd.Series:
    features = pd.DataFrame(list(zip(*cols)), columns=[col+"_Index" for col in feature_cols])
    for col in features.columns:
        features[col] = features[col].astype(int)
    predictions = catboost_model.predict(features)
    return pd.Series(predictions.flatten()).astype(str)

indexed_df = indexed_df.withColumn("prediction", predict_udf(
    *[indexed_df[col+"_Index"] for col in feature_cols]
))


# # Evaluate the model (you might need to adjust this part based on your evaluation metrics)
accuracy = indexed_df.filter(pyspark_col("treatment").isNotNull() & pyspark_col("prediction").isNotNull()).withColumn("correct", when(pyspark_col("Growing_Stress") == pyspark_col("prediction"), 1).otherwise(0)).selectExpr("avg(correct) as accuracy").first()["accuracy"]
print(f"Accuracy: {accuracy}")
# # Show some predictions
indexed_df.select("treatment", "prediction").show(10)



Accuracy: 0.3337633538813422
+---------+----------+
|treatment|prediction|
+---------+----------+
|      Yes|       Yes|
|      Yes|       Yes|
|      Yes|       Yes|
|      Yes|       Yes|
|       No|       Yes|
|      Yes|       Yes|
|      Yes|       Yes|
|      Yes|       Yes|
|      Yes|       Yes|
|       No|       Yes|
+---------+----------+
only showing top 10 rows



# **4. Using XGBOOST Model**

In [None]:
# XGBOOST from pyspark.sql.functions import pandas_udf, PandasUDFType, col as pyspark_col, when
from pyspark.sql.types import StringType
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
import pandas as pd
import xgboost as xgb
import numpy as np

# Define categorical columns
categorical_columns = ['Gender', 'Country', 'Occupation', 'self_employed', 'family_history',
                       'Days_Indoors', 'Growing_Stress', 'Changes_Habits', 'Mental_Health_History',
                       'Mood_Swings', 'Coping_Struggles', 'Work_Interest', 'Social_Weakness',
                       'mental_health_interview', 'care_options']

# Define target and features
target_col = "treatment"
feature_cols = [col for col in categorical_columns if col != target_col]

# Index categorical columns
indexers = [StringIndexer(inputCol=column, outputCol=column+"_Index", handleInvalid="keep") for column in feature_cols]
label_indexer = StringIndexer(inputCol=target_col, outputCol="label", handleInvalid="keep")

# Create pipeline
indexing_pipeline = Pipeline(stages=indexers + [label_indexer])
indexed_df = indexing_pipeline.fit(df).transform(df)

# Convert to pandas
pandas_df = indexed_df.select([target_col, *[col+"_Index" for col in feature_cols]]).toPandas()

# Ensure all features are numeric
for col in pandas_df.columns:
    if col != target_col:
        pandas_df[col] = pandas_df[col].astype(int)

# Prepare X and y
X = pandas_df[[col+"_Index" for col in feature_cols]]
y = pandas_df[target_col].astype('category').cat.codes  # Convert target to numeric if not already

# Train XGBoost model
xgb_model = xgb.XGBClassifier(n_estimators=100, learning_rate=0.1, max_depth=6, use_label_encoder=False, eval_metric='mlogloss')
xgb_model.fit(X, y)

# Predict function using pandas_udf
@pandas_udf(returnType=StringType(), functionType=PandasUDFType.SCALAR)
def predict_udf(*cols) -> pd.Series:
    features = pd.DataFrame(list(zip(*cols)), columns=[col+"_Index" for col in feature_cols])
    for col in features.columns:
        features[col] = features[col].astype(int)
    predictions = xgb_model.predict(features)
    return pd.Series(predictions.astype(str))

# Apply prediction to Spark DataFrame
indexed_df = indexed_df.withColumn("prediction", predict_udf(*[indexed_df[col+"_Index"] for col in feature_cols]))

# Evaluation (accuracy)
accuracy = indexed_df.filter(
    pyspark_col("treatment").isNotNull() & pyspark_col("prediction").isNotNull()
).withColumn(
    "correct", when(pyspark_col("label") == pyspark_col("prediction").cast("double"), 1).otherwise(0)
).selectExpr("avg(correct) as accuracy").first()["accuracy"]

print(f"Accuracy: {accuracy:.4f}")

# Show predictions
indexed_df.select("treatment", "prediction").show(10)


Parameters: { "use_label_encoder" } are not used.



Accuracy: 0.2163
+---------+----------+
|treatment|prediction|
+---------+----------+
|      Yes|         1|
|      Yes|         1|
|      Yes|         1|
|      Yes|         1|
|       No|         1|
|      Yes|         1|
|      Yes|         1|
|      Yes|         1|
|      Yes|         1|
|       No|         1|
+---------+----------+
only showing top 10 rows

