<a href="https://colab.research.google.com/github/lab-jianghao/spark_ml_sample/blob/main/02_rf_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!apt-get install openjdk-17-jdk-headless

!wget https://dlcdn.apache.org/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz /content
!tar xf spark-3.5.0-bin-hadoop3.tgz

In [2]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.5.0-bin-hadoop3"

In [None]:
!pip install pyspark==3.5.0

In [4]:
from pyspark.sql import SparkSession

spark = SparkSession.builder\
        .master("local[*]")\
        .appName("Colab")\
        .getOrCreate()

In [5]:
from functools import wraps

def spark_sql_initializer(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        spark = SparkSession.builder\
            .appName("Colab_DT")\
            .master("local[*]")\
            .getOrCreate()

        spark.sparkContext.setLogLevel("WARN")

        func(spark,*args, **kwargs)

        spark.stop()

    return wrapper

In [6]:
import seaborn as sns

penguins = sns.load_dataset("penguins")
penguins.head()

Unnamed: 0,species,island,bill_length_mm,bill_depth_mm,flipper_length_mm,body_mass_g,sex
0,Adelie,Torgersen,39.1,18.7,181.0,3750.0,Male
1,Adelie,Torgersen,39.5,17.4,186.0,3800.0,Female
2,Adelie,Torgersen,40.3,18.0,195.0,3250.0,Female
3,Adelie,Torgersen,,,,,
4,Adelie,Torgersen,36.7,19.3,193.0,3450.0,Female


In [18]:
from functools import reduce

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import Imputer, StringIndexer, VectorAssembler, IndexToString
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


from pyspark.sql.functions import col

@spark_sql_initializer
def train_with_NaN(spark, df):
    penguins_df = spark.createDataFrame(df)
    # penguins_df.show()

    index_columns = penguins_df.columns[:6]
    penguins_df = penguins_df.withColumnRenamed("sex", "label")

    feature_indexes = [StringIndexer(inputCol=col, outputCol=f"indexed_{col}", handleInvalid="keep") for col in index_columns[:2]]
    for indexer in feature_indexes:
        penguins_df = indexer.fit(penguins_df).transform(penguins_df)

    feature_imputers = [Imputer(inputCol=col, outputCol=f"indexed_{col}", strategy="median") for col in index_columns[2:]]
    for imputer in feature_imputers:
        penguins_df = imputer.fit(penguins_df).transform(penguins_df)

    vector_assembler = VectorAssembler(
        inputCols=list(map(lambda idx_col: f"indexed_{idx_col}", index_columns)), outputCol="features")
    # penguins_df = vector_assembler.transform(penguins_df)
    penguins_df = penguins_df.drop(*list(map(lambda idx_col: f"indexed_{idx_col}", index_columns)))
    penguins_df.show()

    label_indexer = StringIndexer(inputCol="label", outputCol="indexed_label").fit(penguins_df)
    label_converter = IndexToString(inputCol="prediction", outputCol="predictedLabel", labels=label_indexer.labels)

    rf_classifier = RandomForestClassifier(labelCol="indexed_label", featuresCol="features", numTrees=10)

    training_data, test_data = penguins_df.randomSplit([0.7, 0.3])

    dt_pipeline = Pipeline(
        stages=feature_indexes + feature_imputers + [vector_assembler, label_indexer, rf_classifier, label_converter])
    dt_model = dt_pipeline.fit(training_data)

    dt_prediction = dt_model.transform(test_data)
    evaluator = MulticlassClassificationEvaluator(
        labelCol="indexed_label", predictionCol="prediction", metricName="accuracy")

    print("Test Error = %g " % (1.0 - evaluator.evaluate(dt_prediction)))

    dt_prediction.select(*index_columns, "label", "predictedLabel").show()

    dt_model.write().overwrite().save("file:///content/model/RandomForest")


In [19]:
train_with_NaN(penguins)

+-------+---------+--------------+-------------+-----------------+-----------+------+
|species|   island|bill_length_mm|bill_depth_mm|flipper_length_mm|body_mass_g| label|
+-------+---------+--------------+-------------+-----------------+-----------+------+
| Adelie|Torgersen|          39.1|         18.7|            181.0|     3750.0|  Male|
| Adelie|Torgersen|          39.5|         17.4|            186.0|     3800.0|Female|
| Adelie|Torgersen|          40.3|         18.0|            195.0|     3250.0|Female|
| Adelie|Torgersen|           NaN|          NaN|              NaN|        NaN|   NaN|
| Adelie|Torgersen|          36.7|         19.3|            193.0|     3450.0|Female|
| Adelie|Torgersen|          39.3|         20.6|            190.0|     3650.0|  Male|
| Adelie|Torgersen|          38.9|         17.8|            181.0|     3625.0|Female|
| Adelie|Torgersen|          39.2|         19.6|            195.0|     4675.0|  Male|
| Adelie|Torgersen|          34.1|         18.1|      