Skip to content

Commit

Permalink
chore: Adding Spark34 support (#2052) (#2116)
Browse files Browse the repository at this point in the history
* chore: bump to spark 3.4.1

---------

Co-authored-by: Jessica Wang <jessiwang@microsoft.com>
Co-authored-by: Scott Votaw <svotaw@gmail.com>
Co-authored-by: Brendan Walsh <37676373+BrendanWalsh@users.noreply.github.com>
Co-authored-by: JessicaXYWang <108437381+JessicaXYWang@users.noreply.github.com>

fixes

Co-authored-by: Keerthi Yanda <98137159+KeerthiYandaOS@users.noreply.github.com>
  • Loading branch information
mhamilton723 and KeerthiYandaOS committed Nov 1, 2023
1 parent 903dc6b commit c2fdb05
Show file tree
Hide file tree
Showing 46 changed files with 288 additions and 262 deletions.
22 changes: 11 additions & 11 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ import org.apache.commons.io.FileUtils
import sbt.ExclusionRule

import java.io.File
import java.net.URL
import scala.xml.transform.{RewriteRule, RuleTransformer}
import scala.xml.{Node => XmlNode, NodeSeq => XmlNodeSeq, _}

val condaEnvName = "synapseml"
val sparkVersion = "3.2.3"
val sparkVersion = "3.4.1"
name := "synapseml"
ThisBuild / organization := "com.microsoft.azure"
ThisBuild / scalaVersion := "2.12.15"
ThisBuild / scalaVersion := "2.12.17"

val scalaMajorVersion = 2.12

Expand All @@ -21,23 +20,24 @@ val excludes = Seq(
)

val coreDependencies = Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % "compile",
// Excluding protobuf-java, as spark-core is bringing the older version transitively.
"org.apache.spark" %% "spark-core" % sparkVersion % "compile" exclude("com.google.protobuf", "protobuf-java"),
"org.apache.spark" %% "spark-mllib" % sparkVersion % "compile",
"org.apache.spark" %% "spark-avro" % sparkVersion % "provided",
"org.apache.spark" %% "spark-avro" % sparkVersion % "compile",
"org.apache.spark" %% "spark-tags" % sparkVersion % "test",
"com.globalmentor" % "hadoop-bare-naked-local-fs" % "0.1.0" % "test",
"org.scalatest" %% "scalatest" % "3.2.14" % "test")
val extraDependencies = Seq(
"commons-lang" % "commons-lang" % "2.6",
"org.scalactic" %% "scalactic" % "3.2.14",
"io.spray" %% "spray-json" % "1.3.5",
"com.jcraft" % "jsch" % "0.1.54",
"org.apache.httpcomponents.client5" % "httpclient5" % "5.1.3",
"org.apache.httpcomponents" % "httpmime" % "4.5.13",
"com.linkedin.isolation-forest" %% "isolation-forest_3.2.0" % "2.0.8",
// Although breeze 1.2 is already provided by Spark, this is needed for Azure Synapse Spark 3.2 pools.
// Otherwise a NoSuchMethodError will be thrown by interpretability code. This problem only happens
// to Azure Synapse Spark 3.2 pools.
"org.scalanlp" %% "breeze" % "1.2"
"com.linkedin.isolation-forest" %% "isolation-forest_3.4.1" % "3.0.3",
// Although breeze 2.1.0 is already provided by Spark, this is needed for Azure Synapse Spark 3.4 pools.
// Otherwise a NoSuchMethodError will be thrown by interpretability code.
"org.scalanlp" %% "breeze" % "2.1.0"
).map(d => d excludeAll (excludes: _*))
val dependencies = coreDependencies ++ extraDependencies

Expand Down Expand Up @@ -70,7 +70,7 @@ pomPostProcess := pomPostFunc

