In [20]:
from pyspark import SparkContext

from pyspark.sql.types import *
from pyspark.sql import Row
from pyspark.sql import SQLContext

from pyspark.ml.linalg import Vector
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import HashingTF,StopWordsRemover,IDF,Tokenizer

In [3]:
%%time
sc = SparkContext("local[2]","Application")
sqlContext = SQLContext(sc)

Wall time: 31 s


In [4]:
%%time
path="mini_newsgroups\\*" #"https://kdd.ics.uci.edu/databases/20newsgroups/mini_newsgroups.tar.gz"
newsGroupRowData=sc.wholeTextFiles(path)
print "Number of documents read in is:",newsGroupRowData.count()

Number of documents read in is: 2000
Wall time: 2min 9s


In [5]:
newsGroupRowData.takeSample(False,1, 10L)

[(u'file:/C:/Users/6910P/Documents/mini_newsgroups/sci.med/59627',
  u'Newsgroups: sci.med\nPath: cantaloupe.srv.cs.cmu.edu!crabapple.srv.cs.cmu.edu!sjha\nFrom: sjha+@cs.cmu.edu (Somesh Jha)\nSubject: What is intersection syndrome and Feldene?\nMessage-ID: <C6EpA9.EFz.1@cs.cmu.edu>\nSender: news@cs.cmu.edu (Usenet News System)\nNntp-Posting-Host: gs73.sp.cs.cmu.edu\nOrganization: School of Computer Science, Carnegie Mellon\nDate: Sun, 2 May 1993 15:49:16 GMT\nLines: 17\n\n\nHi:\n\nI went to the orthopedist on Tuesday. He diagnosed me as having\n"intersection syndrome". He prescribed Feldene for me. I want\nto know more about the disease and the drug.\n\nThanks\n\n\nSomesh\n\n\n\n\n\n\n')]

In [7]:
%%time
filepaths = newsGroupRowData.map(lambda (filepath, text): filepath)
print filepaths.takeSample(False,5, 10L)

[u'file:/C:/Users/6910P/Documents/mini_newsgroups/comp.sys.mac.hardware/52296', u'file:/C:/Users/6910P/Documents/mini_newsgroups/alt.atheism/53542', u'file:/C:/Users/6910P/Documents/mini_newsgroups/rec.autos/101592', u'file:/C:/Users/6910P/Documents/mini_newsgroups/rec.motorcycles/104430', u'file:/C:/Users/6910P/Documents/mini_newsgroups/rec.motorcycles/104302']
Wall time: 1min 5s


In [8]:
%%time
text = newsGroupRowData.map(lambda (filepath, text): text)
print text.takeSample(False,1, 10L)

[u'Newsgroups: sci.med\nPath: cantaloupe.srv.cs.cmu.edu!crabapple.srv.cs.cmu.edu!sjha\nFrom: sjha+@cs.cmu.edu (Somesh Jha)\nSubject: What is intersection syndrome and Feldene?\nMessage-ID: <C6EpA9.EFz.1@cs.cmu.edu>\nSender: news@cs.cmu.edu (Usenet News System)\nNntp-Posting-Host: gs73.sp.cs.cmu.edu\nOrganization: School of Computer Science, Carnegie Mellon\nDate: Sun, 2 May 1993 15:49:16 GMT\nLines: 17\n\n\nHi:\n\nI went to the orthopedist on Tuesday. He diagnosed me as having\n"intersection syndrome". He prescribed Feldene for me. I want\nto know more about the disease and the drug.\n\nThanks\n\n\nSomesh\n\n\n\n\n\n\n']
Wall time: 1min 6s


In [9]:
%%time
id = filepaths.map(lambda filepath: filepath.split("/")[-1])
print id.take(5)

[u'51121', u'51126', u'51127', u'51131', u'51139']
Wall time: 1.64 s


In [12]:
%%time
topics = filepaths.map(lambda filepath: filepath.split("/")[-2])
print topics.take(5)
print topics.distinct().take(20)

