Skip to content

Commit

Permalink
Add back missing files
Browse files Browse the repository at this point in the history
  • Loading branch information
longvu-db committed May 26, 2024
1 parent 53a0ae0 commit 9593178
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (2024) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.connect.server

import scala.collection.JavaConverters._

import com.google.protobuf
import io.delta.tables.DeltaTable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.plugin.CommandPlugin

/**
* Planner plugin for command extensions using [[proto.DeltaCommand]].
*/
class DeltaCommandPlugin extends CommandPlugin with DeltaPlannerBase {
override def process(raw: Array[Byte], planner: SparkConnectPlanner): Boolean = {
val command = protobuf.Any.parseFrom(raw)
if (command.is(classOf[proto.DeltaCommand])) {
process(command.unpack(classOf[proto.DeltaCommand]), planner)
true
} else {
false
}
}

private def process(command: proto.DeltaCommand, planner: SparkConnectPlanner): Unit = {
command.getCommandTypeCase match {
case proto.DeltaCommand.CommandTypeCase.CLONE_TABLE =>
processCloneTable(planner.session, command.getCloneTable)
case _ =>
throw InvalidPlanInput(s"${command.getCommandTypeCase}")
}
}

private def processCloneTable(spark: SparkSession, cloneTable: proto.CloneTable): Unit = {
val deltaTable = transformDeltaTable(spark, cloneTable.getTable)
if (cloneTable.hasVersion) {
deltaTable.cloneAtVersion(
cloneTable.getVersion,
cloneTable.getTarget,
cloneTable.getIsShallow,
cloneTable.getReplace,
cloneTable.getPropertiesMap.asScala.toMap
)
} else if (cloneTable.hasTimestamp) {
deltaTable.cloneAtTimestamp(
cloneTable.getTimestamp,
cloneTable.getTarget,
cloneTable.getIsShallow,
cloneTable.getReplace,
cloneTable.getPropertiesMap.asScala.toMap
)
} else {
deltaTable.clone(
cloneTable.getTarget,
cloneTable.getIsShallow,
cloneTable.getReplace,
cloneTable.getPropertiesMap.asScala.toMap
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (2024) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.connect.server

import io.delta.tables.DeltaTable

import org.apache.spark.sql.SparkSession

/**
* Base trait for the planner plugins of Delta Connect.
*/
trait DeltaPlannerBase {
protected def transformDeltaTable(
spark: SparkSession, deltaTable: proto.DeltaTable): DeltaTable = {
deltaTable.getAccessTypeCase match {
case proto.DeltaTable.AccessTypeCase.PATH =>
DeltaTable.forPath(spark, deltaTable.getPath.getPath, deltaTable.getPath.getHadoopConfMap)
case proto.DeltaTable.AccessTypeCase.TABLE_OR_VIEW_NAME =>
DeltaTable.forName(spark, deltaTable.getTableOrViewName)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (2024) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.connect.server

import java.util.Optional

import com.google.protobuf
import com.google.protobuf.{ByteString, InvalidProtocolBufferException}
import io.delta.connect.proto

import org.apache.spark.SparkEnv
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.plugin.RelationPlugin
import org.apache.spark.sql.delta.connect.server.DeltaRelationPlugin.{parseAnyFrom, parseRelationFrom}

/**
* Planner plugin for relation extensions using [[proto.DeltaRelation]].
*/
class DeltaRelationPlugin extends RelationPlugin with DeltaPlannerBase {
override def transform(raw: Array[Byte], planner: SparkConnectPlanner): Optional[LogicalPlan] = {
val relation = parseAnyFrom(raw,
SparkEnv.get.conf.get(Connect.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT))
if (relation.is(classOf[proto.DeltaRelation])) {
Optional.of(
transform(
parseRelationFrom(relation.getValue,
SparkEnv.get.conf.get(Connect.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT)),
planner
))
} else {
Optional.empty()
}
}


private def transform(
relation: proto.DeltaRelation, planner: SparkConnectPlanner): LogicalPlan = {
relation.getRelationTypeCase match {
case proto.DeltaRelation.RelationTypeCase.SCAN =>
transformScan(planner.session, relation.getScan)
case _ =>
throw InvalidPlanInput(s"Unknown DeltaRelation ${relation.getRelationTypeCase}")
}
}

private def transformScan(spark: SparkSession, scan: proto.Scan): LogicalPlan = {
val deltaTable = transformDeltaTable(spark, scan.getTable)
deltaTable.toDF.queryExecution.analyzed
}
}

object DeltaRelationPlugin {
private def parseAnyFrom(ba: Array[Byte], recursionLimit: Int): protobuf.Any = {
val bs = ByteString.copyFrom(ba)
val cis = bs.newCodedInput()
cis.setSizeLimit(Integer.MAX_VALUE)
cis.setRecursionLimit(recursionLimit)
val plan = protobuf.Any.parseFrom(cis)
try {
// If the last tag is 0, it means the message is correctly parsed.
// If the last tag is not 0, it means the message is not correctly
// parsed, and we should throw an exception.
cis.checkLastTagWas(0)
plan
} catch {
case e: InvalidProtocolBufferException =>
e.setUnfinishedMessage(plan)
throw e
}
}

private def parseRelationFrom(bs: ByteString, recursionLimit: Int): proto.DeltaRelation = {
val cis = bs.newCodedInput()
cis.setSizeLimit(Integer.MAX_VALUE)
cis.setRecursionLimit(recursionLimit)
val plan = proto.DeltaRelation.parseFrom(cis)
try {
// If the last tag is 0, it means the message is correctly parsed.
// If the last tag is not 0, it means the message is not correctly
// parsed, and we should throw an exception.
cis.checkLastTagWas(0)
plan
} catch {
case e: InvalidProtocolBufferException =>
e.setUnfinishedMessage(plan)
throw e
}
}
}

0 comments on commit 9593178

Please sign in to comment.