Skip to content

Commit

Permalink
Merge pull request #1324 from datastax/SPARKC-646
Browse files Browse the repository at this point in the history
SPARKC-646 Spark 3.1 support on master branch
  • Loading branch information
jtgrabowski committed Jul 13, 2021
2 parents 64dd2d7 + 3741b88 commit 5cde6e9
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 53 deletions.
4 changes: 4 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
3.1.0
* Updated Spark to 3.1.1 and commons-lang to 3.10
* Fixed crash in the DirectJoin caused by internal changes in Spark 3.1 (SPARKC-626)

3.0.1
* Fix: repeated metadata refresh with the Spark connector (SPARKC-633)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BindReferences
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildSide}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildSide}
import org.apache.spark.sql.execution.metric.SQLMetrics

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Exp
import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation, DataSourceV2Strategy}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation, DataSourceV2Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}


Expand All @@ -28,38 +28,38 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
if hasValidDirectJoin(joinType, leftKeys, rightKeys, condition, left, right) =>

val (otherBranch, joinTargetBranch, joinKeys, buildType) = {
val (otherBranch, joinTargetBranch, buildType) = {
if (leftValid(joinType, leftKeys, rightKeys, condition, left, right)) {
(right, left, leftKeys, BuildLeft)
(right, left, BuildLeft)
} else {
(left, right, rightKeys, BuildRight)
(left, right, BuildRight)
}
}

/* We want to take advantage of all of our pushed filter code which happens in
full table scans. Unfortunately the pushdown code itself is private within the
DataSourceV2Strategy class. To work around this we will invoke DataSourceV2Strategy on
our target branch. This will let us know all of the pushable filters that we can
use in the direct join.
*/
val dataSourceOpitimzedPlan = new DataSourceV2Strategy(spark)(joinTargetBranch).head
val cassandraScanExec = getScanExec(dataSourceOpitimzedPlan).get
full table scans. Unfortunately the pushdown code itself is private within the
DataSourceV2Strategy class. To work around this we will invoke DataSourceV2Strategy on
our target branch. This will let us know all of the pushable filters that we can
use in the direct join.
*/
val dataSourceOptimizedPlan = new DataSourceV2Strategy(spark)(joinTargetBranch).head
val cassandraScanExec = getScanExec(dataSourceOptimizedPlan).get

joinTargetBranch match {
case PhysicalOperation(attributes, _, DataSourceV2ScanRelation(_: CassandraTable, _, _)) =>
case PhysicalOperation(attributes, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _)) =>
val directJoin =
CassandraDirectJoinExec(
leftKeys,
rightKeys,
joinType,
buildType,
condition,
planLater(otherBranch),
aliasMap(attributes),
cassandraScanExec
)

val newPlan = reorderPlan(dataSourceOpitimzedPlan, directJoin) :: Nil
leftKeys,
rightKeys,
joinType,
buildType,
condition,
planLater(otherBranch),
aliasMap(attributes),
cassandraScanExec
)

val newPlan = reorderPlan(dataSourceOptimizedPlan, directJoin) :: Nil
val newOutput = (newPlan.head.outputSet, newPlan.head.output.map(_.name))
val oldOutput = (plan.outputSet, plan.output.map(_.name))
val noMissingOutput = oldOutput._1.subsetOf(newPlan.head.outputSet)
Expand All @@ -79,8 +79,8 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
condition: Option[Expression],
left: LogicalPlan,
right: LogicalPlan): Boolean = {
leftValid(joinType,leftKeys, rightKeys, condition, left, right) ||
rightValid(joinType,leftKeys, rightKeys, condition, left, right)
leftValid(joinType, leftKeys, rightKeys, condition, left, right) ||
rightValid(joinType, leftKeys, rightKeys, condition, left, right)
}


