In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

sc = SparkContext.getOrCreate(SparkConf().setMaster("local[*]"))
spark = SparkSession.builder.getOrCreate()

In [2]:
from pyspark.sql.types import StructType, StructField, StringType

schema = StructType().add("class", StringType(), True).add("message", StringType(), True)

In [3]:
df = spark.read.schema(schema).csv("spam.csv")

In [4]:
df = df.filter(df.message.isNotNull())

In [5]:
from pyspark.ml.feature import StringIndexer
import pyspark.sql.functions as f

indexer = StringIndexer(inputCol="class", outputCol="label").setHandleInvalid("keep")
indexed = indexer.fit(df).transform(df)
indexed = indexed.select(f.col("label"), f.col("message"))
indexed = indexed.filter(indexed.label != 2.0)
indexed.show()

+-----+--------------------+
|label|             message|
+-----+--------------------+
|  0.0|Go until jurong p...|
|  0.0|Ok lar... Joking ...|
|  1.0|Free entry in 2 a...|
|  0.0|U dun say so earl...|
|  0.0|Nah I don't think...|
|  1.0|FreeMsg Hey there...|
|  0.0|Even my brother i...|
|  0.0|As per your reque...|
|  1.0|WINNER!! As a val...|
|  1.0|Had your mobile 1...|
|  0.0|I'm gonna be home...|
|  1.0|SIX chances to wi...|
|  1.0|URGENT! You have ...|
|  0.0|I've been searchi...|
|  0.0|I HAVE A DATE ON ...|
|  1.0|XXXMobileMovieClu...|
|  0.0|Oh k...i'm watchi...|
|  0.0|Eh u remember how...|
|  0.0|Fine if that��s t...|
|  1.0|England v Macedon...|
+-----+--------------------+
only showing top 20 rows



In [6]:
import pyspark.sql.functions as f

replaced = indexed.select(f.col("label"), f.col("message"), 
                          f.regexp_replace(f.col("message"), "[,.\-\!\?\$]", "").alias("replaced"))

In [7]:
from pyspark.ml.feature import Tokenizer

tokenizer = Tokenizer(inputCol="replaced", outputCol="words")
wordsData = tokenizer.transform(replaced)

In [8]:
from pyspark.ml.feature import StopWordsRemover

remover = StopWordsRemover(inputCol="words", outputCol="filtered")
filteredData = remover.transform(wordsData)

In [9]:
dataset = filteredData.select(f.col("label"), f.col("filtered"))
training, test = dataset.randomSplit([0.8, 0.2], seed = 0)

In [10]:
# https://stackoverflow.com/questions/32231049/how-to-use-spark-naive-bayes-classifier-for-text-classification-with-idf

from pyspark.mllib.feature import HashingTF, IDF
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import NaiveBayes   

# Split data into labels and features, transform
# preservesPartitioning is not really required
# since map without partitioner shouldn't trigger repartitiong
labels = training.rdd.map(
    lambda doc: doc["label"],  # Standard Python dict access 
    preservesPartitioning=True # This is obsolete.
)

tf = HashingTF(numFeatures=20000).transform( ## Use much larger number in practice
    training.rdd.map(lambda doc: doc["filtered"], 
    preservesPartitioning=True))

idf = IDF().fit(tf)
tfidf = idf.transform(tf)

# Combine using zip
training = labels.zip(tfidf).map(lambda x: LabeledPoint(x[0], x[1]))

# Train and check
model = NaiveBayes.train(training)
labels_and_preds = labels.zip(model.predict(tfidf)).map(
    lambda x: {"actual": x[0], "predicted": float(x[1])})

In [11]:
from pyspark.mllib.evaluation import MulticlassMetrics
from operator import itemgetter

metrics = MulticlassMetrics(
    labels_and_preds.map(itemgetter("actual", "predicted")))

metrics.accuracy

0.9860611510791367

In [12]:
labels = test.rdd.map(
    lambda doc: doc["label"],  # Standard Python dict access 
    preservesPartitioning=True # This is obsolete.
)

tf = HashingTF(numFeatures=20000).transform( ## Use much larger number in practice
    test.rdd.map(lambda doc: doc["filtered"], 
    preservesPartitioning=True))

idf = IDF().fit(tf)
tfidf = idf.transform(tf)

# Combine using zip
training = labels.zip(tfidf).map(lambda x: LabeledPoint(x[0], x[1]))

# Train and check
labels_and_preds = labels.zip(model.predict(tfidf)).map(
    lambda x: {"actual": x[0], "predicted": float(x[1])})

In [13]:
metrics = MulticlassMetrics(
    labels_and_preds.map(itemgetter("actual", "predicted")))

metrics.accuracy

0.9741992882562278