-
Notifications
You must be signed in to change notification settings - Fork 401
/
BenchmarkAlgorithm.scala
120 lines (102 loc) · 3.74 KB
/
BenchmarkAlgorithm.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package com.databricks.spark.sql.perf.mllib
import com.typesafe.scalalogging.slf4j.{LazyLogging => Logging}
import org.apache.spark.ml.attribute.{NominalAttribute, NumericAttribute}
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
/**
* The description of a benchmark for an ML algorithm. It follows a simple, standard proceduce:
* - generate some test and training data
* - generate a model against the training data
* - score the model against the training data
* - score the model against the test data
*
* You should not assume that your implementation can carry state around. If some state is needed,
* consider adding it to the context.
*
* It is assumed that the implementation is going to be an object.
*/
trait BenchmarkAlgorithm extends Logging {
def trainingDataSet(ctx: MLBenchContext): DataFrame
def testDataSet(ctx: MLBenchContext): DataFrame
/**
* Create an [[Estimator]] with params set from the given [[MLBenchContext]].
*/
def getEstimator(ctx: MLBenchContext): Estimator[_]
/**
* The unnormalized score of the training procedure on a dataset. The normalization is
* performed by the caller.
*/
@throws[Exception]("if scoring fails")
def score(
ctx: MLBenchContext,
testSet: DataFrame,
model: Transformer): Double = -1.0 // Not putting NaN because it is not valid JSON.
def name: String = {
this.getClass.getCanonicalName.replace("$", "")
}
}
/**
* Uses an evaluator to perform the scoring.
*/
trait ScoringWithEvaluator {
self: BenchmarkAlgorithm =>
protected def evaluator(ctx: MLBenchContext): Evaluator
final override def score(
ctx: MLBenchContext,
testSet: DataFrame,
model: Transformer): Double = {
val eval = model.transform(testSet)
evaluator(ctx).evaluate(eval)
}
}
/**
* Builds the training set for an initial dataset and an initial model. Useful for validating a
* trained model against a given model.
*/
trait TrainingSetFromTransformer {
self: BenchmarkAlgorithm =>
protected def initialData(ctx: MLBenchContext): DataFrame
protected def trueModel(ctx: MLBenchContext): Transformer
final override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
val initial = initialData(ctx)
val model = trueModel(ctx)
val fCol = col("features")
// Special case for the trees: we need to set the number of labels.
// numClasses is set? We will add the number of classes to the final column.
val lCol = ctx.params.numClasses match {
case Some(numClasses) =>
val labelAttribute = if (numClasses == 0) {
NumericAttribute.defaultAttr.withName("label")
} else {
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
}
val labelMetadata = labelAttribute.toMetadata()
col("prediction").as("label", labelMetadata)
case None =>
col("prediction").as("label")
}
model.transform(initial).select(fCol, lCol)
}
}
/**
* The test data is the same as the training data.
*/
trait TestFromTraining {
self: BenchmarkAlgorithm =>
final override def testDataSet(ctx: MLBenchContext): DataFrame = {
// Copy the context with a new seed.
val ctx2 = ctx.params.randomSeed match {
case Some(x) =>
// Also set the number of examples to the number of test examples.
assert(ctx.params.numTestExamples.nonEmpty, "You must specify test examples")
val p = ctx.params.copy(randomSeed = Some(x + 1), numExamples = ctx.params.numTestExamples)
ctx.copy(params = p)
case None =>
// Making a full copy to reset the internal seed.
ctx.copy()
}
self.trainingDataSet(ctx2)
}
}