[u'alt.atheism', u'alt.atheism', u'alt.atheism', u'alt.atheism', u'alt.atheism']
[u'sci.crypt', u'comp.sys.mac.hardware', u'sci.med', u'comp.windows.x', u'misc.forsale', u'talk.politics.guns', u'comp.os.ms-windows.misc', u'sci.space', u'rec.sport.baseball', u'rec.motorcycles', u'talk.politics.misc', u'soc.religion.christian', u'comp.graphics', u'talk.religion.misc', u'talk.politics.mideast', u'comp.sys.ibm.pc.hardware', u'alt.atheism', u'rec.sport.hockey', u'sci.electronics', u'rec.autos']
Wall time: 37.2 s


In [13]:
%%time
sqlContext = SQLContext(sc)

# The schema is encoded in a string.
schemaString = "id text topic"
fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split()]
schema = StructType(fields)

# Apply the schema to the RDD.
newsgroups = newsGroupRowData.map(lambda (filepath, text): (filepath.split("/")[-1],text,filepath.split("/")[-2]))
df = sqlContext.createDataFrame(newsgroups, schema)

#print schema
df.printSchema()

# Creates a temporary view using the DataFrame
df.createOrReplaceTempView("newsgroups")

# SQL can be run over DataFrames that have been registered as a table.
results = sqlContext.sql("SELECT id,topic,text FROM newsgroups limit 5")
results.show()

root
 |-- id: string (nullable = true)
 |-- text: string (nullable = true)
 |-- topic: string (nullable = true)

+-----+-----------+--------------------+
|   id|      topic|                text|
+-----+-----------+--------------------+
|51121|alt.atheism|Xref: cantaloupe....|
|51126|alt.atheism|Path: cantaloupe....|
|51127|alt.atheism|Path: cantaloupe....|
|51131|alt.atheism|Path: cantaloupe....|
|51139|alt.atheism|Path: cantaloupe....|
+-----+-----------+--------------------+

Wall time: 23.2 s


In [14]:
results = sqlContext.sql("select distinct topic, count(*) as cnt from newsgroups group by topic order by cnt desc limit 5")
results.show()

+--------------------+---+
|               topic|cnt|
+--------------------+---+
|        misc.forsale|100|
|      comp.windows.x|100|
|    rec.sport.hockey|100|
|  rec.sport.baseball|100|
|comp.os.ms-window...|100|
+--------------------+---+



In [22]:
%%time
result_list = df[df.topic.like("comp%")].collect()
new_df = sc.parallelize(result_list).toDF()
new_df.dropDuplicates().show()

+-----+--------------------+--------------------+
|   id|                text|               topic|
+-----+--------------------+--------------------+
|38907|Path: cantaloupe....|       comp.graphics|
|60691|Path: cantaloupe....|comp.sys.ibm.pc.h...|
|51996|Path: cantaloupe....|comp.sys.mac.hard...|
|38758|Xref: cantaloupe....|       comp.graphics|
|38904|Xref: cantaloupe....|       comp.graphics|
| 9622|Xref: cantaloupe....|comp.os.ms-window...|
|67386|Path: cantaloupe....|      comp.windows.x|
|52059|Xref: cantaloupe....|comp.sys.mac.hard...|
|66400|Newsgroups: comp....|      comp.windows.x|
|60841|Path: cantaloupe....|comp.sys.ibm.pc.h...|
|10094|Path: cantaloupe....|comp.os.ms-window...|
| 9911|Xref: cantaloupe....|comp.os.ms-window...|
| 9943|Path: cantaloupe....|comp.os.ms-window...|
|60992|Newsgroups: comp....|comp.sys.ibm.pc.h...|
|52010|Path: cantaloupe....|comp.sys.mac.hard...|
|38750|Path: cantaloupe....|       comp.graphics|
| 9485|Xref: cantaloupe....|comp.os.ms-window...|


In [24]:
labeledNewsGroups = df.withColumn("label",df.topic.like("comp%").cast("double"))
labeledNewsGroups.sample(False,0.003,10L).show(5)

