From 0319650f275c8f4539c1ab14d4ac0660352ae32e Mon Sep 17 00:00:00 2001 From: Mark Hamilton Date: Wed, 1 Jul 2020 12:11:32 -0400 Subject: [PATCH] build: make python test loop easier: --- build.sbt | 31 ++++++++++--------- .../ml/spark/codegen/CodegenConfig.scala | 3 ++ src/main/python/setup.py | 18 ++++++++++- src/test/python/mmlsparktest/spark.py | 3 +- 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/build.sbt b/build.sbt index efcb3ad64e..b351cbaff0 100644 --- a/build.sbt +++ b/build.sbt @@ -60,8 +60,12 @@ cleanCondaEnvTask := { new File(".")) ! s.log } +def isWindows: Boolean = { + sys.props("os.name").toLowerCase.contains("windows") +} + def osPrefix: Seq[String] = { - if (sys.props("os.name").toLowerCase.contains("windows")) { + if (isWindows) { Seq("cmd", "/C") } else { Seq() @@ -98,6 +102,15 @@ generatePythonDoc := { } +val pythonizedVersion = settingKey[String]("Pythonized version") +pythonizedVersion := { + if (version.value.contains("-")){ + version.value.split("-".head).head + ".dev1" + }else{ + version.value + } +} + def uploadToBlob(source: String, dest: String, container: String, log: ManagedLogger, accountName: String="mmlspark"): Int = { @@ -161,14 +174,6 @@ publishR := { singleUploadToBlob(rPackage.toString,rPackage.getName, "rrr", s.log) } -def pythonizeVersion(v: String): String = { - if (v.contains("-")){ - v.split("-".head).head + ".dev1" - }else{ - v - } -} - packagePythonTask := { val s = streams.value (run in IntegrationTest2).toTask("").value @@ -180,8 +185,7 @@ packagePythonTask := { Process( activateCondaEnv ++ Seq(s"python", "setup.py", "bdist_wheel", "--universal", "-d", s"${pythonPackageDir.absolutePath}"), - pythonSrcDir, - "MML_PY_VERSION" -> pythonizeVersion(version.value)) ! s.log + pythonSrcDir) ! s.log } val installPipPackageTask = TaskKey[Unit]("installPipPackage", "install python sdk") @@ -192,7 +196,7 @@ installPipPackageTask := { packagePythonTask.value Process( activateCondaEnv ++ Seq("pip", "install", - s"mmlspark-${pythonizeVersion(version.value)}-py2.py3-none-any.whl"), + s"mmlspark-${pythonizedVersion.value}-py2.py3-none-any.whl"), pythonPackageDir) ! s.log } @@ -211,7 +215,6 @@ testPythonTask := { "mmlsparktest" ), new File("target/scala-2.11/generated/test/python/"), - "MML_VERSION" -> version.value ) ! s.log } @@ -319,7 +322,7 @@ val settings = Seq( logBuffered in Test := false, buildInfoKeys := Seq[BuildInfoKey]( name, version, scalaVersion, sbtVersion, - baseDirectory, datasetDir), + baseDirectory, datasetDir, pythonizedVersion), parallelExecution in Test := false, test in assembly := {}, assemblyMergeStrategy in assembly := { diff --git a/src/it/scala/com/microsoft/ml/spark/codegen/CodegenConfig.scala b/src/it/scala/com/microsoft/ml/spark/codegen/CodegenConfig.scala index 05e547b5f3..6f190cabb2 100644 --- a/src/it/scala/com/microsoft/ml/spark/codegen/CodegenConfig.scala +++ b/src/it/scala/com/microsoft/ml/spark/codegen/CodegenConfig.scala @@ -59,6 +59,9 @@ object Config { |CNTK library, images, and text. |"\"" | + |__version__ = "${BuildInfo.pythonizedVersion}" + |__spark_package_version__ = "${BuildInfo.version}" + | |$importString |""".stripMargin } diff --git a/src/main/python/setup.py b/src/main/python/setup.py index 47d6e16422..acd8bcc75a 100644 --- a/src/main/python/setup.py +++ b/src/main/python/setup.py @@ -3,10 +3,26 @@ import os from setuptools import setup, find_packages +import codecs +import os.path + +def read(rel_path): + here = os.path.abspath(os.path.dirname(__file__)) + with codecs.open(os.path.join(here, rel_path), 'r') as fp: + return fp.read() + +def get_version(rel_path): + for line in read(rel_path).splitlines(): + if line.startswith('__version__'): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + else: + raise RuntimeError("Unable to find version string.") + setup( name="mmlspark", - version=os.environ["MML_PY_VERSION"], + version=get_version("mmlspark/__init__.py"), description="Microsoft ML for Spark", long_description="Microsoft ML for Apache Spark contains Microsoft's open source " + "contributions to the Apache Spark ecosystem", diff --git a/src/test/python/mmlsparktest/spark.py b/src/test/python/mmlsparktest/spark.py index 25c356c915..dd097e0c3a 100644 --- a/src/test/python/mmlsparktest/spark.py +++ b/src/test/python/mmlsparktest/spark.py @@ -3,11 +3,12 @@ from pyspark.sql import SparkSession, SQLContext import os +import mmlspark spark = SparkSession.builder \ .master("local[*]") \ .appName("PysparkTests") \ - .config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark_2.11:" + os.environ["MML_VERSION"]) \ + .config("spark.jars.packages", "com.microsoft.ml.spark:mmlspark_2.11:" + mmlspark.__spark_package_version__) \ .config("spark.executor.heartbeatInterval", "60s") \ .config("spark.sql.shuffle.partitions", 10) \ .config("spark.sql.crossJoin.enabled", "true") \