val getDatasetsTask = TaskKey[Unit]("getDatasets", "download datasets used for testing")
val datasetName = "datasets-2023-04-03.tgz"
val datasetUrl = new URI(s"https://mmlspark.blob.core.windows.net/installers/$datasetName").toURL()
val datasetUrl = new URI(s"https://mmlspark.blob.core.windows.net/installers/$datasetName").toURL
val datasetDir = settingKey[File]("The directory that holds the dataset")
ThisBuild / datasetDir := {
join((Compile / packageBin / artifactPath).value.getParentFile,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ object PyCodegen {
// There's `Already borrowed` error found in transformers 4.16.2 when using tokenizers
s"""extras_require={"extras": [
| "cmake",
| "horovod==0.25.0",
| "horovod==0.28.1",
| "pytorch_lightning>=1.5.0,<1.5.10",
| "torch==1.11.0",
| "torchvision>=0.12.0",
| "transformers==4.15.0",
| "torch==1.13.1",
| "torchvision>=0.14.1",
| "transformers==4.32.1",
| "petastorm>=0.12.0",
| "huggingface-hub>=0.8.1",
|]},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object PackageUtils {

val PackageName = s"synapseml_$ScalaVersionSuffix"
val PackageMavenCoordinate = s"$PackageGroup:$PackageName:${BuildInfo.version}"
private val AvroCoordinate = "org.apache.spark:spark-avro_2.12:3.3.1"
private val AvroCoordinate = "org.apache.spark:spark-avro_2.12:3.4.1"
val PackageRepository: String = SparkMLRepository

// If testing onnx package with snapshots repo, make sure to switch to using
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

package com.microsoft.azure.synapse.ml.exploratory

import breeze.stats.distributions.ChiSquared
import breeze.stats.distributions.{ChiSquared, RandBasis}
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
Expand Down Expand Up @@ -261,6 +261,7 @@ private[exploratory] case class DistributionMetrics(numFeatures: Int,

// Calculates left-tailed p-value from degrees of freedom and chi-squared test statistic
def chiSquaredPValue: Column = {
implicit val rand: RandBasis = RandBasis.mt0
val degOfFreedom = numFeatures - 1
val scoreCol = chiSquaredTestStatistic
val chiSqPValueUdf = udf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ private[ml] class HadoopFileReader(file: PartitionedFile,

private val iterator = {
val fileSplit = new FileSplit(
new Path(new URI(file.filePath)),
new Path(new URI(file.filePath.toString())),
file.start,
file.length,
Array.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

package com.microsoft.azure.synapse.ml.nn

import breeze.linalg.functions.euclideanDistance
import breeze.linalg.{DenseVector, norm, _}
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,20 @@ object SparkHelpers {

def flatten(ratings: Dataset[_], num: Int, dstOutputColumn: String, srcOutputColumn: String): DataFrame = {
import ratings.sparkSession.implicits._

val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
.toDF("id", "recommendations")
import org.apache.spark.sql.functions.{collect_top_k, struct}

val arrayType = ArrayType(
new StructType()
.add(dstOutputColumn, IntegerType)
.add("rating", FloatType)
.add(Constants.RatingCol, FloatType)
)
recs.select(col("id").as(srcOutputColumn), col("recommendations").cast(arrayType))

ratings.toDF(srcOutputColumn, dstOutputColumn, Constants.RatingCol).groupBy(srcOutputColumn)
.agg(collect_top_k(struct(Constants.RatingCol, dstOutputColumn), num, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.toDF(srcOutputColumn, Constants.Recommendations)
.withColumn(Constants.Recommendations, col(Constants.Recommendations).cast(arrayType))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class PatchedImageFileFormat extends ImageFileFormat with Serializable with Logg
Iterator(emptyUnsafeRow)
} else {
val origin = file.filePath
val path = new Path(origin)
val path = new Path(origin.toString())
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
val stream = fs.open(path)
val bytes = try {
Expand All @@ -107,11 +107,12 @@ class PatchedImageFileFormat extends ImageFileFormat with Serializable with Logg
IOUtils.close(stream)
}

val resultOpt = catchFlakiness(5)(ImageSchema.decode(origin, bytes)) //scalastyle:ignore magic.number
val resultOpt = catchFlakiness(5)( //scalastyle:ignore magic.number
ImageSchema.decode(origin.toString(), bytes))
val filteredResult = if (imageSourceOptions.dropInvalid) {
resultOpt.toIterator
} else {
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin.toString())))
}

if (requiredSchema.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.apache.spark.sql.internal.connector.{SimpleTableProvider, SupportsStr
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.sparkproject.dmg.pmml.False

import java.util
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -107,8 +108,12 @@ private[streaming] class HTTPDataWriter(val partitionId: Int,
val replyColIndex: Int,
val name: String)
extends DataWriter[InternalRow] with Logging {
logInfo(s"Creating writer on PID:$partitionId")
HTTPSourceStateHolder.getServer(name).commit(epochId - 1, partitionId)
logDebug(s"Creating writer on parition:$partitionId epoch $epochId")

val server = HTTPSourceStateHolder.getServer(name)
if (server.isContinuous) {
server.commit(epochId - 1, partitionId)
}

private val ids: mutable.ListBuffer[(String, Int)] = new mutable.ListBuffer[(String, Int)]()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class HTTPSourceTable(options: CaseInsensitiveStringMap)
override def readSchema(): StructType = HTTPSourceV2.Schema

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
logInfo("Creating Microbatch reader")
new HTTPMicroBatchReader(continuous = false, options = options)
}

Expand Down Expand Up @@ -136,8 +135,8 @@ private[streaming] object DriverServiceUtils {
host: String,
handler: HttpHandler): HttpServer = {
val port: Int = StreamUtilities.using(new ServerSocket(0))(_.getLocalPort).get
val server = HttpServer.create(new InetSocketAddress(host, port), 100) //scalastyle:ignore magic.number
server.setExecutor(Executors.newFixedThreadPool(100)) //scalastyle:ignore magic.number
val server = HttpServer.create(new InetSocketAddress(host, port), 100) //scalastyle:ignore magic.number
server.setExecutor(Executors.newFixedThreadPool(100)) //scalastyle:ignore magic.number
server.createContext(s"/$path", handler)
server.start()
server
Expand Down Expand Up @@ -208,10 +207,10 @@ private[streaming] class HTTPMicroBatchReader(continuous: Boolean, options: Case

val numPartitions: Int = options.getInt(HTTPSourceV2.NumPartitions, 2)
val host: String = options.get(HTTPSourceV2.Host, "localhost")
val port: Int = options.getInt(HTTPSourceV2.Port, 8888) //scalastyle:ignore magic.number
val port: Int = options.getInt(HTTPSourceV2.Port, 8888) //scalastyle:ignore magic.number
val path: String = options.get(HTTPSourceV2.Path)
val name: String = options.get(HTTPSourceV2.NAME)
val epochLength: Long = options.getLong(HTTPSourceV2.EpochLength, 30000) //scalastyle:ignore magic.number
val epochLength: Long = options.getLong(HTTPSourceV2.EpochLength, 30000) //scalastyle:ignore magic.number

val forwardingOptions: collection.Map[String, String] = options.asCaseSensitiveMap().asScala
.filter { case (k, _) => k.startsWith("forwarding") }
Expand Down Expand Up @@ -270,8 +269,9 @@ private[streaming] class HTTPMicroBatchReader(continuous: Boolean, options: Case

val config = WorkerServiceConfig(host, port, path, forwardingOptions,
DriverServiceUtils.getDriverHost, driverService.getAddress.getPort, epochLength)

Range(0, numPartitions).map { i =>
HTTPInputPartition(continuous, name, config, startMap(i), endMap.map(_ (i)), i)
HTTPInputPartition(continuous, name, config, startMap(i), endMap.map(_(i)), i)
: InputPartition
}.toArray
}
Expand Down Expand Up @@ -318,7 +318,7 @@ private[streaming] class HTTPContinuousReader(options: CaseInsensitiveStringMap)
}

override def planInputPartitions(start: Offset): Array[InputPartition] =
planInputPartitions(start, null) //scalastyle:ignore null
planInputPartitions(start, null) //scalastyle:ignore null

override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
HTTPSourceReaderFactory
Expand All @@ -332,7 +332,13 @@ private[streaming] case class HTTPInputPartition(continuous: Boolean,
endValue: Option[Long],
partitionIndex: Int
)
extends InputPartition
extends InputPartition {
if (!HTTPSourceStateHolder.hasServer(name)) {
val client = HTTPSourceStateHolder.getOrCreateClient(name)
HTTPSourceStateHolder.getOrCreateServer(name, startValue - 1, partitionIndex, continuous, client, config)
}

}

object HTTPSourceStateHolder {

Expand Down Expand Up @@ -381,6 +387,10 @@ object HTTPSourceStateHolder {
HTTPSourceStateHolder.Servers(name)
}

private[streaming] def hasServer(name: String): Boolean = {
HTTPSourceStateHolder.Servers.contains(name)
}

private[streaming] def getOrCreateServer(name: String,
epoch: Long,
partitionId: Int,
Expand Down Expand Up @@ -487,10 +497,10 @@ private[streaming] class WorkerServer(val name: String,

def registerPartition(localEpoch: Epoch, partitionId: PID): Unit = synchronized {
if (!registeredPartitions.contains(partitionId)) {
logInfo(s"registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
logDebug(s"registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
registeredPartitions.update(partitionId, localEpoch)
} else {
logInfo(s"re-registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
logDebug(s"re-registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
val previousEpoch = registeredPartitions(partitionId)
registeredPartitions.update(partitionId, localEpoch)
//there has been a failed partition and we need to rehydrate the queue
Expand All @@ -514,14 +524,16 @@ private[streaming] class WorkerServer(val name: String,
@GuardedBy("this")
private val historyQueues = new mutable.HashMap[(Epoch, PID), mutable.ListBuffer[CachedRequest]]

@GuardedBy("this")
private[streaming] val recoveredPartitions = new mutable.HashMap[(Epoch, PID), LinkedBlockingQueue[CachedRequest]]

private class PublicHandler extends HttpHandler {
override def handle(request: HttpExchange): Unit = {
logDebug(s"handling epoch: $epoch")
logDebug(s"handling request epoch: $epoch")
val uuid = UUID.randomUUID().toString
val cReq = new CachedRequest(request, uuid)
requestQueues(epoch).put(cReq)
logDebug(s"handled request epoch: $epoch")
}
}

Expand All @@ -540,6 +552,7 @@ private[streaming] class WorkerServer(val name: String,
None
}
.foreach { request =>
logDebug(s"Replying to request")
HTTPServerUtils.respond(request.e, data)
request.e.close()
routingTable.remove(id)
Expand Down Expand Up @@ -582,7 +595,7 @@ private[streaming] class WorkerServer(val name: String,
}
try {
val server = HttpServer.create(new InetSocketAddress(InetAddress.getByName(host), startingPort),
100) //scalastyle:ignore magic.number
100) //scalastyle:ignore magic.number
(server, startingPort)
} catch {
case _: java.net.BindException =>
Expand Down Expand Up @@ -624,22 +637,24 @@ private[streaming] class WorkerServer(val name: String,
}

timeout.map {
case Left(0L) => Option(queue.poll())
case Right(t) =>
Option(queue.poll(t, TimeUnit.MILLISECONDS)).orElse {
synchronized {
//If the queue times out then we move to the next epoch
epoch += 1
val lbq = new LinkedBlockingQueue[CachedRequest]()
requestQueues.update(epoch, lbq)
epochStart = System.currentTimeMillis()
case Left(0L) => Option(queue.poll())
case Right(t) =>
val polled = queue.poll(t, TimeUnit.MILLISECONDS)
Option(polled).orElse {
synchronized {
//If the queue times out then we move to the next epoch
epoch += 1
val lbq = new LinkedBlockingQueue[CachedRequest]()
requestQueues.update(epoch, lbq)
epochStart = System.currentTimeMillis()
}
None
}
}
case _ => throw new IllegalArgumentException("Should not hit this path")
}
.orElse(Some(Some(queue.take())))
.flatten

case _ => throw new IllegalArgumentException("Should not hit this path")
}
.orElse(Some(Some(queue.take())))
.flatten
}
}

Expand All @@ -650,7 +665,8 @@ private[streaming] class WorkerServer(val name: String,
if (TaskContext.get().attemptNumber() == 0) {
// If the request has never been materialized add it to the cache, otherwise we are in a retry and
// should not modify the history
historyQueues.getOrElseUpdate((localEpoch, partitionIndex), new ListBuffer[CachedRequest]())
historyQueues
.getOrElseUpdate((localEpoch, partitionIndex), new ListBuffer[CachedRequest]())
.append(request)
}
InternalRow(
Expand Down Expand Up @@ -702,7 +718,6 @@ private[streaming] class HTTPInputPartitionReader(continuous: Boolean,
val endEpoch: Option[Long],
val partitionIndex: Int)
extends ContinuousPartitionReader[InternalRow] with Logging {

val client: WorkerClient = HTTPSourceStateHolder.getOrCreateClient(name)
val server: WorkerServer = HTTPSourceStateHolder.getOrCreateServer(
name, startEpoch, partitionIndex, continuous, client, config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ object RTestGen {
| "spark.sql.shuffle.partitions=10",
| "spark.sql.crossJoin.enabled=true")
|
|sc <- spark_connect(master = "local", version = "3.2.4", config = conf)
|sc <- spark_connect(master = "local", version = "3.4.1", config = conf)
|
|""".stripMargin, StandardOpenOption.CREATE)

Expand Down

0 comments on commit c2fdb05

Please sign in to comment.