In [2]:
%env SPARK_LOCAL_HOSTNAME=localhost

import pyspark

env: SPARK_LOCAL_HOSTNAME=localhost


In [3]:
from pyspark.sql import SparkSession

In [4]:
spark = SparkSession.builder.appName('BD').getOrCreate()

In [22]:
from pyspark.ml.feature import Tokenizer, RegexTokenizer
from pyspark.ml.functions import *
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.feature import CountVectorizer
from pyspark.ml.classification import NaiveBayes, NaiveBayesModel

### Reading Data File in Spark Data Frame

In [24]:
reddit_csv = spark.read.csv('reddit_headlines_labels3.csv', inferSchema=True, header=True)
reddit_csv.show(truncate=False, n=7)

+------------------------------------------------------------------------------------------------------------+------+
|headline                                                                                                    |label3|
+------------------------------------------------------------------------------------------------------------+------+
|Unpaid Kentucky coal miners have been blocking a train track for 3 weeks                                    |0.0   |
|New revelations about Herschel Walker show how Democrats could hold the Senate                              |1.0   |
|Amazon referred to DOJ for potential criminal obstruction of Congress                                       |0.0   |
|Swedish navy returns to vast underground HQ amid Russia fears                                               |0.0   |
|Supreme Court strikes down Louisiana law that would have limited state to one abortion clinic               |1.0   |
|The fight over domestic COVID funding is holding back g

In [6]:
reddit_csv.where(col("label3").isNull()).show()

+--------+------+
|headline|label3|
+--------+------+
+--------+------+



In [7]:
reddit_csv.dtypes

[('headline', 'string'), ('label3', 'string')]

In [8]:
reddit_csv.printSchema()

root
 |-- headline: string (nullable = true)
 |-- label3: string (nullable = true)



In [9]:
data = reddit_csv.withColumn('label3', reddit_csv['label3'].cast('int'))
data.printSchema()

root
 |-- headline: string (nullable = true)
 |-- label3: integer (nullable = true)



In [10]:
data.show(5)

+--------------------+------+
|            headline|label3|
+--------------------+------+
|Unpaid Kentucky c...|     0|
|New revelations a...|     1|
|Amazon referred t...|     0|
|Swedish navy retu...|     0|
|Supreme Court str...|     1|
+--------------------+------+
only showing top 5 rows



In [11]:
data = reddit_csv.select("headline", col("label3").cast("Int").alias("label3"))
data.show(truncate = False,n=5)

+---------------------------------------------------------------------------------------------+------+
|headline                                                                                     |label3|
+---------------------------------------------------------------------------------------------+------+
|Unpaid Kentucky coal miners have been blocking a train track for 3 weeks                     |0     |
|New revelations about Herschel Walker show how Democrats could hold the Senate               |1     |
|Amazon referred to DOJ for potential criminal obstruction of Congress                        |0     |
|Swedish navy returns to vast underground HQ amid Russia fears                                |0     |
|Supreme Court strikes down Louisiana law that would have limited state to one abortion clinic|1     |
+---------------------------------------------------------------------------------------------+------+
only showing top 5 rows



In [12]:
data.where(col("label3").isNull()).count()

50

In [13]:
data = data.na.drop()

### Divide the Data into Train and Test Data 

In [14]:
#Divide the data into 70% for training and 30% for testing
dividedData = data.randomSplit([0.7,0.3])
trainingData = dividedData[0] #index[0] = data training
testingData = dividedData[1] #index[0] = data testing
print('Training Data Rows:', trainingData.count(),'; Testing Data Rows:', testingData.count())

Training Data Rows: 3516 ; Testing Data Rows: 1503


In [15]:
trainingData.show()

+--------------------+------+
|            headline|label3|
+--------------------+------+
|"""China Is Natio...|     0|
|"""Crisis: Danger...|     1|
|"""Disturbing"" m...|     1|
|"""F-35s Don't He...|     0|
|"""General Mud"" ...|     0|
|"""I want you to ...|     0|
|"""Is that too so...|     0|
|"""Putin wants to...|     0|
|"""QAnon Shaman""...|     0|
|"""They have brok...|     0|
|"""War Justificat...|     0|
|"""What I Know Ab...|     1|
|"""Where the f***...|     0|
|"'Unacceptable': ...|     0|
|"'We Have Got to ...|     1|
|"'Yes, Exactly,' ...|     0|
|"911 call from Br...|     0|
|"After Inlander i...|     0|
|"Amazon blocks sa...|     0|
|"As Trump Says 'V...|     0|
+--------------------+------+
only showing top 20 rows



### Preparing Training Data

<br> Seperate individual words from 'headline' into seperate words

In [16]:
tokenizer = RegexTokenizer(inputCol="headline", outputCol="headlineWords", pattern='\\W') #RegularExpression Tokenizer
tokenizedTrain = tokenizer.transform(trainingData)
tokenizedTrain.show(truncate=False, n=5)

+----------------------------------------------------------------------------------------------------------------------------+------+------------------------------------------------------------------------------------------------------------------------------------------+
|headline                                                                                                                    |label3|headlineWords                                                                                                                             |
+----------------------------------------------------------------------------------------------------------------------------+------+------------------------------------------------------------------------------------------------------------------------------------------+
|"""China Is National Security Threat No. 1"" -- US Director of National Intelligence John Ratcliffe"                        |0     |[china, is, national, security, threat, no, 1, u

