Skip to content

Commit

Permalink
[SPARK-35288][SQL] StaticInvoke should find the method without exact …
Browse files Browse the repository at this point in the history
…argument classes match

### What changes were proposed in this pull request?

This patch proposes to make StaticInvoke able to find method with given method name even the parameter types do not exactly match to argument classes.

### Why are the changes needed?

Unlike `Invoke`, `StaticInvoke` only tries to get the method with exact argument classes. If the calling method's parameter types are not exactly matched with the argument classes, `StaticInvoke` cannot find the method.

`StaticInvoke` should be able to find the method under the cases too.

### Does this PR introduce _any_ user-facing change?

Yes. `StaticInvoke` can find a method even the argument classes are not exactly matched.

### How was this patch tested?

Unit test.

Closes apache#32413 from viirya/static-invoke.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit 33fbf56)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
viirya authored and dongjoon-hyun committed May 9, 2021
1 parent 4295996 commit 373454a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,34 @@ trait InvokeLike extends Expression with NonSQLExpression {
}
}
}

final def findMethod(cls: Class[_], functionName: String, argClasses: Seq[Class[_]]): Method = {
// Looking with function name + argument classes first.
try {
cls.getMethod(functionName, argClasses: _*)
} catch {
case _: NoSuchMethodException =>
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
// We look at function name + argument length
val m = cls.getMethods.filter { m =>
m.getName == functionName && m.getParameterCount == arguments.length
}
if (m.isEmpty) {
sys.error(s"Couldn't find $functionName on $cls")
} else if (m.length > 1) {
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
val realMethods = m.filter(!_.isSynthetic)
if (realMethods.length > 1) {
// Ambiguous case, we don't know which method to choose, just fail it.
sys.error(s"Found ${realMethods.length} $functionName on $cls")
} else {
realMethods.head
}
} else {
m.head
}
}
}
}

/**
Expand Down Expand Up @@ -232,7 +260,7 @@ case class StaticInvoke(
override def children: Seq[Expression] = arguments

lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
@transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*)
@transient lazy val method = findMethod(cls, functionName, argClasses)

override def eval(input: InternalRow): Any = {
invoke(null, method, arguments, input, dataType)
Expand Down Expand Up @@ -319,31 +347,7 @@ case class Invoke(

@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
// Looking with function name + argument classes first.
try {
Some(cls.getMethod(encodedFunctionName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
// We look at function name + argument length
val m = cls.getMethods.filter { m =>
m.getName == encodedFunctionName && m.getParameterCount == arguments.length
}
if (m.isEmpty) {
sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else if (m.length > 1) {
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
val realMethods = m.filter(!_.isSynthetic)
if (realMethods.length > 1) {
// Ambiguous case, we don't know which method to choose, just fail it.
sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls")
} else {
Some(realMethods.head)
}
} else {
Some(m.head)
}
}
Some(findMethod(cls, encodedFunctionName, argClasses))
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val clsType = ObjectType(classOf[ConcreteClass])
val obj = new ConcreteClass

val input = (1, 2)
checkObjectExprEvaluation(
Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0)
Invoke(Literal(obj, clsType), "testFunc", IntegerType,
Seq(Literal(input, ObjectType(input.getClass)))), 2)
}

test("SPARK-35288: static invoke should find method without exact param type match") {
val input = (1, 2)

checkObjectExprEvaluation(
StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func",
Seq(Literal(input, ObjectType(input.getClass)))), 3)

checkObjectExprEvaluation(
StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func",
Seq(Literal(1, IntegerType))), -1)
}
}

Expand All @@ -652,10 +666,22 @@ class TestBean extends Serializable {
assert(i != null, "this setter should not be called with null.")
}

object TestStaticInvoke {
def func(param: Any): Int = param match {
case pair: Tuple2[_, _] =>
pair.asInstanceOf[Tuple2[Int, Int]]._1 + pair.asInstanceOf[Tuple2[Int, Int]]._2
case _ => -1
}
}

abstract class BaseClass[T] {
def testFunc(param: T): T
def testFunc(param: T): Int
}

class ConcreteClass extends BaseClass[Int] with Serializable {
override def testFunc(param: Int): Int = param - 1
class ConcreteClass extends BaseClass[Product] with Serializable {
override def testFunc(param: Product): Int = param match {
case _: Tuple2[_, _] => 2
case _: Tuple3[_, _, _] => 3
case _ => 4
}
}

0 comments on commit 373454a

Please sign in to comment.