Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SW-1658][rel-3.26] Figure out better way of caching MOJOs (v1 & v2) #1568

Merged
merged 1 commit into from Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.models

import org.apache.spark.expose.Logging
import org.apache.spark.sql.SparkSession

import scala.collection.mutable

trait H2OMOJOBaseCache[B, M] extends Logging {
private object Lock

private val pipelineCache = mutable.Map.empty[String, B]
private val lastAccessMap = mutable.Map.empty[String, Long]

private lazy val sparkConf = SparkSession.builder().getOrCreate().sparkContext.getConf
private lazy val cleanupRetryTimeout = sparkConf.getInt("spark.ext.h2o.mojo.destroy.timeout", 10 * 60 * 1000)
private val cleanerThread = new Thread() {
override def run(): Unit = {
while (!Thread.interrupted()) {
try {
Thread.sleep(cleanupRetryTimeout)
val toDestroy = lastAccessMap.flatMap { case (uid, lastAccess) =>
val currentDiff = System.currentTimeMillis() - lastAccess
if (currentDiff > cleanupRetryTimeout) {
logDebug(s"Removing mojo $uid from cache as it has not been used for $cleanupRetryTimeout ms.")
Some(uid)
} else {
None
}
}
Lock.synchronized {
toDestroy.map { uid =>
lastAccessMap.remove(uid)
pipelineCache.remove(uid)
}
}
} catch {
case _: InterruptedException => Thread.currentThread.interrupt()
}
}
}
}


logDebug("Cleaner thread for unused MOJOs started.")


def startCleanupThread(): Unit = {
if (!cleanerThread.isAlive) {
cleanerThread.start()
}
}

def getMojoBackend(uid: String, bytesGetter: ()=> Array[Byte], model: M): B = Lock.synchronized {
if (!pipelineCache.contains(uid)) {
//println(s"MISS MOJO pipeline: thread=${Thread.currentThread().getName}, this=${this}")
pipelineCache.put(uid, loadMojoBackend(bytesGetter(), model))
}
lastAccessMap.put(uid, System.currentTimeMillis())
pipelineCache(uid)
}

def loadMojoBackend(mojoData: Array[Byte], model: M): B
}
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql._
import scala.collection.JavaConverters._

