Skip to content

Commit

Permalink
[FLINK-13433][table-planner-blink] Do not fetch data from LookupableT…
Browse files Browse the repository at this point in the history
…ableSource if the JoinKey in left side of LookupJoin contains null value

This closes apache#9285
  • Loading branch information
beyond1920 authored and becketqin committed Aug 17, 2019
1 parent 0f07af2 commit 4defd19
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object LookupJoinCodeGenerator {
: GeneratedFunction[FlatMapFunction[BaseRow, BaseRow]] = {

val ctx = CodeGeneratorContext(config)
val (prepareCode, parameters) = prepareParameters(
val (prepareCode, parameters, nullInParameters) = prepareParameters(
ctx,
typeFactory,
inputType,
Expand All @@ -87,11 +87,17 @@ object LookupJoinCodeGenerator {
s"$lookupFunctionTerm.setCollector($DEFAULT_COLLECTOR_TERM);"
}

// TODO: filter all records when there is any nulls on the join key, because
// "IS NOT DISTINCT FROM" is not supported yet.
val body =
s"""
|$prepareCode
|$setCollectorCode
|$lookupFunctionTerm.eval($parameters);
|if ($nullInParameters) {
| return;
|} else {
| $lookupFunctionTerm.eval($parameters);
| }
""".stripMargin

FunctionCodeGenerator.generateFunction(
Expand All @@ -118,7 +124,7 @@ object LookupJoinCodeGenerator {
: GeneratedFunction[AsyncFunction[BaseRow, AnyRef]] = {

val ctx = CodeGeneratorContext(config)
val (prepareCode, parameters) = prepareParameters(
val (prepareCode, parameters, nullInParameters) = prepareParameters(
ctx,
typeFactory,
inputType,
Expand All @@ -130,11 +136,18 @@ object LookupJoinCodeGenerator {
val lookupFunctionTerm = ctx.addReusableFunction(asyncLookupFunction)
val DELEGATE = className[DelegatingResultFuture[_]]

// TODO: filter all records when there is any nulls on the join key, because
// "IS NOT DISTINCT FROM" is not supported yet.
val body =
s"""
|$prepareCode
|$DELEGATE delegates = new $DELEGATE($DEFAULT_COLLECTOR_TERM);
|$lookupFunctionTerm.eval(delegates.getCompletableFuture(), $parameters);
|if ($nullInParameters) {
| $DEFAULT_COLLECTOR_TERM.complete(java.util.Collections.emptyList());
| return;
|} else {
| $DELEGATE delegates = new $DELEGATE($DEFAULT_COLLECTOR_TERM);
| $lookupFunctionTerm.eval(delegates.getCompletableFuture(), $parameters);
|}
""".stripMargin

FunctionCodeGenerator.generateFunction(
Expand All @@ -156,7 +169,7 @@ object LookupJoinCodeGenerator {
lookupKeyInOrder: Array[Int],
allLookupFields: Map[Int, LookupKey],
isExternalArgs: Boolean,
fieldCopy: Boolean): (String, String) = {
fieldCopy: Boolean): (String, String, String) = {

val inputFieldExprs = for (i <- lookupKeyInOrder) yield {
allLookupFields.get(i) match {
Expand Down Expand Up @@ -195,9 +208,12 @@ object LookupJoinCodeGenerator {
| $newTerm = $assign;
|}
""".stripMargin
(code, newTerm)
(code, newTerm, e.nullTerm)
}
(codeAndArg.map(_._1).mkString("\n"), codeAndArg.map(_._2).mkString(", "))
(
codeAndArg.map(_._1).mkString("\n"),
codeAndArg.map(_._2).mkString(", "),
codeAndArg.map(_._3).mkString("|| "))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ abstract class CommonLookupJoin(
joinType: JoinRelType): Unit = {

// check join on all fields of PRIMARY KEY or (UNIQUE) INDEX
if (allLookupKeys.isEmpty || allLookupKeys.isEmpty) {
if (allLookupKeys.isEmpty) {
throw new TableException(
"Temporal table join requires an equality condition on fields of " +
s"table [${tableSource.explainSource()}].")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.table.api.Types
import org.apache.flink.table.planner.runtime.utils.{BatchTableEnvUtil, BatchTestBase, InMemoryLookupableTableSource}

import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{Before, Test}

class LookupJoinITCase extends BatchTestBase {
import java.lang.Boolean
import java.util

import scala.collection.JavaConversions._

@RunWith(classOf[Parameterized])
class LookupJoinITCase(isAsyncMode: Boolean) extends BatchTestBase {

val data = List(
BatchTestBase.row(1L, 12L, "Julian"),
Expand All @@ -33,6 +41,12 @@ class LookupJoinITCase extends BatchTestBase {
BatchTestBase.row(8L, 11L, "Hello world"),
BatchTestBase.row(9L, 12L, "Hello world!"))

val dataWithNull = List(
BatchTestBase.row(null, 15L, "Hello"),
BatchTestBase.row(3L, 15L, "Fabian"),
BatchTestBase.row(null, 11L, "Hello world"),
BatchTestBase.row(9L, 12L, "Hello world!"))

val typeInfo = new RowTypeInfo(LONG_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO)

val userData = List(
Expand All @@ -55,19 +69,72 @@ class LookupJoinITCase extends BatchTestBase {
.enableAsync()
.build()

val userDataWithNull = List(
(11, 1L, "Julian"),
(22, null, "Hello"),
(33, 3L, "Fabian"),
(44, null, "Hello world"))

val userWithNullDataTableSource = InMemoryLookupableTableSource.builder()
.data(userDataWithNull)
.field("age", Types.INT)
.field("id", Types.LONG)
.field("name", Types.STRING)
.build()

val userAsyncWithNullDataTableSource = InMemoryLookupableTableSource.builder()
.data(userDataWithNull)
.field("age", Types.INT)
.field("id", Types.LONG)
.field("name", Types.STRING)
.enableAsync()
.build()

var userTable: String = _
var userTableWithNull: String = _

@Before
override def before() {
super.before()
BatchTableEnvUtil.registerCollection(tEnv, "T0", data, typeInfo, "id, len, content")
val myTable = tEnv.sqlQuery("SELECT *, PROCTIME() as proctime FROM T0")
tEnv.registerTable("T", myTable)

BatchTableEnvUtil.registerCollection(
tEnv, "T1", dataWithNull, typeInfo, "id, len, content")
val myTable1 = tEnv.sqlQuery("SELECT *, PROCTIME() as proctime FROM T1")
tEnv.registerTable("nullableT", myTable1)

tEnv.registerTableSource("userTable", userTableSource)
tEnv.registerTableSource("userAsyncTable", userAsyncTableSource)
userTable = if (isAsyncMode) "userAsyncTable" else "userTable"

tEnv.registerTableSource("userWithNullDataTable", userWithNullDataTableSource)
tEnv.registerTableSource("userWithNullDataAsyncTable", userAsyncWithNullDataTableSource)
userTableWithNull = if (isAsyncMode) "userWithNullDataAsyncTable" else "userWithNullDataTable"

// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
}

@Test
def testLeftJoinTemporalTableWithLocalPredicate(): Unit = {
val sql = s"SELECT T.id, T.len, T.content, D.name, D.age FROM T LEFT JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id " +
"AND T.len > 1 AND D.age > 20 AND D.name = 'Fabian' " +
"WHERE T.id > 1"

val expected = Seq(
BatchTestBase.row(2, 15, "Hello", null, null),
BatchTestBase.row(3, 15, "Fabian", "Fabian", 33),
BatchTestBase.row(8, 11, "Hello world", null, null),
BatchTestBase.row(9, 12, "Hello world!", null, null))
checkResult(sql, expected, false)
}

@Test
def testJoinTemporalTable(): Unit = {
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, T.content, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"

val expected = Seq(
Expand All @@ -79,7 +146,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableWithPushDown(): Unit = {
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, T.content, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20"

val expected = Seq(
Expand All @@ -90,7 +157,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableWithNonEqualFilter(): Unit = {
val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age"

val expected = Seq(
Expand All @@ -101,7 +168,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableOnMultiFields(): Unit = {
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name"

val expected = Seq(
Expand All @@ -112,7 +179,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableOnMultiFieldsWithUdf(): Unit = {
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON mod(T.id, 4) = D.id AND T.content = D.name"

val expected = Seq(
Expand All @@ -123,7 +190,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testJoinTemporalTableOnMultiKeyFields(): Unit = {
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id"

val expected = Seq(
Expand All @@ -134,7 +201,7 @@ class LookupJoinITCase extends BatchTestBase {

@Test
def testLeftJoinTemporalTable(): Unit = {
val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userTable " +
val sql = s"SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"

val expected = Seq(
Expand All @@ -147,88 +214,50 @@ class LookupJoinITCase extends BatchTestBase {
}

@Test
def testAsyncJoinTemporalTable(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"

val expected = Seq(
BatchTestBase.row(1, 12, "Julian", "Julian"),
BatchTestBase.row(2, 15, "Hello", "Jark"),
BatchTestBase.row(3, 15, "Fabian", "Fabian"))
checkResult(sql, expected, false)
}

@Test
def testAsyncJoinTemporalTableWithPushDown(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND D.age > 20"
def testJoinTemporalTableOnMultiKeyFieldsWithNullData(): Unit = {
val sql = s"SELECT T.id, T.len, D.name FROM nullableT T JOIN $userTableWithNull " +
"for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id"

val expected = Seq(
BatchTestBase.row(2, 15, "Hello", "Jark"),
BatchTestBase.row(3, 15, "Fabian", "Fabian"))
BatchTestBase.row(3,15,"Fabian"))
checkResult(sql, expected, false)
}

@Test
def testAsyncJoinTemporalTableWithNonEqualFilter(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id WHERE T.len <= D.age"

def testLeftJoinTemporalTableOnMultiKeyFieldsWithNullData(): Unit = {
val sql = s"SELECT D.id, T.len, D.name FROM nullableT T LEFT JOIN $userTableWithNull " +
"for system_time as of T.proctime AS D ON T.content = D.name AND T.id = D.id"
val expected = Seq(
BatchTestBase.row(2, 15, "Hello", "Jark", 22),
BatchTestBase.row(3, 15, "Fabian", "Fabian", 33))
BatchTestBase.row(null,15,null),
BatchTestBase.row(3,15,"Fabian"),
BatchTestBase.row(null,11,null),
BatchTestBase.row(null,12,null))
checkResult(sql, expected, false)
}

@Test
def testAsyncLeftJoinTemporalTableWithLocalPredicate(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, T.content, D.name, D.age FROM T LEFT JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id " +
"AND T.len > 1 AND D.age > 20 AND D.name = 'Fabian' " +
"WHERE T.id > 1"

val expected = Seq(
BatchTestBase.row(2, 15, "Hello", null, null),
BatchTestBase.row(3, 15, "Fabian", "Fabian", 33),
BatchTestBase.row(8, 11, "Hello world", null, null),
BatchTestBase.row(9, 12, "Hello world!", null, null))
def testJoinTemporalTableOnNullConstantKey(): Unit = {
val sql = s"SELECT T.id, T.len, T.content FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON D.id = null"
val expected = Seq()
checkResult(sql, expected, false)
}

@Test
def testAsyncJoinTemporalTableOnMultiFields(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, D.name FROM T JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id AND T.content = D.name"

val expected = Seq(
BatchTestBase.row(1, 12, "Julian"),
BatchTestBase.row(3, 15, "Fabian"))
def testJoinTemporalTableOnMultiKeyFieldsWithNullConstantKey(): Unit = {
val sql = s"SELECT T.id, T.len, D.name FROM T JOIN $userTable " +
"for system_time as of T.proctime AS D ON T.content = D.name AND null = D.id"
val expected = Seq()
checkResult(sql, expected, false)
}
}

@Test
def testAsyncLeftJoinTemporalTable(): Unit = {
// TODO: enable object reuse until [FLINK-12351] is fixed.
env.getConfig.disableObjectReuse()
val sql = "SELECT T.id, T.len, D.name, D.age FROM T LEFT JOIN userAsyncTable " +
"for system_time as of T.proctime AS D ON T.id = D.id"
object LookupJoinITCase {

val expected = Seq(
BatchTestBase.row(1, 12, "Julian", 11),
BatchTestBase.row(2, 15, "Jark", 22),
BatchTestBase.row(3, 15, "Fabian", 33),
BatchTestBase.row(8, 11, null, null),
BatchTestBase.row(9, 12, null, null))
checkResult(sql, expected, false)
@Parameterized.Parameters(name = "isAsyncMode = {0}")
def parameters(): util.Collection[Array[java.lang.Object]] = {
Seq[Array[AnyRef]](
Array(Boolean.TRUE), Array(Boolean.FALSE)
)
}
}

0 comments on commit 4defd19

Please sign in to comment.