Skip to content

Commit

Permalink
update for miss
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxiaocs7 committed Apr 7, 2024
1 parent 7d755a8 commit e869d57
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,13 @@ class ExecuteStatement(
// Rename all col name to avoid duplicate columns
val colName = range(0, result.schema.size).map(x => "col" + x)

val codec = if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") "zstd" else "zlib"
// df.write will introduce an extra shuffle for the outermost limit, and hurt performance
if (resultMaxRows > 0) {
result.toDF(colName: _*).limit(resultMaxRows).write
.option("compression", codec).format("orc").save(saveFileName.get)
.option("compression", "zstd").format("orc").save(saveFileName.get)
} else {
result.toDF(colName: _*).write
.option("compression", codec).format("orc").save(saveFileName.get)
.option("compression", "zstd").format("orc").save(saveFileName.get)
}
info(s"Save result to ${saveFileName.get}")
fetchOrcStatement = Some(new FetchOrcStatement(spark))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.execution.datasources.orc.OrcDeserializer
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.KyuubiException
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.SPARK_ENGINE_RUNTIME_VERSION
import org.apache.kyuubi.operation.{FetchIterator, IterableFetchIterator}
import org.apache.kyuubi.util.reflect.DynConstructors

Expand Down Expand Up @@ -76,27 +75,13 @@ class FetchOrcStatement(spark: SparkSession) {

private def getOrcDeserializer(orcSchema: StructType, colId: Array[Int]): OrcDeserializer = {
try {
if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") {
// SPARK-34535 changed the constructor signature of OrcDeserializer
DynConstructors.builder()
.impl(classOf[OrcDeserializer], classOf[StructType], classOf[Array[Int]])
.build[OrcDeserializer]()
.newInstance(
orcSchema,
colId)
} else {
DynConstructors.builder()
.impl(
classOf[OrcDeserializer],
classOf[StructType],
classOf[StructType],
classOf[Array[Int]])
.build[OrcDeserializer]()
.newInstance(
new StructType,
orcSchema,
colId)
}
// SPARK-34535 changed the constructor signature of OrcDeserializer
DynConstructors.builder()
.impl(classOf[OrcDeserializer], classOf[StructType], classOf[Array[Int]])
.build[OrcDeserializer]()
.newInstance(
orcSchema,
colId)
} catch {
case e: Throwable =>
throw new KyuubiException("Failed to create OrcDeserializer", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,6 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
largeVarTypes)
}

// IpcOption.DEFAULT was introduced in ARROW-11081(ARROW-4.0.0), add this for adapt Spark 3.1/3.2
// IpcOption.DEFAULT was introduced in ARROW-11081(ARROW-4.0.0), add this for adapt Spark 3.2
final private val ARROW_IPC_OPTION_DEFAULT = new IpcOption()
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ object SparkDatasetHelper extends Logging {
doCollectLimit(collectLimit)
case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
executeArrowBatchCollect(collectLimit.child)
case command: CommandResultExec =>
doCommandResultExec(command)
case commandResult: CommandResultExec =>
doCommandResultExec(commandResult)
case localTableScan: LocalTableScanExec =>
doLocalTableScan(localTableScan)
case plan: SparkPlan =>
Expand Down Expand Up @@ -184,26 +184,21 @@ object SparkDatasetHelper extends Logging {
result.toArray
}

private lazy val commandResultExecRowsMethod = DynMethods.builder("rows")
.impl("org.apache.spark.sql.execution.CommandResultExec")
.build()

private def doCommandResultExec(command: SparkPlan): Array[Array[Byte]] = {
val spark = SparkSession.active
val rows = command.asInstanceOf[CommandResultExec].rows
command.longMetric("numOutputRows").add(rows.size)
sendDriverMetrics(spark.sparkContext, command.metrics)
private def doCommandResultExec(commandResult: CommandResultExec): Array[Array[Byte]] = {
val spark = commandResult.session
commandResult.longMetric("numOutputRows").add(commandResult.rows.size)
sendDriverMetrics(spark.sparkContext, commandResult.metrics)
KyuubiArrowConverters.toBatchIterator(
rows.iterator,
command.schema,
commandResult.rows.iterator,
commandResult.schema,
spark.sessionState.conf.arrowMaxRecordsPerBatch,
maxBatchSize,
-1,
spark.sessionState.conf.sessionLocalTimeZone).toArray
}

private def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = {
val spark = SparkSession.active
val spark = localTableScan.session
localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size)
sendDriverMetrics(spark.sparkContext, localTableScan.metrics)
KyuubiArrowConverters.toBatchIterator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.SparkSession

import org.apache.kyuubi.{KyuubiFunSuite, Utils}
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.SPARK_ENGINE_RUNTIME_VERSION

trait WithSparkSQLEngine extends KyuubiFunSuite {
protected var spark: SparkSession = _
Expand All @@ -35,7 +34,7 @@ trait WithSparkSQLEngine extends KyuubiFunSuite {
// Affected by such configuration' default value
// engine.initialize.sql='SHOW DATABASES'
// SPARK-35378
protected lazy val initJobId: Int = if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") 1 else 0
protected val initJobId: Int = 1

override def beforeAll(): Unit = {
startSparkEngine()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
val listener = new JobCountListener
val l2 = new SQLMetricsListener
val nodeName = spark.sql("SHOW TABLES").queryExecution.executedPlan.getClass.getName
if (SPARK_ENGINE_RUNTIME_VERSION < "3.2") {
assert(nodeName == "org.apache.spark.sql.execution.command.ExecutedCommandExec")
} else {
assert(nodeName == "org.apache.spark.sql.execution.CommandResultExec")
}
assert(nodeName == "org.apache.spark.sql.execution.CommandResultExec")
withJdbcStatement("table_1") { statement =>
statement.executeQuery("CREATE TABLE table_1 (id bigint) USING parquet")
withSparkListener(listener) {
Expand All @@ -310,15 +306,8 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
}
}

if (SPARK_ENGINE_RUNTIME_VERSION < "3.2") {
// Note that before Spark 3.2, a LocalTableScan SparkPlan will be submitted, and the issue of
// preventing LocalTableScan from triggering a job submission was addressed in [KYUUBI #4710].
assert(l2.queryExecution.executedPlan.getClass.getName ==
"org.apache.spark.sql.execution.LocalTableScanExec")
} else {
assert(l2.queryExecution.executedPlan.getClass.getName ==
"org.apache.spark.sql.execution.CommandResultExec")
}
assert(l2.queryExecution.executedPlan.getClass.getName ==
"org.apache.spark.sql.execution.CommandResultExec")
assert(listener.numJobs == 0)
}

Expand Down Expand Up @@ -374,7 +363,6 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp

test("post CommandResultExec driver-side metrics") {
spark.sql("show tables").show(truncate = false)
assume(SPARK_ENGINE_RUNTIME_VERSION >= "3.2")
val expectedMetrics = Map(
0L -> (("CommandResult", Map("number of output rows" -> "2"))))
withTables("table_1", "table_2") {
Expand Down

0 comments on commit e869d57

Please sign in to comment.