+-----+--------------------+--------------------+-----+
|   id|                text|               topic|label|
+-----+--------------------+--------------------+-----+
|68232|Newsgroups: comp....|      comp.windows.x|  1.0|
|14989|Path: cantaloupe....|           sci.crypt|  0.0|
|15248|Xref: cantaloupe....|           sci.crypt|  0.0|
|15736|Path: cantaloupe....|           sci.crypt|  0.0|
|20738|Path: cantaloupe....|soc.religion.chri...|  0.0|
+-----+--------------------+--------------------+-----+
only showing top 5 rows



In [25]:
train_set, test_set = labeledNewsGroups.randomSplit([0.9, 0.1], 12345)
print "Total document count:",labeledNewsGroups.count()
print "Training-set count:",train_set.count()
print "Test-set count:",test_set.count()

Total document count: 2000
Training-set count: 1779
Test-set count: 221


In [29]:
tokenizer = Tokenizer().setInputCol("text").setOutputCol("words")
remover= StopWordsRemover().setInputCol("words").setOutputCol("filtered").setCaseSensitive(False)
hashingTF = HashingTF().setNumFeatures(1000).setInputCol("filtered").setOutputCol("rawFeatures")
idf = IDF().setInputCol("rawFeatures").setOutputCol("features").setMinDocFreq(0)
lr = LogisticRegression().setRegParam(0.01).setThreshold(0.5)
pipeline=Pipeline(stages=[tokenizer,remover,hashingTF,idf, lr])

In [30]:
print "Logistic regression features column=",lr.getFeaturesCol()
print "logistic regression label column=",lr.getLabelCol()
print "Logistic regression threshold=",lr.getThreshold()

Logistic regression features column= features
logistic regression label column= label
Logistic regression threshold= 0.5


In [31]:
print "Tokenizer:"
print tokenizer.explainParams()
print "***************************"
print "Remover:"
print remover.explainParams()
print "***************************"
print "HashingTF:"
print hashingTF.explainParams()
print "***************************"
print "IDF:"
print idf.explainParams()
print "***************************"
print "LogisticRegression:"
print lr.explainParams()
print "***************************"
print "Pipeline:"
print pipeline.explainParams()

