Skip to content

Commit

Permalink
Fix serialization fuzzing error
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Jul 5, 2019
1 parent f6df907 commit 07316a8
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Expand Up @@ -19,7 +19,7 @@ def getPythonVersion(baseVersion: String): String = {

val baseVersion = "0.17.1"
val condaEnvName = "mmlspark"
name := "mmlspark-build"
name := "mmlspark"
organization := "com.microsoft.ml.spark"
version := getVersion(baseVersion)
scalaVersion := "2.11.12"
Expand Down
Expand Up @@ -12,6 +12,7 @@ object Config {

val topDir = BuildInfo.baseDirectory
val version = BuildInfo.version
val packageName = BuildInfo.name
val targetDir = new File(topDir, "target/scala-2.11")
val scalaSrcDir = "src/main/scala"

Expand Down
Expand Up @@ -53,7 +53,7 @@ abstract class PySparkWrapperParamsTest(entryPoint: Params,
|spark = SparkSession.builder \\
| .master("local[*]") \\
| .appName("$entryPointName") \\
| .config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark-build_2.11:$version") \\
| .config("spark.jars.packages", "com.microsoft.ml.spark:${packageName}_2.11:$version") \\
| .config("spark.executor.heartbeatInterval", "60s") \\
| .getOrCreate()
|
Expand Down
2 changes: 1 addition & 1 deletion src/test/python/mmlspark/recommendation/test_ranking.py
Expand Up @@ -18,7 +18,7 @@
spark = SparkSession.builder \
.master("local[*]") \
.appName("_FindBestModel") \
.config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark-build_2.11:" + os.environ["MML_VERSION"]) \
.config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark_2.11:" + os.environ["MML_VERSION"]) \
.config("spark.executor.heartbeatInterval", "60s") \
.getOrCreate()

Expand Down
Expand Up @@ -10,6 +10,8 @@ import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.commons.compress.utils.IOUtils
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.scalactic.Equality
import org.scalatest.Assertion

trait SpeechKey {
lazy val speechKey = sys.env.getOrElse("SPEECH_API_KEY", Secrets.speechApiKey)
Expand All @@ -35,6 +37,11 @@ class SpeechToTextSuite extends TransformerFuzzing[SpeechToText]
Tuple1(audioBytes)
).toDF("audio")

override lazy val dfEq = new Equality[DataFrame] {
override def areEqual(a: DataFrame, b: Any): Boolean =
baseDfEq.areEqual(a.drop("audio"), b.asInstanceOf[DataFrame].drop("audio"))
}

test("Basic Usage") {
val toObj = SpeechResponse.makeFromRowConverter
val result = toObj(stt.setFormat("simple")
Expand Down

0 comments on commit 07316a8

Please sign in to comment.