In [1]:
// 导入需要的包
import org.apache.spark.ml.{Model, Pipeline}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{Tokenizer, Word2Vec}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}

import scala.io.Source

In [2]:
// 常量定义
final val LABEL_POSITIVE = 1.0
final val LABEL_NEGATIVE = 0.0

final val CLASS_SPAM = "spam"

final val TRAIN_DATA = "https://s3.amazonaws.com/workflowexecutor/examples/data/SMSSpamCollection.csv"

In [3]:
// 辅助方法定义

/**
 * 数据格式转换
 */
def loadData(filePath: String): (Vector[(String, Double)], Vector[(String, Double)]) = {
    def reformat(sms: String, label: Double): (String, Double) = (sms.split("\t").last, label)

    val file = Source.fromURL(filePath,"UTF-8").getLines().toVector.tail
    val (spam, ham) = file.partition(_.contains(CLASS_SPAM))
    val spamData = spam.map(x => reformat(x, LABEL_POSITIVE))
    val hamData = ham.map(x => reformat(x, LABEL_NEGATIVE))
    (spamData, hamData)
}

/**
 * 计算精确度
 */
def precision(model:Model[_], test:DataFrame):Double = {
    val testResult = model.transform(test)
    val total = testResult.count()
    val corrects = testResult.filter("prediction = label").count()

    corrects.asInstanceOf[Double] / total
}

In [5]:
// 数据集加载
val sqlContext = new SQLContext(sc)

import sqlContext.implicits._

val (spam, ham) = loadData(TRAIN_DATA)
// 随机选取训练数据和测试数据
val (trainSpan, testSpan) = spam.partition(_._1.length % 10 < 8)
val (trainHam, testHam) = ham.partition(_._1.length % 10 < 8)

val negTest = sc.parallelize(testSpan).toDF("text", "label").cache()
val posTest = sc.parallelize(testHam).toDF("text", "label").cache()
val train = sc.parallelize(trainHam.union(trainSpan)).toDF("text", "label").cache()
val test = negTest.unionAll(posTest)

In [7]:
// 构造流水线
val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
val word2vec = new Word2Vec().setInputCol(tokenizer.getOutputCol).setOutputCol("features")
val lr = new LogisticRegression()
val pipeline = new Pipeline().setStages(Array(tokenizer,word2vec,lr))

In [8]:
// 构造参数网络
val paramGrid = new ParamGridBuilder().addGrid(word2vec.vectorSize, Array(50, 100, 200)).addGrid(lr.regParam, Array(0.00001, 0.001, 0.1)).build()

In [11]:
// 训练模型（包括Word2Vec特征抽取和LR分类模型）
val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator()).setEstimatorParamMaps(paramGrid).setNumFolds(5)

val model = cv.fit(train)

In [12]:
// 评估模型结果：使用AUC指标
val testResult = model.transform(test)
val evaluator = new BinaryClassificationEvaluator()
println(s"test metrics: ${evaluator.evaluate(testResult)}")

test metrics: 0.963439565363504


In [13]:
// 评估模型结果：负样本与正样本准确率
val negPrecision = precision(model,negTest)
val posPrecision = precision(model,posTest)
println(s"neg-precision = $negPrecision, pos-precision = $posPrecision")

neg-precision = 0.7531645569620253, pos-precision = 0.9789823008849557
