In [124]:
from pyspark.sql import SparkSession
from pyspark import SparkConf

In [127]:
conf = SparkConf().set('spark.rpc.message.maxSize', '256')
spark = SparkSession.builder.config(conf=conf).appName('spam').getOrCreate()

In [9]:
data = spark.read.csv('SMSSpamCollection',inferSchema=True,sep='\t')

In [11]:
data = data.withColumnRenamed('_c0','class').withColumnRenamed('_c1','text')

In [12]:
data.show(5)

+-----+--------------------+
|class|                text|
+-----+--------------------+
|  ham|Go until jurong p...|
|  ham|Ok lar... Joking ...|
| spam|Free entry in 2 a...|
|  ham|U dun say so earl...|
|  ham|Nah I don't think...|
+-----+--------------------+
only showing top 5 rows



In [13]:
from pyspark.sql.functions import length

In [14]:
data = data.withColumn('length',length(data['text']))

In [15]:
data.show()

+-----+--------------------+------+
|class|                text|length|
+-----+--------------------+------+
|  ham|Go until jurong p...|   111|
|  ham|Ok lar... Joking ...|    29|
| spam|Free entry in 2 a...|   155|
|  ham|U dun say so earl...|    49|
|  ham|Nah I don't think...|    61|
| spam|FreeMsg Hey there...|   147|
|  ham|Even my brother i...|    77|
|  ham|As per your reque...|   160|
| spam|WINNER!! As a val...|   157|
| spam|Had your mobile 1...|   154|
|  ham|I'm gonna be home...|   109|
| spam|SIX chances to wi...|   136|
| spam|URGENT! You have ...|   155|
|  ham|I've been searchi...|   196|
|  ham|I HAVE A DATE ON ...|    35|
| spam|XXXMobileMovieClu...|   149|
|  ham|Oh k...i'm watchi...|    26|
|  ham|Eh u remember how...|    81|
|  ham|Fine if thats th...|    56|
| spam|England v Macedon...|   155|
+-----+--------------------+------+
only showing top 20 rows



In [19]:
data.groupBy('class').count().show()

+-----+-----+
|class|count|
+-----+-----+
|  ham| 4827|
| spam|  747|
+-----+-----+



In [31]:
data.groupBy('class').mean().show()

+-----+-----------------+
|class|      avg(length)|
+-----+-----------------+
|  ham|71.45431945307645|
| spam|138.6706827309237|
+-----+-----------------+



In [180]:
from pyspark.ml.feature import Tokenizer, StopWordsRemover, CountVectorizer, IDF, StringIndexer

In [181]:
tokenizer = Tokenizer(inputCol='text',outputCol='tokens')
stop_remove = StopWordsRemover(inputCol='tokens',outputCol='stop_token')
count_vec = CountVectorizer(inputCol='stop_token',outputCol='c_vec')
idf = IDF(inputCol='c_vec',outputCol='tf_idf')
ham_spam_to_num = StringIndexer(inputCol='class',outputCol='label')

In [182]:
from pyspark.ml.feature import VectorAssembler

In [208]:
clean_up = VectorAssembler(inputCols=['tf_idf','length'],outputCol='features')

TypeError: 'VectorAssembler' object is not subscriptable

In [184]:
from pyspark.ml.classification import MultilayerPerceptronClassifier

In [185]:
from pyspark.ml import Pipeline

In [186]:
data_prep_pipe = Pipeline(stages=[ham_spam_to_num,tokenizer,stop_remove,count_vec,idf,clean_up])

In [200]:
cleaner = data_prep_pipe.fit(data)
cleaner

PipelineModel_5b5fdce95597

In [267]:
clean_data =cleaner.transform(data)
num_features=clean_data.select('features').collect()[0][0].size

In [268]:
clean_data = clean_data.select('label','features')

In [269]:
# split into training and test set
training, test = clean_data.randomSplit([0.7,0.3])
training['features']

Column<b'features'>

In [270]:
mlp = MultilayerPerceptronClassifier(labelCol = 'label', featuresCol='features',
                                     maxIter=100,layers=[num_features,5,2,2], blockSize=128, seed=1234)

In [271]:
# train the model
spam_detector = mlp.fit(training)

In [272]:
# compute accuracy on the test set
test_results = spam_detector.transform(test)
test_results

DataFrame[label: double, features: vector, rawPrediction: vector, probability: vector, prediction: double]

In [273]:
preds = test_results.select("prediction", "label")

In [274]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")

In [275]:
print(f'Test set accuracy: {evaluator.evaluate(preds) }')

Test set accuracy: 0.9763173475429248
