From 0231d45757aea3b60b4eb432fa55c3c20e75d44b Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Wed, 22 Sep 2021 16:05:08 +0200 Subject: [PATCH] Fix SAM conversion involving ()-insertion and overloading Without overloading, a function literal `() => 3` whose expected type is a Unit-returning SAM gets typed as `() => { 3; () }` as expected, but with overloading we first type the function literal without an expected type and then compare it against the formal parameter type. This commit makes this check succeed by replacing a result type of `Unit` by `WildcardType` so we end up comparing `() => Int <:< () => ?` instead of `() => Int <:< () => Unit`. Fixes #13549. --- compiler/src/dotty/tools/dotc/core/Types.scala | 12 +++++++++--- .../src/dotty/tools/dotc/typer/Applications.scala | 3 ++- compiler/src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/pos/i13549.scala | 11 +++++++++++ 4 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 tests/pos/i13549.scala diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index c8162fcca948..490bbdd030f9 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1779,15 +1779,21 @@ object Types { * @pre this is a method type without parameter dependencies. * @param dropLast The number of trailing parameters that should be dropped * when forming the function type. + * @param unitToWildcard If true and the result type is Unit, use a wildcard + * as the function result type instead. Useful when + * checking if a function literal could be converted into + * this function type. */ - def toFunctionType(isJava: Boolean, dropLast: Int = 0)(using Context): Type = this match { + def toFunctionType(isJava: Boolean, dropLast: Int = 0, unitToWildcard: Boolean = false)(using Context): Type = this match { case mt: MethodType if !mt.isParamDependent => val formals1 = if (dropLast == 0) mt.paramInfos else mt.paramInfos dropRight dropLast val isContextual = mt.isContextualMethod && !ctx.erasedTypes val isErased = mt.isErasedMethod && !ctx.erasedTypes val result1 = mt.nonDependentResultApprox match { - case res: MethodType => res.toFunctionType(isJava) - case res => res + case res: MethodType => res.toFunctionType(isJava, unitToWildcard = unitToWildcard) + case res => + if unitToWildcard && res.isRef(defn.UnitClass) then WildcardType + else res } val funType = defn.FunctionOf( formals1 mapConserve (_.translateFromRepeated(toArray = isJava)), diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index ab7187f6f4d4..a00092d8d46c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -652,7 +652,8 @@ trait Applications extends Compatibility { def SAMargOK = defn.isFunctionType(argtpe1) && formal.match - case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined)) + case SAMType(sam) => argtpe <:< sam.toFunctionType( + isJava = formal.classSymbol.is(JavaDefined), unitToWildcard = true) case _ => false isCompatible(argtpe, formal) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index b7741c636b06..64b3e79160b4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3728,7 +3728,7 @@ class Typer extends Namer if defn.isFunctionType(wtp) && !defn.isFunctionType(pt) => pt match { case SAMType(sam) - if wtp <:< sam.toFunctionType(isJava = pt.classSymbol.is(JavaDefined)) => + if wtp <:< sam.toFunctionType(isJava = pt.classSymbol.is(JavaDefined), unitToWildcard = true) => // was ... && isFullyDefined(pt, ForceDegree.flipBottom) // but this prevents case blocks from implementing polymorphic partial functions, // since we do not know the result parameter a priori. Have to wait until the diff --git a/tests/pos/i13549.scala b/tests/pos/i13549.scala new file mode 100644 index 000000000000..5d2cfea8845a --- /dev/null +++ b/tests/pos/i13549.scala @@ -0,0 +1,11 @@ +@FunctionalInterface +trait Executable { + def execute(): Unit +} + +object Test { + def assertThrows(executable: Executable, message: String): Unit = ??? + def assertThrows(executable: Executable, foo: Int): Unit = ??? + + assertThrows(() => 3, "This is a message") +}