Skip to content

Commit

Permalink
Consistent timestamps in MergeIntoCommand
Browse files Browse the repository at this point in the history
Transform timestamps for MergeIntoCommand during PreprocessTableMerge.

GitOrigin-RevId: ddcb2ecb5a04a2887997d117c2049555eb8816e5
  • Loading branch information
olaky authored and vkorukanti committed Jul 11, 2022
1 parent 5d3d73f commit fadefea
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 5 deletions.
Expand Up @@ -16,6 +16,7 @@

package org.apache.spark.sql.delta

import java.time.{Instant, LocalDateTime}
import java.util.Locale

import scala.collection.mutable
Expand All @@ -25,13 +26,17 @@ import org.apache.spark.sql.delta.commands.MergeIntoCommand

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CurrentDate, CurrentTimestamp, CurrentTimeZone, Expression, Literal, LocalTimestamp, Now, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.catalyst.trees.TreePatternBits
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DateType, StringType, StructField, StructType, TimestampNTZType, TimestampType}

case class PreprocessTableMerge(override val conf: SQLConf)
extends Rule[LogicalPlan] with UpdateExpressionsSupport {
Expand Down Expand Up @@ -231,15 +236,43 @@ case class PreprocessTableMerge(override val conf: SQLConf)
case DeltaFullTable(index) => index
case o => throw DeltaErrors.notADeltaSourceException("MERGE", Some(o))
}
MergeIntoCommand(
source, target, tahoeFileIndex, condition,
processedMatched, processedNotMatched, finalSchemaOpt)

/**
* Because source and target are not children of MergeIntoCommand they are not processed when
* invoking the [[ComputeCurrentTime]] rule. This is why they need special handling.
*/
val now = Instant.now()
transformTimestamps(
MergeIntoCommand(transformTimestamps(source, now), transformTimestamps(target, now),
tahoeFileIndex, condition, processedMatched, processedNotMatched, finalSchemaOpt),
now)
} else {
DeltaMergeInto(source, target, condition,
processedMatched, processedNotMatched, migrateSchema, finalSchemaOpt)
}
}

private def transformTimestamps(plan: LogicalPlan, instant: Instant): LogicalPlan = {
import org.apache.spark.sql.delta.implicits._

val currentTimestampMicros = instantToMicros(instant)
val currentTime = Literal.create(currentTimestampMicros, TimestampType)
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)

plan.transformUpWithSubqueries {
case subQuery =>
subQuery.transformAllExpressionsUpWithPruning(_.containsPattern(CURRENT_LIKE)) {
case cd: CurrentDate =>
Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
case CurrentTimestamp() | Now() => currentTime
case CurrentTimeZone() => timezone
case localTimestamp: LocalTimestamp =>
val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
}
}
}

/**
* Resolves any non explicitly inserted generated columns in `allActions` to its
* corresponding generated expression.
Expand Down
@@ -0,0 +1,167 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta

import java.sql.Timestamp

import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.functions.{current_timestamp, lit, timestamp_seconds}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils

class MergeIntoTimestampConsistencySuite extends MergeIntoTimestampConsistencySuiteBase {
}


abstract class MergeIntoTimestampConsistencySuiteBase extends QueryTest
with SharedSparkSession with DeltaSQLCommandTest {
private def withTestTables(block: => Unit): Unit = {
def setupTablesAndRun(): Unit = {
spark.range(0, 5)
.toDF("id")
.withColumn("updated", lit(false))
.withColumn("timestampOne", timestamp_seconds(lit(1)))
.withColumn("timestampTwo", timestamp_seconds(lit(1337)))
.write
.format("delta")
.saveAsTable("target")
spark.range(0, 10)
.toDF("id")
.withColumn("updated", lit(true))
.withColumn("timestampOne", current_timestamp())
.withColumn("timestampTwo", current_timestamp())
.createOrReplaceTempView("source")

block
}

Utils.tryWithSafeFinally(setupTablesAndRun) {
sql("DROP VIEW IF EXISTS source")
sql("DROP TABLE IF EXISTS target")
}
}

test("Consistent timestamps between source and ON condition") {
withTestTables {
sql(s"""MERGE INTO target t
| USING source s
| ON s.id = t.id AND s.timestampOne = now()
| WHEN MATCHED THEN UPDATE SET *""".stripMargin)

assertAllRowsAreUpdated()
}
}

test("Consistent timestamps between source and WHEN MATCHED condition") {
withTestTables {
sql(s"""MERGE INTO target t
| USING source s
| ON s.id = t.id
| WHEN MATCHED AND s.timestampOne = now() AND s.timestampTwo = now()
| THEN UPDATE SET *""".stripMargin)

assertAllRowsAreUpdated()
}
}

test("Consistent timestamps between source and UPDATE SET") {
withTestTables {
sql(
s"""MERGE INTO target t
| USING source s
| ON s.id = t.id
| WHEN MATCHED THEN UPDATE
| SET updated = s.updated, t.timestampOne = s.timestampOne, t.timestampTwo = now()
|""".stripMargin)

assertUpdatedTimestampsInTargetAreAllEqual()
}
}

test("Consistent timestamps between source and WHEN NOT MATCHED condition") {
withTestTables {
sql(s"""MERGE INTO target t
| USING source s
| ON s.id = t.id
| WHEN NOT MATCHED AND s.timestampOne = now() AND s.timestampTwo = now()
| THEN INSERT *
|""".stripMargin)

assertNewSourceRowsInserted()
}
}

test("Consistent timestamps between source and INSERT VALUES") {
withTestTables {
sql(
s"""MERGE INTO target t
| USING source s
| ON s.id = t.id
| WHEN NOT MATCHED THEN INSERT (id, updated, timestampOne, timestampTwo)
| VALUES (s.id, s.updated, s.timestampOne, now())
|""".stripMargin)

assertUpdatedTimestampsInTargetAreAllEqual()
}
}

test("Consistent timestamps with subquery in source") {
withTestTables {
val sourceWithSubqueryTable = "source_with_subquery"
withTempView(s"$sourceWithSubqueryTable") {
sql(
s"""CREATE OR REPLACE TEMPORARY VIEW $sourceWithSubqueryTable
| AS SELECT * FROM source WHERE timestampOne IN (SELECT now())
|""".stripMargin).collect()

sql(s"""MERGE INTO target t
| USING $sourceWithSubqueryTable s
| ON s.id = t.id
| WHEN MATCHED THEN UPDATE SET *""".stripMargin)

assertAllRowsAreUpdated()
}
}
}

private def assertAllRowsAreUpdated(): Unit = {
val nonUpdatedRowsCount = sql("SELECT * FROM target WHERE updated = FALSE").count()
assert(0 === nonUpdatedRowsCount, "Un-updated rows in target table")
}

private def assertNewSourceRowsInserted(): Unit = {
val numNotInsertedSourceRows =
sql("SELECT * FROM source s LEFT ANTI JOIN target t ON s.id = t.id").count()
assert(0 === numNotInsertedSourceRows, "Un-inserted rows in source table")
}

private def assertUpdatedTimestampsInTargetAreAllEqual(): Unit = {
import testImplicits._

val timestampCombinations =
sql(s"""SELECT timestampOne, timestampTwo
| FROM target WHERE updated = TRUE GROUP BY timestampOne, timestampTwo
|""".stripMargin)
val rows = timestampCombinations.as[(Timestamp, Timestamp)].collect()
assert(1 === rows.length, "Multiple combinations of timestamp values in target table")
assert(rows(0)._1 === rows(0)._2,
"timestampOne and timestampTwo are not equal in target table")
}
}

0 comments on commit fadefea

Please sign in to comment.