In [17]:
tokenizedTrain.show(truncate=True, n=5)

+--------------------+------+--------------------+
|            headline|label3|       headlineWords|
+--------------------+------+--------------------+
|"""China Is Natio...|     0|[china, is, natio...|
|"""Crisis: Danger...|     1|[crisis, danger, ...|
|"""Disturbing"" m...|     1|[disturbing, memo...|
|"""F-35s Don't He...|     0|[f, 35s, don, t, ...|
|"""General Mud"" ...|     0|[general, mud, pu...|
+--------------------+------+--------------------+
only showing top 5 rows



<br> Removing stop words </br>

In [18]:
swr = StopWordsRemover(inputCol=tokenizer.getOutputCol(), 
                       outputCol="MeaningfulWords")
SwRemovedTrain = swr.transform(tokenizedTrain)
SwRemovedTrain.show(truncate=False, n=5)

+----------------------------------------------------------------------------------------------------------------------------+------+------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|headline                                                                                                                    |label3|headlineWords                                                                                                                             |MeaningfulWords                                                                                          |
+----------------------------------------------------------------------------------------------------------------------------+------+-------------------------------------------------------------------------------------------------------------

<br> Converting word features into numerical features </br>

In [19]:
hashTF = HashingTF(inputCol=swr.getOutputCol(), outputCol="features")
numericTrainData = hashTF.transform(SwRemovedTrain).select(
    'label3', 'MeaningfulWords', 'features')
numericTrainData.show(truncate=False, n=3)

+------+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------+
|label3|MeaningfulWords                                                                                     |features                                                                                                                                            |
+------+----------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------+
|0     |[china, national, security, threat, 1, us, director, national, intelligence, john, ratcliffe]       |(262144,[13981,81060,92651,107101,109156,194186,212790,219622,230810,251861],[1.0,1.0,1.0,1.0,1.0,1.0,2.0,1.0,1.0,

### Training the model

In [110]:
lr = LogisticRegression(labelCol="label3", featuresCol="features", 
                        maxIter=10, regParam=0.01, elasticNetParam=0.8)
model = lr.fit(numericTrainData)
#print ("Training is done!")

### Preparing Test Data

In [111]:
tokenizedTest = tokenizer.transform(testingData)
SwRemovedTest = swr.transform(tokenizedTest)
numericTest = hashTF.transform(SwRemovedTest).select(
    'Label3', 'MeaningfulWords', 'features')
numericTest.show(truncate=False, n=2)

+------+---------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------+
|Label3|MeaningfulWords                                                                              |features                                                                                                               |
+------+---------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------+
|0     |[china, national, security, threat, 1, us, director, national, intelligence, john, ratcliffe]|(262144,[13981,81060,92651,107101,109156,194186,212790,219622,230810,251861],[1.0,1.0,1.0,1.0,1.0,1.0,2.0,1.0,1.0,1.0])|
|0     |[socialist, aoc, pleads, second, stimulus, checks, rent, forgiveness, pandemic, rages]       |(26214

### Predict testing data and calculate the accuracy model

In [112]:
prediction = model.transform(numericTest)
predictionFinal = prediction.select(
    "MeaningfulWords", "prediction", "Label3")
predictionFinal.show(n=4, truncate = False)
correctPrediction = predictionFinal.filter(
    predictionFinal['prediction'] == predictionFinal['Label3']).count()
totalData = predictionFinal.count()
print("correct prediction:", correctPrediction, ", total data:", totalData, 
      ", accuracy:", correctPrediction/totalData)

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------+------+
|MeaningfulWords                                                                                                                                                                                            |prediction|Label3|
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------+------+
|[china, national, security, threat, 1, us, director, national, intelligence, john, ratcliffe]                                                                                                              |0.0       |0     |
|[socialist, aoc, pleads, second, stimulus, checks, rent, forgiveness, pandemic, rages]                 