Skip to content

Commit

Permalink
bug fix: DLModel prediction (#2194)
Browse files Browse the repository at this point in the history
* bug fix: DLModel prediction (#4)

Make sure DLModel.train=False when predicting in pipeline API

* 1. broadcast transformer in DLModel.transform ; 2. remove useless ut
  • Loading branch information
sperlingxx authored and qiuxin2012 committed Feb 7, 2018
1 parent 0b3e8e5 commit 43f023f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,13 @@ class DLModel[@specialized(Float, Double) T: ClassTag](
val featureColIndex = dataFrame.schema.fieldIndex($(featuresCol))
val featureFunc = getConvertFunc(featureType)
val sc = dataFrame.sqlContext.sparkContext
val modelBroadCast = ModelBroadcast[T]().broadcast(sc, model)
val modelBroadCast = ModelBroadcast[T]().broadcast(sc, model.evaluate())
val localBatchSize = $(batchSize)
val transformerBC = sc.broadcast(SampleToMiniBatch[T](localBatchSize))

val resultRDD = dataFrame.rdd.mapPartitions { rowIter =>
val localModel = modelBroadCast.value()
val transformer = transformerBC.value.cloneTransformer()
rowIter.grouped(localBatchSize).flatMap { rowBatch =>
val samples = rowBatch.map { row =>
val features = featureFunc(row, featureColIndex)
Expand All @@ -401,7 +403,7 @@ class DLModel[@specialized(Float, Double) T: ClassTag](
}
Sample(Tensor(featureBuffer.toArray, featureSize))
}.toIterator
val predictions = SampleToMiniBatch(localBatchSize).apply(samples).flatMap { batch =>
val predictions = transformer(samples).flatMap { batch =>
val batchResult = localModel.forward(batch.getInput())
batchResult.toTensor.split(1).map(outputToPrediction)
}
Expand Down

0 comments on commit 43f023f

Please sign in to comment.