In [26]:
from pyspark.context import  SparkContext
from pyspark.sql import SQLContext

sc = SparkContext('local','test')
sql = SQLContext(sc)

In [27]:
from pyspark.sql.functions import lit
dems_df =  sql.read.text("dems.txt")
gop_df = sql.read.text("gop.txt")

In [28]:
corpus_df = dems_df.select("value", lit(1).alias("label")).union(gop_df.select("value", lit(0).alias("label")))

In [30]:
corpus_df.select("*").limit(20).show()

+--------------------+-----+
|               value|label|
+--------------------+-----+
|This week @senate...|    1|
|Health care profe...|    1|
|RT @SeemaNanda: G...|    1|
|Republicans keep ...|    1|
|RT @SpeakerPelosi...|    1|
|While the preside...|    1|
|You are not alone...|    1|
|RT @DNCWarRoom: W...|    1|
|RT @DNCWarRoom: T...|    1|
|RT @DNCWarRoom: T...|    1|
|LISTEN. TO. HEALT...|    1|
|RT @SeemaNanda: B...|    1|
|This is a HUGE wi...|    1|
|RT @SenSherrodBro...|    1|
|RT @WisDems: Make...|    1|
|RT @DemConvention...|    1|
|Abortion is healt...|    1|
|RT @RepLucyMcBath...|    1|
|Get counted. Get ...|    1|
+--------------------+-----+



In [31]:
train_df, test_df = corpus_df.randomSplit([0.75, 0.25])

In [34]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import CountVectorizer, Tokenizer, StopWordsRemover

tokenizer = Tokenizer(inputCol="value", outputCol="words")
stop_words_remover = StopWordsRemover(inputCol="words", outputCol="words_cleaned")
vectorizer = CountVectorizer(inputCol="words_cleaned", outputCol="features")
cleaning_pipeline = Pipeline(stages = [tokenizer,stop_words_remover,vectorizer])
cleaning_pipeline_model = cleaning_pipeline.fit(corpus_df)
cleaned_training_df = cleaning_pipeline_model.transform(train_df)
cleaned_testing_df = cleaning_pipeline_model.transform(test_df)

In [35]:
cleaned_training_df.show(3)

+--------------------+-----+--------------------+--------------------+--------------------+
|               value|label|               words|       words_cleaned|            features|
+--------------------+-----+--------------------+--------------------+--------------------+
|"....for the thir...|    1|["....for, the, t...|["....for, third,...|(86466,[0,23,80,9...|
|"12 years of educ...|    1|["12, years, of, ...|["12, years, educ...|(86466,[0,46,51,9...|
|"@KatieMcGintyPA ...|    1|["@katiemcgintypa...|["@katiemcgintypa...|(86466,[244,854,1...|
+--------------------+-----+--------------------+--------------------+--------------------+
only showing top 3 rows



In [36]:
from pyspark.ml.classification import NaiveBayes
naive_bayes = NaiveBayes(featuresCol="features", labelCol = "label")

In [37]:
naive_bayes_model = naive_bayes.fit(cleaned_training_df)
predictions_df = naive_bayes_model.transform(cleaned_testing_df)

In [38]:
predictions_df.select("features","label","prediction").limit(20).show()

+--------------------+-----+----------+
|            features|label|prediction|
+--------------------+-----+----------+
|(86466,[85,223,38...|    1|       1.0|
|(86466,[35,121,13...|    1|       0.0|
|(86466,[3,116,204...|    1|       0.0|
|(86466,[7,8,15,23...|    1|       1.0|
|(86466,[13,24,116...|    1|       1.0|
|(86466,[0,30,73,1...|    1|       1.0|
|(86466,[8,2097,29...|    1|       1.0|
|(86466,[0,47,144,...|    1|       1.0|
|(86466,[13,32,47,...|    1|       1.0|
|(86466,[0,66,81,8...|    1|       1.0|
|(86466,[0,13,42,5...|    1|       1.0|
|(86466,[97,125,17...|    1|       1.0|
|(86466,[81,116,11...|    1|       1.0|
|(86466,[12,99,102...|    1|       1.0|
|(86466,[2,3,5,6,1...|    1|       1.0|
|(86466,[16,65,94,...|    1|       1.0|
|(86466,[12,116,19...|    1|       1.0|
|(86466,[4,83,116,...|    1|       1.0|
|(86466,[0,53,159,...|    1|       1.0|
|(86466,[253,273,8...|    1|       0.0|
+--------------------+-----+----------+



In [39]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
eval = MulticlassClassificationEvaluator(labelCol='label',predictionCol='prediction', metricName = 'accuracy')
eval.evaluate(predictions_df)

0.9337789847503656

In [40]:
cleaned_training_df.select("features").show(1)

+--------------------+
|            features|
+--------------------+
|(86466,[0,23,80,9...|
+--------------------+
only showing top 1 row

