/
HamOrSpamDemo.scala
193 lines (168 loc) · 6.86 KB
/
HamOrSpamDemo.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.h2o
import hex.ModelMetricsBinomial
import hex.deeplearning.{DeepLearning, DeepLearningModel}
import hex.deeplearning.DeepLearningModel.DeepLearningParameters
import org.apache.spark.h2o._
import org.apache.spark.mllib.feature.{HashingTF, IDF, IDFModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
import org.apache.spark.{SparkConf, SparkContext, mllib}
import water.support.{H2OFrameSupport, ModelMetricsSupport, SparkContextSupport}
/**
* Demo for NYC meetup and MLConf 2015.
*
* It predicts spam text messages.
* Training dataset is available in the file smalldata/smsData.txt.
*/
object HamOrSpamDemo extends SparkContextSupport with ModelMetricsSupport with H2OFrameSupport {
val DATAFILE = "smsData.txt"
val TEST_MSGS = Seq(
"Michal, beer tonight in MV?",
"We tried to contact you re your reply to our offer of a Video Handset? 750 anytime any networks mins? UNLIMITED TEXT?")
def main(args: Array[String]) {
val conf: SparkConf = configure("Sparkling Water Meetup: Ham or Spam (spam text messages detector)")
// Create SparkContext to execute application on Spark cluster
val sc = new SparkContext(conf)
// Register input file as Spark file
addFiles(sc, TestUtils.locate("smalldata/" + DATAFILE))
// Initialize H2O context
implicit val h2oContext = H2OContext.getOrCreate(sc)
import h2oContext.implicits._
// Initialize SQL context
implicit val sqlContext = SparkSession.builder().getOrCreate().sqlContext
import sqlContext.implicits._
// Data load
val data = load(sc, DATAFILE)
// Extract response spam or ham
val hamSpam = data.map(r => r(0))
val message = data.map(r => r(1))
// Tokenize message content
val tokens = tokenize(message)
// Build IDF model
val (hashingTF, idfModel, tfidf) = buildIDFModel(tokens)
// Merge response with extracted vectors
val resultRDD: DataFrame = hamSpam.zip(tfidf).map(v => SMS(v._1, v._2)).toDF
val table: H2OFrame = resultRDD
// Transform target column into categorical
H2OFrameSupport.withLockAndUpdate(table) { fr =>
fr.replace(fr.find("target"), fr.vec("target").toCategoricalVec).remove()
}
// Split table
val keys = Array[String]("train.hex", "valid.hex")
val ratios = Array[Double](0.8)
val frs = split(table, keys, ratios)
val (train, valid) = (frs(0), frs(1))
table.delete()
// Build a model
val dlModel = buildDLModel(train, valid)
// Collect model metrics
val trainMetrics = modelMetrics[ModelMetricsBinomial](dlModel, train)
val validMetrics = modelMetrics[ModelMetricsBinomial](dlModel, valid)
println(
s"""
|AUC on train data = ${trainMetrics.auc}
|AUC on valid data = ${validMetrics.auc}
""".stripMargin)
// Detect spam messages
TEST_MSGS.foreach(msg => {
println(
s"""
|"$msg" is ${if (isSpam(msg, sc, dlModel, hashingTF, idfModel)) "SPAM" else "HAM"}
""".stripMargin)
})
// Shutdown Spark cluster and H2O
h2oContext.stop(stopSparkContext = true)
}
/** Data loader */
def load(sc: SparkContext, dataFile: String): RDD[Array[String]] = {
sc.textFile(enforceLocalSparkFile(dataFile)).map(l => l.split("\t", 2)).filter(r => !r(0).isEmpty)
}
/** Text message tokenizer.
*
* Produce a bag of word representing given message.
*
* @param data RDD of text messages
* @return RDD of bag of words
*/
def tokenize(data: RDD[String]): RDD[Seq[String]] = {
val ignoredWords = Seq("the", "a", "", "in", "on", "at", "as", "not", "for")
val ignoredChars = Seq(',', ':', ';', '/', '<', '>', '"', '.', '(', ')', '?', '-', '\'', '!', '0', '1')
val texts = data.map(r => {
var smsText = r.toLowerCase
for (c <- ignoredChars) {
smsText = smsText.replace(c, ' ')
}
val words = smsText.split(" ").filter(w => !ignoredWords.contains(w) && w.length > 2).distinct
words.toSeq
})
texts
}
/** Buil tf-idf model representing a text message. */
def buildIDFModel(tokens: RDD[Seq[String]],
minDocFreq: Int = 4,
hashSpaceSize: Int = 1 << 10): (HashingTF, IDFModel, RDD[mllib.linalg.Vector]) = {
// Hash strings into the given space
val hashingTF = new HashingTF(hashSpaceSize)
val tf = hashingTF.transform(tokens)
// Build term frequency-inverse document frequency
val idfModel = new IDF(minDocFreq = minDocFreq).fit(tf)
val expandedText = idfModel.transform(tf)
(hashingTF, idfModel, expandedText)
}
/** Builds DeepLearning model. */
def buildDLModel(train: Frame, valid: Frame,
epochs: Int = 10, l1: Double = 0.001, l2: Double = 0.0,
hidden: Array[Int] = Array[Int](200, 200))
(implicit h2oContext: H2OContext): DeepLearningModel = {
import h2oContext.implicits._
// Build a model
val dlParams = new DeepLearningParameters()
dlParams._train = train
dlParams._valid = valid
dlParams._response_column = 'target
dlParams._epochs = epochs
dlParams._l1 = l1
dlParams._hidden = hidden
// Create a job
val dl = new DeepLearning(dlParams, water.Key.make("dlModel.hex"))
dl.trainModel.get
}
/** Spam detector */
def isSpam(msg: String,
sc: SparkContext,
dlModel: DeepLearningModel,
hashingTF: HashingTF,
idfModel: IDFModel,
hamThreshold: Double = 0.5)
(implicit sqlContext: SQLContext, h2oContext: H2OContext): Boolean = {
import h2oContext.implicits._
import sqlContext.implicits._
val msgRdd = sc.parallelize(Seq(msg))
val msgVector: DataFrame = idfModel.transform(
hashingTF.transform(
tokenize(msgRdd))).map(v => SMS("?", v)).toDF
val msgTable: H2OFrame = msgVector
msgTable.remove(0) // remove first column
val prediction = dlModel.score(msgTable)
//println(prediction)
prediction.vecs()(1).at(0) < hamThreshold
}
}
/** Training message representation. */
case class SMS(target: String, fv: mllib.linalg.Vector)