-
Notifications
You must be signed in to change notification settings - Fork 310
/
XGBoostClassificationModelOp.scala
91 lines (78 loc) · 4.54 KB
/
XGBoostClassificationModelOp.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package ml.dmlc.xgboost4j.scala.spark.mleap
import java.nio.file.Files
import ml.combust.bundle.BundleContext
import ml.combust.bundle.dsl.{Model, NodeShape, Value}
import ml.combust.bundle.op.OpModel
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
import org.apache.spark.ml.bundle._
import org.apache.spark.ml.linalg.Vector
import scala.util.Using
/**
* Created by hollinwilkins on 9/16/17.
*/
class XGBoostClassificationModelOp extends SimpleSparkOp[XGBoostClassificationModel] {
/** Type class for the underlying model.
*/
override val Model: OpModel[SparkBundleContext, XGBoostClassificationModel] = new OpModel[SparkBundleContext, XGBoostClassificationModel] {
override val klazz: Class[XGBoostClassificationModel] = classOf[XGBoostClassificationModel]
override def opName: String = "xgboost.classifier"
override def store(model: Model, obj: XGBoostClassificationModel)
(implicit context: BundleContext[SparkBundleContext]): Model = {
assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))
val thresholds = if(obj.isSet(obj.thresholds)) {
Some(obj.getThresholds)
} else None
val out = Files.newOutputStream(context.file("xgboost.model"))
obj._booster.saveModel(out)
val numFeatures = context.context.dataset.get.select(obj.getFeaturesCol).first.getAs[Vector](0).size
model.withValue("thresholds", thresholds.map(_.toSeq).map(Value.doubleList)).
withValue("num_classes", Value.int(obj.numClasses)).
withValue("num_features", Value.int(numFeatures)).
withValue("tree_limit", Value.int(obj.getOrDefault(obj.treeLimit))).
withValue("missing", Value.float(obj.getOrDefault(obj.missing))).
withValue("infer_batch_size", Value.int(obj.getOrDefault(obj.inferBatchSize))).
withValue("use_external_memory", Value.boolean(obj.getOrDefault(obj.useExternalMemory))).
withValue("allow_non_zero_for_missing", Value.boolean(obj.getOrDefault(obj.allowNonZeroForMissing))).
withValue("objective", Value.string(obj.getOrDefault(obj.objective)))
}
override def load(model: Model)
(implicit context: BundleContext[SparkBundleContext]): XGBoostClassificationModel = {
val booster = Using(Files.newInputStream(context.file("xgboost.model"))) { in =>
SXGBoost.loadModel(in)
}.get
val xgb = new XGBoostClassificationModel("", model.value("num_classes").getInt, booster)
model.getValue("tree_limit").map(o => xgb.setTreeLimit(o.getInt))
model.getValue("thresholds").map(o => xgb.setThresholds(o.getDoubleList.toArray))
model.getValue("missing").map(o => xgb.setMissing(o.getFloat))
model.getValue("allow_non_zero_for_missing").map(o => xgb.setAllowNonZeroForMissing(o.getBoolean))
model.getValue("infer_batch_size").map(o => xgb.setInferBatchSize(o.getInt))
model.getValue("use_external_memory").map(o => xgb.set(xgb.useExternalMemory, o.getBoolean))
model.getValue("objective").map(o => xgb.set(xgb.objective, o.getString))
xgb
}
}
override def sparkLoad(uid: String,
shape: NodeShape,
model: XGBoostClassificationModel): XGBoostClassificationModel = {
val xgb = new XGBoostClassificationModel(uid, model.numClasses, model._booster)
if(model.isSet(model.thresholds)) xgb.setThresholds(model.getOrDefault(model.thresholds))
if(model.isSet(model.missing)) xgb.setMissing(model.getOrDefault(model.missing))
if(model.isSet(model.allowNonZeroForMissing)) xgb.setAllowNonZeroForMissing(model.getOrDefault(model.allowNonZeroForMissing))
if(model.isSet(model.inferBatchSize)) xgb.setInferBatchSize(model.getOrDefault(model.inferBatchSize))
if(model.isSet(model.treeLimit)) xgb.setTreeLimit(model.getOrDefault(model.treeLimit))
if(model.isSet(model.useExternalMemory)) xgb.set(xgb.useExternalMemory, model.getOrDefault(model.useExternalMemory))
if(model.isSet(model.objective)) xgb.set(xgb.objective, model.getOrDefault(model.objective))
xgb
}
override def sparkInputs(obj: XGBoostClassificationModel): Seq[ParamSpec] = {
Seq("features" -> obj.featuresCol)
}
override def sparkOutputs(obj: XGBoostClassificationModel): Seq[SimpleParamSpec] = {
Seq("raw_prediction" -> obj.rawPredictionCol,
"prediction" -> obj.predictionCol,
"probability" -> obj.probabilityCol,
"leaf_prediction" -> obj.leafPredictionCol,
"contrib_prediction" -> obj.contribPredictionCol)
}
}