Skip to content

Commit

Permalink
fix: [Workaround] CNTKModel does not output correct result (#1076)
Browse files Browse the repository at this point in the history
* fix: [Workaround] CNTKModel does not output correct result for ResNet50 model when running on Databricks

* Add code comment to reference issue 1075

* Only cache for non-streaming df
  • Loading branch information
memoryz committed Jun 9, 2021
1 parent 36ee274 commit 0632f1b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/main/scala/com/microsoft/ml/spark/cntk/CNTKModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,10 @@ class CNTKModel(override val uid: String) extends Model[CNTKModel] with ComplexP
val droppedDF = outputDF.drop(outputDF.columns.filter(_.startsWith(coercionPrefix)): _*)

val unbatchedDF = if (getBatchInput) {
new FlattenBatch().transform(droppedDF)
// TODO: The cache call is a workaround for issue 1075:
// https://github.com/Azure/mmlspark/issues/1075
val cacheAttempted = if (droppedDF.isStreaming) droppedDF else droppedDF.cache()
new FlattenBatch().transform(cacheAttempted)
} else {
droppedDF
}
Expand Down

0 comments on commit 0632f1b

Please sign in to comment.