In [1]:
from pyspark.sql import Row


In [2]:
def load_dataframe(path):
    rdd = sc.textFile(path)\
        .map(lambda line: line.split())\
        .map(lambda words: Row(label = words[0], words = words[1:]))
    return spark.createDataFrame(rdd)

In [3]:
train_data = load_dataframe("20ng-train-all-terms.txt")
test_data = load_dataframe("20ng-test-all-terms.txt")

                                                                                

In [4]:
from pyspark.ml.feature import CountVectorizer
vectorizer = CountVectorizer(
    inputCol="words", 
    outputCol = "bag_of_words"
)

In [5]:
vectorizer_transformer = vectorizer.fit(train_data)

                                                                                

In [6]:
train_bag_of_words = vectorizer_transformer.transform(train_data)
test_bag_of_words = vectorizer_transformer.transform(train_data)

In [7]:
train_data.select("label")\
    .distinct()\
    .sort("label")\
    .show(truncate=False)

[Stage 6:>                                                          (0 + 1) / 1]

+------------------------+
|label                   |
+------------------------+
|alt.atheism             |
|comp.graphics           |
|comp.os.ms-windows.misc |
|comp.sys.ibm.pc.hardware|
|comp.sys.mac.hardware   |
|comp.windows.x          |
|misc.forsale            |
|rec.autos               |
|rec.motorcycles         |
|rec.sport.baseball      |
|rec.sport.hockey        |
|sci.crypt               |
|sci.electronics         |
|sci.med                 |
|sci.space               |
|soc.religion.christian  |
|talk.politics.guns      |
|talk.politics.mideast   |
|talk.politics.misc      |
|talk.religion.misc      |
+------------------------+



                                                                                

In [8]:
from pyspark.ml.feature import StringIndexer

label_indexer = StringIndexer(
    inputCol="label", 
    outputCol="label_index"
)
label_indexer_transformer = label_indexer.fit(train_bag_of_words)


train_bag_of_words = label_indexer_transformer.transform(train_bag_of_words)
test_bag_of_words = label_indexer_transformer.transform(test_bag_of_words)

                                                                                

In [18]:
from pyspark.ml.classification import NaiveBayes

classifier = NaiveBayes(
    labelCol="label_index",
    featuresCol="bag_of_words",
    predictionCol="label_index_predicted"
)

classifier_transformer = classifier.fit(train_bag_of_words)

                                                                                

In [20]:
test_predicted = classifier_transformer.transform(train_bag_of_words)

In [21]:
test_predicted.select("label_index", "label_index_predicted")\
    .limit(10)\
    .show()

22/11/03 03:05:24 WARN DAGScheduler: Broadcasting large task binary with size 12.0 MiB



[Stage 27:>                                                         (0 + 1) / 1]

22/11/03 03:05:28 WARN PythonRunner: Detected deadlock while completing task 0.0 in stage 27 (TID 20): Attempting to kill Python Worker
+-----------+---------------------+
|label_index|label_index_predicted|
+-----------+---------------------+
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
+-----------+---------------------+




                                                                                

In [22]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(
    labelCol="label_index", 
    predictionCol="label_index_predicted", 
    metricName="accuracy"
)
accuracy = evaluator.evaluate(test_predicted)
print("Accuracy = {:.2f}".format(accuracy))

22/11/03 03:07:02 WARN DAGScheduler: Broadcasting large task binary with size 12.0 MiB



[Stage 30:>                                                         (0 + 1) / 1]

Accuracy = 0.96



                                                                                

In [24]:
vectorizer = CountVectorizer(inputCol="words", outputCol="bag_of_words")
label_indexer = StringIndexer(inputCol="label", outputCol="label_index")
classifier = NaiveBayes(
    labelCol="label_index", featuresCol="bag_of_words", predictionCol="label_index_predicted",
)
pipeline = Pipeline(stages=[vectorizer, label_indexer, classifier])
pipeline_model = pipeline.fit(train_data)

test_predicted = pipeline_model.transform(test_data)

                                                                                

In [25]:
test_predicted.select("label_index", "label_index_predicted")\
    .limit(10)\
    .show()

22/11/03 03:12:14 WARN DAGScheduler: Broadcasting large task binary with size 12.0 MiB
+-----------+---------------------+
|label_index|label_index_predicted|
+-----------+---------------------+
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 19.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
|       17.0|                 17.0|
+-----------+---------------------+



In [26]:
accuracy = evaluator.evaluate(test_predicted)
print("Accuracy = {:.2f}".format(accuracy))

22/11/03 03:13:31 WARN DAGScheduler: Broadcasting large task binary with size 12.0 MiB



[Stage 45:>                                                         (0 + 1) / 1]

Accuracy = 0.80



                                                                                