class H2OMOJOModel(override val uid: String) extends H2OMOJOModelBase[H2OMOJOModel] with H2OMOJOPrediction {

H2OMOJOCache.startCleanupThread()
protected final val modelDetails: NullableStringParam = new NullableStringParam(this, "modelDetails", "Raw details of this model.")

setDefault(
Expand All @@ -43,29 +43,6 @@ class H2OMOJOModel(override val uid: String) extends H2OMOJOModelBase[H2OMOJOMod

override protected def outputColumnName: String = getDetailedPredictionCol()

// Some MojoModels are not serializable ( DeepLearning ), so we are reusing the mojoData to keep information about mojo model
@transient protected lazy val easyPredictModelWrapper: EasyPredictModelWrapper = {
val config = new EasyPredictModelWrapper.Config()
config.setModel(Utils.getMojoModel(getMojoData()))
config.setConvertUnknownCategoricalLevelsToNa(getConvertUnknownCategoricalLevelsToNa())
config.setConvertInvalidNumbersToNa(getConvertInvalidNumbersToNa())
if (canGenerateContributions(config.getModel)) {
config.setEnableContributions(getWithDetailedPredictionCol())
}
// always let H2O produce full output, filter later if required
config.setUseExtendedOutput(true)
new EasyPredictModelWrapper(config)
}

private def canGenerateContributions(model: GenModel): Boolean = {
model match {
case _: PredictContributionsFactory =>
val modelCategory = model.getModelCategory
modelCategory == ModelCategory.Regression || modelCategory == ModelCategory.Binomial
case _ => false
}
}

override def copy(extra: ParamMap): H2OMOJOModel = defaultCopy(extra)

override def transform(dataset: Dataset[_]): DataFrame = {
Expand Down Expand Up @@ -146,3 +123,28 @@ object H2OMOJOModel extends H2OMOJOReadable[H2OMOJOModel] with H2OMOJOLoader[H2O
model.set(model.featuresCols -> originalFeatures)
}
}

object H2OMOJOCache extends H2OMOJOBaseCache[EasyPredictModelWrapper, H2OMOJOModel] {

private def canGenerateContributions(model: GenModel): Boolean = {
model match {
case _: PredictContributionsFactory =>
val modelCategory = model.getModelCategory
modelCategory == ModelCategory.Regression || modelCategory == ModelCategory.Binomial
case _ => false
}
}

override def loadMojoBackend(mojoData: Array[Byte], model: H2OMOJOModel): EasyPredictModelWrapper = {
val config = new EasyPredictModelWrapper.Config()
config.setModel(Utils.getMojoModel(mojoData))
config.setConvertUnknownCategoricalLevelsToNa(model.getConvertUnknownCategoricalLevelsToNa())
config.setConvertInvalidNumbersToNa(model.getConvertInvalidNumbersToNa())
if (canGenerateContributions(config.getModel)) {
config.setEnableContributions(model.getWithDetailedPredictionCol())
}
// always let H2O produce full output, filter later if required
config.setUseExtendedOutput(true)
new EasyPredictModelWrapper(config)
}
}
Expand Up @@ -33,14 +33,15 @@ import scala.util.Random

class H2OMOJOPipelineModel(override val uid: String) extends H2OMOJOModelBase[H2OMOJOPipelineModel] {

@transient private lazy val mojoPipeline: MojoPipeline = {
val reader = MojoPipelineReaderBackendFactory.createReaderBackend(new ByteArrayInputStream(getMojoData()))
MojoPipeline.loadFrom(reader)
}
H2OMOJOPipelineCache.startCleanupThread()

// private parameter used to store MOJO output columns
protected final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "OutputCols")

@transient private lazy val mojoPipeline: MojoPipeline = {
H2OMOJOPipelineCache.getMojoBackend(uid, getMojoData, this)
}

case class Mojo2Prediction(preds: List[Double])

private def prepareBooleans(colType: Type, colData: Any): Any = {
Expand All @@ -60,7 +61,6 @@ class H2OMOJOPipelineModel(override val uid: String) extends H2OMOJOModelBase[H2
} else {
colData
}

}

private val modelUdf = (names: Array[String]) =>
Expand Down Expand Up @@ -204,3 +204,10 @@ object H2OMOJOPipelineModel extends H2OMOJOReadable[H2OMOJOPipelineModel] with H
model
}
}

private object H2OMOJOPipelineCache extends H2OMOJOBaseCache[MojoPipeline, H2OMOJOPipelineModel] {
override def loadMojoBackend(mojoData: Array[Byte], model: H2OMOJOPipelineModel): MojoPipeline = {
val reader = MojoPipelineReaderBackendFactory.createReaderBackend(new ByteArrayInputStream(mojoData))
MojoPipeline.loadFrom(reader)
}
}
Expand Up @@ -34,7 +34,8 @@ trait H2OMOJOPrediction
self: H2OMOJOModel =>

def extractPredictionColContent(): Column = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => extractBinomialPredictionColContent()
case ModelCategory.Regression => extractRegressionPredictionColContent()
case ModelCategory.Multinomial => extractMultinomialPredictionColContent()
Expand All @@ -43,12 +44,13 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => extractDimReductionSimplePredictionColContent()
case ModelCategory.WordEmbedding => extractWordEmbeddingPredictionColContent()
case ModelCategory.AnomalyDetection => extractAnomalyPredictionColContent()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}

def getPredictionUDF(): UserDefinedFunction = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => getBinomialPredictionUDF()
case ModelCategory.Regression => getRegressionPredictionUDF()
case ModelCategory.Multinomial => getMultinomialPredictionUDF()
Expand All @@ -57,12 +59,13 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => getDimReductionPredictionUDF()
case ModelCategory.WordEmbedding => getWordEmbeddingPredictionUDF()
case ModelCategory.AnomalyDetection => getAnomalyPredictionUDF()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}

override def getPredictionColSchema(): Seq[StructField] = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => getBinomialPredictionColSchema()
case ModelCategory.Regression => getRegressionPredictionColSchema()
case ModelCategory.Multinomial => getMultinomialPredictionColSchema()
Expand All @@ -71,12 +74,13 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => getDimReductionPredictionColSchema()
case ModelCategory.WordEmbedding => getWordEmbeddingPredictionColSchema()
case ModelCategory.AnomalyDetection => getAnomalyPredictionColSchema()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}

