<a href="https://colab.research.google.com/github/lab-jianghao/spark_ml_sample/blob/main/03_(TF_INF)_nb_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!apt-get install openjdk-17-jdk-headless

!wget https://dlcdn.apache.org/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz /content
!tar xf spark-3.5.0-bin-hadoop3.tgz

In [29]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.5.0-bin-hadoop3"

In [None]:
!pip install pyspark==3.5.0

In [31]:
from pyspark.sql import SparkSession

spark = SparkSession.builder\
        .master("local[*]")\
        .appName("Colab")\
        .getOrCreate()

In [32]:
from functools import wraps

def spark_sql_initializer(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        spark = SparkSession.builder\
            .appName("Colab_DT")\
            .master("local[*]")\
            .getOrCreate()

        spark.sparkContext.setLogLevel("WARN")

        func(spark,*args, **kwargs)

        spark.stop()

    return wrapper

In [73]:
from pyspark.sql.functions import substring

from pyspark.ml import Pipeline
from pyspark.ml.feature import Tokenizer, HashingTF, IDF
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


@spark_sql_initializer
def train(spark, train_df, test_df):
    training_data = spark.createDataFrame(train_df)
    test_data = spark.createDataFrame(test_df)

    tokenizer = Tokenizer(inputCol="text", outputCol="words")
    hashingTF = HashingTF(inputCol="words", outputCol="features")
    idf = IDF(inputCol="features", outputCol="indexed_features")

    nb = NaiveBayes(featuresCol="indexed_features", labelCol="label", predictionCol="prediction")

    nb_pipeline = Pipeline(stages=[tokenizer, hashingTF, idf, nb])
    nb_model = nb_pipeline.fit(training_data)

    nb_predictions = nb_model.transform(test_data)
    evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
    accuracy = evaluator.evaluate(nb_predictions)
    print("Test Accuracy = {:.2%}".format(accuracy))

    nb_predictions = nb_predictions.withColumn("text_truncated", substring("text", 1, 100))

    nb_predictions.select("text_truncated", "label", "prediction").show(truncate=False)


In [74]:
from sklearn.datasets import fetch_20newsgroups

import pandas as pd


get_data_and_labels = lambda subset: fetch_20newsgroups(
    subset=subset, remove=('headers', 'footers'), return_X_y=True)

train_data, train_labels = get_data_and_labels('train')
test_data, test_labels = get_data_and_labels('test')

train_df = pd.DataFrame({"text": train_data, "label": train_labels})
test_df = pd.DataFrame({"text": test_data, "label": test_labels})


train(train_df, test_df)

Test Accuracy = 75.09%
+-----------------------------------------------------------------------------------------------------------+-----+----------+
|text_truncated                                                                                             |label|prediction|
+-----------------------------------------------------------------------------------------------------------+-----+----------+
|I am a little confused on all of the models of the 88-89 bonnevilles.\nI have heard of the LE SE LSE       |7    |7.0       |
|I'm not familiar at all with the format of these "X-Face:" thingies, but\nafter seeing them in some f      |5    |1.0       |
|acooper@mac.cc.macalstr.edu (Turin Turambar, ME Department of Utter Misery) writes:\n> Did that FAQ e      |0    |0.0       |
|In article <benali.737307554@alcor> benali@alcor.concordia.ca ( ILYESS B. BDIRA ) writes:\n>It looks       |17   |17.0      |
|In article <1993Apr21.141259.12012@st-andrews.ac.uk>, nrp@st-andrews.ac.uk (Norman R. P