Skip to content

Commit

Permalink
[SW-1658][rel-3.26] Figure out better way of caching MOJOs (v1 & v2) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke committed Oct 10, 2019
1 parent d2db204 commit a8c181c
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 53 deletions.
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.h2o.sparkling.ml.models

import org.apache.spark.expose.Logging
import org.apache.spark.sql.SparkSession

import scala.collection.mutable

trait H2OMOJOBaseCache[B, M] extends Logging {
private object Lock

private val pipelineCache = mutable.Map.empty[String, B]
private val lastAccessMap = mutable.Map.empty[String, Long]

private lazy val sparkConf = SparkSession.builder().getOrCreate().sparkContext.getConf
private lazy val cleanupRetryTimeout = sparkConf.getInt("spark.ext.h2o.mojo.destroy.timeout", 10 * 60 * 1000)
private val cleanerThread = new Thread() {
override def run(): Unit = {
while (!Thread.interrupted()) {
try {
Thread.sleep(cleanupRetryTimeout)
val toDestroy = lastAccessMap.flatMap { case (uid, lastAccess) =>
val currentDiff = System.currentTimeMillis() - lastAccess
if (currentDiff > cleanupRetryTimeout) {
logDebug(s"Removing mojo $uid from cache as it has not been used for $cleanupRetryTimeout ms.")
Some(uid)
} else {
None
}
}
Lock.synchronized {
toDestroy.map { uid =>
lastAccessMap.remove(uid)
pipelineCache.remove(uid)
}
}
} catch {
case _: InterruptedException => Thread.currentThread.interrupt()
}
}
}
}


logDebug("Cleaner thread for unused MOJOs started.")


def startCleanupThread(): Unit = {
if (!cleanerThread.isAlive) {
cleanerThread.start()
}
}

def getMojoBackend(uid: String, bytesGetter: ()=> Array[Byte], model: M): B = Lock.synchronized {
if (!pipelineCache.contains(uid)) {
//println(s"MISS MOJO pipeline: thread=${Thread.currentThread().getName}, this=${this}")
pipelineCache.put(uid, loadMojoBackend(bytesGetter(), model))
}
lastAccessMap.put(uid, System.currentTimeMillis())
pipelineCache(uid)
}

def loadMojoBackend(mojoData: Array[Byte], model: M): B
}
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql._
import scala.collection.JavaConverters._

class H2OMOJOModel(override val uid: String) extends H2OMOJOModelBase[H2OMOJOModel] with H2OMOJOPrediction {

H2OMOJOCache.startCleanupThread()
protected final val modelDetails: NullableStringParam = new NullableStringParam(this, "modelDetails", "Raw details of this model.")

setDefault(
Expand All @@ -43,29 +43,6 @@ class H2OMOJOModel(override val uid: String) extends H2OMOJOModelBase[H2OMOJOMod

override protected def outputColumnName: String = getDetailedPredictionCol()

// Some MojoModels are not serializable ( DeepLearning ), so we are reusing the mojoData to keep information about mojo model
@transient protected lazy val easyPredictModelWrapper: EasyPredictModelWrapper = {
val config = new EasyPredictModelWrapper.Config()
config.setModel(Utils.getMojoModel(getMojoData()))
config.setConvertUnknownCategoricalLevelsToNa(getConvertUnknownCategoricalLevelsToNa())
config.setConvertInvalidNumbersToNa(getConvertInvalidNumbersToNa())
if (canGenerateContributions(config.getModel)) {
config.setEnableContributions(getWithDetailedPredictionCol())
}
// always let H2O produce full output, filter later if required
config.setUseExtendedOutput(true)
new EasyPredictModelWrapper(config)
}

private def canGenerateContributions(model: GenModel): Boolean = {
model match {
case _: PredictContributionsFactory =>
val modelCategory = model.getModelCategory
modelCategory == ModelCategory.Regression || modelCategory == ModelCategory.Binomial
case _ => false
}
}

override def copy(extra: ParamMap): H2OMOJOModel = defaultCopy(extra)

override def transform(dataset: Dataset[_]): DataFrame = {
Expand Down Expand Up @@ -146,3 +123,28 @@ object H2OMOJOModel extends H2OMOJOReadable[H2OMOJOModel] with H2OMOJOLoader[H2O
model.set(model.featuresCols -> originalFeatures)
}
}

object H2OMOJOCache extends H2OMOJOBaseCache[EasyPredictModelWrapper, H2OMOJOModel] {

private def canGenerateContributions(model: GenModel): Boolean = {
model match {
case _: PredictContributionsFactory =>
val modelCategory = model.getModelCategory
modelCategory == ModelCategory.Regression || modelCategory == ModelCategory.Binomial
case _ => false
}
}

override def loadMojoBackend(mojoData: Array[Byte], model: H2OMOJOModel): EasyPredictModelWrapper = {
val config = new EasyPredictModelWrapper.Config()
config.setModel(Utils.getMojoModel(mojoData))
config.setConvertUnknownCategoricalLevelsToNa(model.getConvertUnknownCategoricalLevelsToNa())
config.setConvertInvalidNumbersToNa(model.getConvertInvalidNumbersToNa())
if (canGenerateContributions(config.getModel)) {
config.setEnableContributions(model.getWithDetailedPredictionCol())
}
// always let H2O produce full output, filter later if required
config.setUseExtendedOutput(true)
new EasyPredictModelWrapper(config)
}
}
Expand Up @@ -33,14 +33,15 @@ import scala.util.Random

class H2OMOJOPipelineModel(override val uid: String) extends H2OMOJOModelBase[H2OMOJOPipelineModel] {

@transient private lazy val mojoPipeline: MojoPipeline = {
val reader = MojoPipelineReaderBackendFactory.createReaderBackend(new ByteArrayInputStream(getMojoData()))
MojoPipeline.loadFrom(reader)
}
H2OMOJOPipelineCache.startCleanupThread()

// private parameter used to store MOJO output columns
protected final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "OutputCols")

@transient private lazy val mojoPipeline: MojoPipeline = {
H2OMOJOPipelineCache.getMojoBackend(uid, getMojoData, this)
}

case class Mojo2Prediction(preds: List[Double])

private def prepareBooleans(colType: Type, colData: Any): Any = {
Expand All @@ -60,7 +61,6 @@ class H2OMOJOPipelineModel(override val uid: String) extends H2OMOJOModelBase[H2
} else {
colData
}

}

private val modelUdf = (names: Array[String]) =>
Expand Down Expand Up @@ -204,3 +204,10 @@ object H2OMOJOPipelineModel extends H2OMOJOReadable[H2OMOJOPipelineModel] with H
model
}
}

private object H2OMOJOPipelineCache extends H2OMOJOBaseCache[MojoPipeline, H2OMOJOPipelineModel] {
override def loadMojoBackend(mojoData: Array[Byte], model: H2OMOJOPipelineModel): MojoPipeline = {
val reader = MojoPipelineReaderBackendFactory.createReaderBackend(new ByteArrayInputStream(mojoData))
MojoPipeline.loadFrom(reader)
}
}
Expand Up @@ -34,7 +34,8 @@ trait H2OMOJOPrediction
self: H2OMOJOModel =>

def extractPredictionColContent(): Column = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => extractBinomialPredictionColContent()
case ModelCategory.Regression => extractRegressionPredictionColContent()
case ModelCategory.Multinomial => extractMultinomialPredictionColContent()
Expand All @@ -43,12 +44,13 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => extractDimReductionSimplePredictionColContent()
case ModelCategory.WordEmbedding => extractWordEmbeddingPredictionColContent()
case ModelCategory.AnomalyDetection => extractAnomalyPredictionColContent()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}