override def getDetailedPredictionColSchema(): Seq[StructField] = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => getBinomialDetailedPredictionColSchema()
case ModelCategory.Regression => getRegressionDetailedPredictionColSchema()
case ModelCategory.Multinomial => getMultinomialDetailedPredictionColSchema()
Expand All @@ -85,7 +89,7 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => getDimReductionDetailedPredictionColSchema()
case ModelCategory.WordEmbedding => getWordEmbeddingDetailedPredictionColSchema()
case ModelCategory.AnomalyDetection => getAnomalyDetailedPredictionColSchema()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}
}
Expand Up @@ -28,7 +28,7 @@ trait H2OMOJOPredictionAnomaly extends H2OMOJOPredictionUtils {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted score.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictAnomalyDetection(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictAnomalyDetection(RowConverter.toH2ORowData(r))
Base(pred.score, pred.normalizedScore)
}
}
Expand Down
Expand Up @@ -30,7 +30,7 @@ trait H2OMOJOPredictionAutoEncoder {
" 'detailed_prediction' instead. The 'prediction' column will contain directly the original input the" +
" way AutoEncoder model sees it (1-hot encoded categorical values) .")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictAutoEncoder(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictAutoEncoder(RowConverter.toH2ORowData(r))
Base(pred.original, pred.reconstructed)
}
}
Expand Down
Expand Up @@ -36,10 +36,10 @@ trait H2OMOJOPredictionBinomial {
def getBinomialPredictionUDF(): UserDefinedFunction = {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted label.")
if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
if (getWithDetailedPredictionCol()) {
udf[WithCalibrationAndContribution, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
WithCalibrationAndContribution(
pred.classProbabilities(0),
pred.classProbabilities(1),
Expand All @@ -50,7 +50,7 @@ trait H2OMOJOPredictionBinomial {
}
} else {
udf[WithCalibration, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
WithCalibration(
pred.classProbabilities(0),
pred.classProbabilities(1),
Expand All @@ -61,7 +61,7 @@ trait H2OMOJOPredictionBinomial {
}
} else if (getWithDetailedPredictionCol()) {
udf[WithContribution, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
WithContribution(
pred.classProbabilities(0),
pred.classProbabilities(1),
Expand All @@ -70,7 +70,7 @@ trait H2OMOJOPredictionBinomial {
}
} else {
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
Base(
pred.classProbabilities(0),
pred.classProbabilities(1)
Expand All @@ -82,7 +82,7 @@ trait H2OMOJOPredictionBinomial {
private val baseFields = Seq("p0", "p1").map(StructField(_, DoubleType, nullable = false))

def getBinomialPredictionColSchema(): Seq[StructField] = {
val fields = if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
val fields = if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
baseFields ++ Seq("p0_calibrated", "p1_calibrated").map(StructField(_, DoubleType, nullable = false))
} else {
baseFields
Expand All @@ -92,7 +92,7 @@ trait H2OMOJOPredictionBinomial {
}

def getBinomialDetailedPredictionColSchema(): Seq[StructField] = {
val fields = if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
val fields = if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
val base = baseFields ++ Seq("p0_calibrated", "p1_calibrated").map(StructField(_, DoubleType, nullable = false))
if (getWithDetailedPredictionCol()) {
base ++ Seq(StructField("contributions", ArrayType(FloatType)))
Expand All @@ -109,7 +109,7 @@ trait H2OMOJOPredictionBinomial {
}

def extractBinomialPredictionColContent(): Column = {
if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
extractColumnsAsNested(Seq("p0", "p1", "p0_calibrated", "p1_calibrated"))
} else {
extractColumnsAsNested(Seq("p0", "p1"))
Expand Down
Expand Up @@ -27,7 +27,7 @@ trait H2OMOJOPredictionClustering {

def getClusteringPredictionUDF(): UserDefinedFunction = {
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictClustering(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictClustering(RowConverter.toH2ORowData(r))
Base(pred.cluster, pred.distances)
}
}
Expand Down
Expand Up @@ -28,7 +28,7 @@ trait H2OMOJOPredictionDimReduction {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted dimensions.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictDimReduction(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictDimReduction(RowConverter.toH2ORowData(r))
Base(pred.dimensions)
}
}
Expand Down
Expand Up @@ -29,7 +29,7 @@ trait H2OMOJOPredictionMultinomial {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted label.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictMultinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictMultinomial(RowConverter.toH2ORowData(r))
Base(pred.classProbabilities)
}
}
Expand Down
Expand Up @@ -30,12 +30,12 @@ trait H2OMOJOPredictionRegression {
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted value.")
if (getWithDetailedPredictionCol()) {
udf[WithContributions, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictRegression(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictRegression(RowConverter.toH2ORowData(r))
WithContributions(pred.value, pred.contributions)
}
} else {
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictRegression(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictRegression(RowConverter.toH2ORowData(r))
Base(pred.value)
}
}
Expand Down
Expand Up @@ -31,7 +31,7 @@ trait H2OMOJOPredictionWordEmbedding {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted word embeddings.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictWord2Vec(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictWord2Vec(RowConverter.toH2ORowData(r))
Base(pred.wordEmbeddings)
}
}
Expand Down