Expand Down Expand Up @@ -147,7 +147,7 @@ object CassandraDirectJoinStrategy extends Logging {
*/
def getScanExec(plan: SparkPlan): Option[BatchScanExec] = {
plan.collectFirst {
case exec@BatchScanExec(_, _: CassandraScan) => exec
case exec @ BatchScanExec(_, _: CassandraScan) => exec
}
}

Expand All @@ -170,7 +170,7 @@ object CassandraDirectJoinStrategy extends Logging {
def getDSV2CassandraRelation(plan: LogicalPlan): Option[DataSourceV2ScanRelation] = {
val children = plan.collectLeaves()
if (children.length == 1) {
plan.collectLeaves().collectFirst { case ds@DataSourceV2ScanRelation(_: CassandraTable, _, _) => ds }
plan.collectLeaves().collectFirst { case ds @ DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _) => ds }
} else {
None
}
Expand All @@ -183,7 +183,7 @@ object CassandraDirectJoinStrategy extends Logging {
def getCassandraTable(plan: LogicalPlan): Option[CassandraTable] = {
val children = plan.collectLeaves()
if (children.length == 1) {
children.collectFirst{ case DataSourceV2ScanRelation(table: CassandraTable, _, _) => table}
children.collectFirst { case DataSourceV2ScanRelation(DataSourceV2Relation(table: CassandraTable, _, _, _, _), _, _) => table }
} else {
None
}
Expand All @@ -192,7 +192,7 @@ object CassandraDirectJoinStrategy extends Logging {
def getCassandraScan(plan: LogicalPlan): Option[CassandraScan] = {
val children = plan.collectLeaves()
if (children.length == 1) {
plan.collectLeaves().collectFirst { case DataSourceV2ScanRelation(_: CassandraTable, cs: CassandraScan, _) => cs }
plan.collectLeaves().collectFirst { case DataSourceV2ScanRelation(_: DataSourceV2Relation, cs: CassandraScan, _) => cs }
} else {
None
}
Expand All @@ -204,7 +204,7 @@ object CassandraDirectJoinStrategy extends Logging {
*/
def hasCassandraChild[T <: QueryPlan[T]](plan: T): Boolean = {
plan.children.size == 1 && plan.children.exists {
case DataSourceV2ScanRelation(_: CassandraTable, _, _) => true
case DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _) => true
case BatchScanExec(_, _: CassandraScan) => true
case _ => false
}
Expand Down Expand Up @@ -237,7 +237,7 @@ object CassandraDirectJoinStrategy extends Logging {
//This may be the only node in the Plan
case BatchScanExec(_, _: CassandraScan) => directJoin
// Plan has children
case normalPlan => normalPlan.transform{
case normalPlan => normalPlan.transform {
case penultimate if hasCassandraChild(penultimate) =>
penultimate.withNewChildren(Seq(directJoin))
}
Expand All @@ -250,20 +250,21 @@ object CassandraDirectJoinStrategy extends Logging {
concern here as columns which may have been non-nullable previously, become nullable in
a left/right join
*/
reordered.transform{
case ProjectExec(projectList, child) =>
val aliases = projectList.collect {
case a@Alias(child: AttributeReference, _) => (child.toAttribute.exprId, a)
}.toMap

val aliasedOutput = directJoin.output.map{
case attr if aliases.contains(attr.exprId) =>
val oldAlias = aliases(attr.exprId)
oldAlias.copy(child = attr)(oldAlias.exprId, oldAlias.qualifier, oldAlias.explicitMetadata)
case other => other
}

ProjectExec (aliasedOutput, child)
reordered.transform {
case ProjectExec(projectList, child) =>
val aliases = projectList.collect {
case a @ Alias(child: AttributeReference, _) => (child.toAttribute.exprId, a)
}.toMap

val aliasedOutput = directJoin.output.map {
case attr if aliases.contains(attr.exprId) =>
val oldAlias = aliases(attr.exprId)
oldAlias.copy(child = attr)(oldAlias.exprId, oldAlias.qualifier,
oldAlias.explicitMetadata, oldAlias.nonInheritableMetadataKeys)
case other => other
}

ProjectExec(aliasedOutput, child)
}
}

Expand All @@ -274,7 +275,7 @@ object CassandraDirectJoinStrategy extends Logging {
def validJoinBranch(plan: LogicalPlan, keys: Seq[Expression]): Boolean = {
val safePlan = containsSafePlans(plan)
val pkConstrained = allPartitionKeysAreJoined(plan, keys)
if (containsSafePlans(plan)){
if (containsSafePlans(plan)) {
logDebug(s"Plan was safe")
}
if (pkConstrained) {
Expand All @@ -291,7 +292,7 @@ object CassandraDirectJoinStrategy extends Logging {
plan match {
case PhysicalOperation(
attributes, _,
DataSourceV2ScanRelation(cassandraTable: CassandraTable, _, _)) =>
DataSourceV2ScanRelation(DataSourceV2Relation(cassandraTable: CassandraTable, _, _, _, _), _, _)) =>

val joinKeysExprId = joinKeys.collect{ case attributeReference: AttributeReference => attributeReference.exprId }

Expand All @@ -312,7 +313,7 @@ object CassandraDirectJoinStrategy extends Logging {
/**
* Map Source Cassandra Column Names to ExpressionIds referring to them
*/
def aliasMap(aliases: Seq[NamedExpression]) = aliases.map {
def aliasMap(aliases: Seq[NamedExpression]): Map[String, ExprId] = aliases.map {
case a @ Alias(child: AttributeReference, _) => child.name -> a.exprId
case attributeReference: AttributeReference => attributeReference.name -> attributeReference.exprId
}.toMap
Expand All @@ -323,7 +324,7 @@ object CassandraDirectJoinStrategy extends Logging {
*/
def containsSafePlans(plan: LogicalPlan): Boolean = {
plan match {
case PhysicalOperation(_, _, DataSourceV2ScanRelation(_: CassandraTable, scan: CassandraScan, _))
case PhysicalOperation(_, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), scan: CassandraScan, _))
if getDirectJoinSetting(scan.consolidatedConf) != AlwaysOff => true
case _ => false
}
Expand Down
4 changes: 2 additions & 2 deletions project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ object Versions {

val CommonsExec = "1.3"
val CommonsIO = "2.6"
val CommonsLang3 = "3.9"
val CommonsLang3 = "3.10"
val Paranamer = "2.8"

val DataStaxJavaDriver = "4.12.0"
Expand All @@ -21,7 +21,7 @@ object Versions {
// and install in a local Maven repository. This is all done automatically, however it will work
// only on Unix/OSX operating system. Windows users have to build and install Spark manually if the
// desired version is not yet published into a public Maven repository.
val ApacheSpark = "3.0.1"
val ApacheSpark = "3.1.1"
val SparkJetty = "9.3.27.v20190418"
val SolrJ = "8.3.0"

Expand Down

0 comments on commit 5cde6e9

Please sign in to comment.