def getPredictionUDF(): UserDefinedFunction = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => getBinomialPredictionUDF()
case ModelCategory.Regression => getRegressionPredictionUDF()
case ModelCategory.Multinomial => getMultinomialPredictionUDF()
Expand All @@ -57,12 +59,13 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => getDimReductionPredictionUDF()
case ModelCategory.WordEmbedding => getWordEmbeddingPredictionUDF()
case ModelCategory.AnomalyDetection => getAnomalyPredictionUDF()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}

override def getPredictionColSchema(): Seq[StructField] = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => getBinomialPredictionColSchema()
case ModelCategory.Regression => getRegressionPredictionColSchema()
case ModelCategory.Multinomial => getMultinomialPredictionColSchema()
Expand All @@ -71,12 +74,13 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => getDimReductionPredictionColSchema()
case ModelCategory.WordEmbedding => getWordEmbeddingPredictionColSchema()
case ModelCategory.AnomalyDetection => getAnomalyPredictionColSchema()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}

override def getDetailedPredictionColSchema(): Seq[StructField] = {
easyPredictModelWrapper.getModelCategory match {
val predictWrapper = H2OMOJOCache.getMojoBackend(uid, getMojoData, this)
predictWrapper.getModelCategory match {
case ModelCategory.Binomial => getBinomialDetailedPredictionColSchema()
case ModelCategory.Regression => getRegressionDetailedPredictionColSchema()
case ModelCategory.Multinomial => getMultinomialDetailedPredictionColSchema()
Expand All @@ -85,7 +89,7 @@ trait H2OMOJOPrediction
case ModelCategory.DimReduction => getDimReductionDetailedPredictionColSchema()
case ModelCategory.WordEmbedding => getWordEmbeddingDetailedPredictionColSchema()
case ModelCategory.AnomalyDetection => getAnomalyDetailedPredictionColSchema()
case _ => throw new RuntimeException("Unknown model category " + easyPredictModelWrapper.getModelCategory)
case _ => throw new RuntimeException("Unknown model category " + predictWrapper.getModelCategory)
}
}
}
Expand Up @@ -28,7 +28,7 @@ trait H2OMOJOPredictionAnomaly extends H2OMOJOPredictionUtils {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted score.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictAnomalyDetection(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictAnomalyDetection(RowConverter.toH2ORowData(r))
Base(pred.score, pred.normalizedScore)
}
}
Expand Down
Expand Up @@ -30,7 +30,7 @@ trait H2OMOJOPredictionAutoEncoder {
" 'detailed_prediction' instead. The 'prediction' column will contain directly the original input the" +
" way AutoEncoder model sees it (1-hot encoded categorical values) .")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictAutoEncoder(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictAutoEncoder(RowConverter.toH2ORowData(r))
Base(pred.original, pred.reconstructed)
}
}
Expand Down
Expand Up @@ -36,10 +36,10 @@ trait H2OMOJOPredictionBinomial {
def getBinomialPredictionUDF(): UserDefinedFunction = {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted label.")
if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
if (getWithDetailedPredictionCol()) {
udf[WithCalibrationAndContribution, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
WithCalibrationAndContribution(
pred.classProbabilities(0),
pred.classProbabilities(1),
Expand All @@ -50,7 +50,7 @@ trait H2OMOJOPredictionBinomial {
}
} else {
udf[WithCalibration, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
WithCalibration(
pred.classProbabilities(0),
pred.classProbabilities(1),
Expand All @@ -61,7 +61,7 @@ trait H2OMOJOPredictionBinomial {
}
} else if (getWithDetailedPredictionCol()) {
udf[WithContribution, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
WithContribution(
pred.classProbabilities(0),
pred.classProbabilities(1),
Expand All @@ -70,7 +70,7 @@ trait H2OMOJOPredictionBinomial {
}
} else {
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictBinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictBinomial(RowConverter.toH2ORowData(r))
Base(
pred.classProbabilities(0),
pred.classProbabilities(1)
Expand All @@ -82,7 +82,7 @@ trait H2OMOJOPredictionBinomial {
private val baseFields = Seq("p0", "p1").map(StructField(_, DoubleType, nullable = false))

def getBinomialPredictionColSchema(): Seq[StructField] = {
val fields = if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
val fields = if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
baseFields ++ Seq("p0_calibrated", "p1_calibrated").map(StructField(_, DoubleType, nullable = false))
} else {
baseFields
Expand All @@ -92,7 +92,7 @@ trait H2OMOJOPredictionBinomial {
}

def getBinomialDetailedPredictionColSchema(): Seq[StructField] = {
val fields = if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
val fields = if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
val base = baseFields ++ Seq("p0_calibrated", "p1_calibrated").map(StructField(_, DoubleType, nullable = false))
if (getWithDetailedPredictionCol()) {
base ++ Seq(StructField("contributions", ArrayType(FloatType)))
Expand All @@ -109,7 +109,7 @@ trait H2OMOJOPredictionBinomial {
}

def extractBinomialPredictionColContent(): Column = {
if (supportsCalibratedProbabilities(easyPredictModelWrapper)) {
if (supportsCalibratedProbabilities(H2OMOJOCache.getMojoBackend(uid, getMojoData, this))) {
extractColumnsAsNested(Seq("p0", "p1", "p0_calibrated", "p1_calibrated"))
} else {
extractColumnsAsNested(Seq("p0", "p1"))
Expand Down
Expand Up @@ -27,7 +27,7 @@ trait H2OMOJOPredictionClustering {

def getClusteringPredictionUDF(): UserDefinedFunction = {
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictClustering(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictClustering(RowConverter.toH2ORowData(r))
Base(pred.cluster, pred.distances)
}
}
Expand Down
Expand Up @@ -28,7 +28,7 @@ trait H2OMOJOPredictionDimReduction {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted dimensions.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictDimReduction(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictDimReduction(RowConverter.toH2ORowData(r))
Base(pred.dimensions)
}
}
Expand Down
Expand Up @@ -29,7 +29,7 @@ trait H2OMOJOPredictionMultinomial {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted label.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictMultinomial(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictMultinomial(RowConverter.toH2ORowData(r))
Base(pred.classProbabilities)
}
}
Expand Down
Expand Up @@ -30,12 +30,12 @@ trait H2OMOJOPredictionRegression {
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted value.")
if (getWithDetailedPredictionCol()) {
udf[WithContributions, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictRegression(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictRegression(RowConverter.toH2ORowData(r))
WithContributions(pred.value, pred.contributions)
}
} else {
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictRegression(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictRegression(RowConverter.toH2ORowData(r))
Base(pred.value)
}
}
Expand Down
Expand Up @@ -31,7 +31,7 @@ trait H2OMOJOPredictionWordEmbedding {
logWarning("Starting from the next major release, the content of 'prediction' column will be generated to " +
" 'detailed_prediction' instead. The 'prediction' column will contain directly the predicted word embeddings.")
udf[Base, Row] { r: Row =>
val pred = easyPredictModelWrapper.predictWord2Vec(RowConverter.toH2ORowData(r))
val pred = H2OMOJOCache.getMojoBackend(uid, getMojoData, this).predictWord2Vec(RowConverter.toH2ORowData(r))
Base(pred.wordEmbeddings)
}
}
Expand Down

0 comments on commit a8c181c

Please sign in to comment.