Tokenizer:
inputCol: input column name. (current: text)
outputCol: output column name. (default: Tokenizer_4ba4b599d3828120d9e0__output, current: words)
***************************
Remover:
caseSensitive: whether to do a case sensitive comparison over the stop words (default: False, current: False)
inputCol: input column name. (current: words)
outputCol: output column name. (default: StopWordsRemover_41a2b4f9d66ae347ce42__output, current: filtered)
stopWords: The words to be filtered out (default: [u'i', u'me', u'my', u'myself', u'we', u'our', u'ours', u'ourselves', u'you', u'your', u'yours', u'yourself', u'yourselves', u'he', u'him', u'his', u'himself', u'she', u'her', u'hers', u'herself', u'it', u'its', u'itself', u'they', u'them', u'their', u'theirs', u'themselves', u'what', u'which', u'who', u'whom', u'this', u'that', u'these', u'those', u'am', u'is', u'are', u'was', u'were', u'be', u'been', u'being', u'have', u'has', u'had', u'having', u'do', u'does', u'did', u'doing', u'a', u'an'

In [32]:
remover.getStopWords()

[u'i',
 u'me',
 u'my',
 u'myself',
 u'we',
 u'our',
 u'ours',
 u'ourselves',
 u'you',
 u'your',
 u'yours',
 u'yourself',
 u'yourselves',
 u'he',
 u'him',
 u'his',
 u'himself',
 u'she',
 u'her',
 u'hers',
 u'herself',
 u'it',
 u'its',
 u'itself',
 u'they',
 u'them',
 u'their',
 u'theirs',
 u'themselves',
 u'what',
 u'which',
 u'who',
 u'whom',
 u'this',
 u'that',
 u'these',
 u'those',
 u'am',
 u'is',
 u'are',
 u'was',
 u'were',
 u'be',
 u'been',
 u'being',
 u'have',
 u'has',
 u'had',
 u'having',
 u'do',
 u'does',
 u'did',
 u'doing',
 u'a',
 u'an',
 u'the',
 u'and',
 u'but',
 u'if',
 u'or',
 u'because',
 u'as',
 u'until',
 u'while',
 u'of',
 u'at',
 u'by',
 u'for',
 u'with',
 u'about',
 u'against',
 u'between',
 u'into',
 u'through',
 u'during',
 u'before',
 u'after',
 u'above',
 u'below',
 u'to',
 u'from',
 u'up',
 u'down',
 u'in',
 u'out',
 u'on',
 u'off',
 u'over',
 u'under',
 u'again',
 u'further',
 u'then',
 u'once',
 u'here',
 u'there',
 u'when',
 u'where',
 u'why',
 u'how',
 u'all

In [34]:
model=pipeline.fit(train_set)

In [41]:
%%time
predictions = model.transform(test_set)
predictions.select("id","topic","probability","prediction","label").sample(False,0.01,10L).show(5)
predictions.select("id","topic","probability","prediction","label").filter(predictions.topic.like("comp%")).sample(False,0.1,10L).show(5)

+-----+--------------------+--------------------+----------+-----+
|   id|               topic|         probability|prediction|label|
+-----+--------------------+--------------------+----------+-----+
|51186|         alt.atheism|[0.99545605647335...|       0.0|  0.0|
|53459|         alt.atheism|[0.99656409133350...|       0.0|  0.0|
| 9942|comp.os.ms-window...|[0.89795237438370...|       0.0|  1.0|
|20866|soc.religion.chri...|[0.98273679010727...|       0.0|  0.0|
|54633|  talk.politics.guns|[0.47279161297191...|       1.0|  0.0|
+-----+--------------------+--------------------+----------+-----+

+-----+--------------------+--------------------+----------+-----+
|   id|               topic|         probability|prediction|label|
+-----+--------------------+--------------------+----------+-----+
|38271|       comp.graphics|[0.02780255772952...|       1.0|  1.0|
|50473|comp.sys.mac.hard...|[0.14943992966858...|       1.0|  1.0|
|60699|comp.sys.ibm.pc.h...|[0.18058644545815...|       1.0| 

In [42]:
predictions.sample(False,0.01,10L).show(5)

+-----+--------------------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------+
|   id|                text|               topic|label|               words|            filtered|         rawFeatures|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+-----+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------+
|51186|Path: cantaloupe....|         alt.atheism|  0.0|[path:, cantaloup...|[path:, cantaloup...|(1000,[19,23,37,4...|(1000,[19,23,37,4...|[5.38940572713749...|[0.99545605647335...|       0.0|
|53459|Newsgroups: alt.a...|         alt.atheism|  0.0|[newsgroups:, alt...|[newsgroups:, alt...|(1000,[1,9,17,25,...|(1000,[1,9,17,25,...|[5.67003203160359...|[0.99656409133350...|       0.0|
| 9942|Xref: cantaloupe....|comp.os

In [43]:
evaluator = BinaryClassificationEvaluator().setMetricName("areaUnderROC")
print "Area under ROC curve:",evaluator.evaluate(predictions)

Area under ROC curve: 0.909709144172


In [59]:
paramGrid = ParamGridBuilder()\
    .addGrid(hashingTF.numFeatures,[1000,10000,100000])\
    .addGrid(idf.minDocFreq,[0,10,100])\
    .build()

In [60]:
cv = CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(2)

In [61]:
%%time
cvModel = cv.fit(train_set)
print "Area under the ROC curve for best fitted model =",evaluator.evaluate(cvModel.transform(test_set))

<type 'list'>


In [66]:
print "Area under ROC curve for non-tuned model:",evaluator.evaluate(predictions)
print "Area under ROC curve for fitted model:",evaluator.evaluate(cvModel.transform(test_set))
#print "Improvement:%.2f".format(evaluator.evaluate(cvModel.transform(test_set)) - evaluator.evaluate(predictions))*100 / evaluator.evaluate(predictions)) 


 Area under ROC curve for non-tuned model: 0.909709144172
Area under ROC curve for fitted model: 0.96578782172
