Skip to content

Commit

Permalink
Merge pull request #6116 from dotty-staging/fix-ext-overload
Browse files Browse the repository at this point in the history
Strengthen overloading resolution to deal with extension methods
  • Loading branch information
odersky committed Mar 21, 2019
2 parents 87a6ce4 + 1762d21 commit 8755bdf
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 24 deletions.
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
case dummyTreeOfType(tp) :: Nil if !(tp isRef defn.NullClass) => "null: " ~ toText(tp)
case _ => toTextGlobal(args, ", ")
}
return "FunProto(" ~ (Str("given ") provided tp.isContextual) ~ argsText ~ "):" ~ toText(resultType)
return "[applied to " ~ (Str("given ") provided tp.isContextual) ~ "(" ~ argsText ~ ") returning " ~ toText(resultType) ~ "]"
case IgnoredProto(ignored) =>
return "?" ~ (("(ignored: " ~ toText(ignored) ~ ")") provided ctx.settings.verbose.value)
case tp @ PolyProto(targs, resType) =>
return "PolyProto(" ~ toTextGlobal(targs, ", ") ~ "): " ~ toText(resType)
return "[applied to [" ~ toTextGlobal(targs, ", ") ~ "] returning " ~ toText(resType)
case _ =>
}
super.toText(tp)
Expand Down
59 changes: 49 additions & 10 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Trees.Untyped
import Contexts._
import Flags._
import Symbols._
import Denotations.Denotation
import Types._
import Decorators._
import ErrorReporting._
Expand Down Expand Up @@ -1204,8 +1205,12 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
* result matching `resultType`?
*/
def hasExtensionMethod(tp: Type, name: TermName, argType: Type, resultType: Type)(implicit ctx: Context) = {
val mbr = tp.memberBasedOnFlags(name, required = ExtensionMethod)
mbr.exists && isApplicable(tp.select(name, mbr), argType :: Nil, resultType)
def qualifies(mbr: Denotation) =
mbr.exists && isApplicable(tp.select(name, mbr), argType :: Nil, resultType)
tp.memberBasedOnFlags(name, required = ExtensionMethod) match {
case mbr: SingleDenotation => qualifies(mbr)
case mbr => mbr.hasAltWith(qualifies(_))
}
}

/** Compare owner inheritance level.
Expand Down Expand Up @@ -1627,16 +1632,50 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
}
else compat
}

/** For each candidate `C`, a proxy termref paired with `C`.
* The proxy termref has as symbol a copy of the original candidate symbol,
* with an info that strips the first value parameter list away.
* @param argTypes The types of the arguments of the FunProto `pt`.
*/
def advanceCandidates(argTypes: List[Type]): List[(TermRef, TermRef)] = {
def strippedType(tp: Type): Type = tp match {
case tp: PolyType =>
val rt = strippedType(tp.resultType)
if (rt.exists) tp.derivedLambdaType(resType = rt) else rt
case tp: MethodType =>
tp.instantiate(argTypes)
case _ =>
NoType
}
def cloneCandidate(cand: TermRef): List[(TermRef, TermRef)] = {
val strippedInfo = strippedType(cand.widen)
if (strippedInfo.exists) {
val sym = cand.symbol.asTerm.copy(info = strippedInfo)
(TermRef(NoPrefix, sym), cand) :: Nil
}
else Nil
}
overload.println(i"look at more params: ${candidates.head.symbol}: ${candidates.map(_.widen)}%, % with $pt, [$targs%, %]")
candidates.flatMap(cloneCandidate)
}

val found = narrowMostSpecific(candidates)
if (found.length <= 1) found
else {
val noDefaults = alts.filter(!_.symbol.hasDefaultParams)
if (noDefaults.length == 1) noDefaults // return unique alternative without default parameters if it exists
else {
val deepPt = pt.deepenProto
if (deepPt ne pt) resolveOverloaded(alts, deepPt, targs)
else alts
}
else pt match {
case pt @ FunProto(_, resType: FunProto) =>
// try to narrow further with snd argument list
val advanced = advanceCandidates(pt.typedArgs.tpes)
resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped
.map(advanced.toMap) // map surviving result(s) back to original candidates
case _ =>
val noDefaults = alts.filter(!_.symbol.hasDefaultParams)
if (noDefaults.length == 1) noDefaults // return unique alternative without default parameters if it exists
else {
val deepPt = pt.deepenProto
if (deepPt ne pt) resolveOverloaded(alts, deepPt, targs)
else alts
}
}
}

Expand Down
22 changes: 14 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,16 @@ object Implicits {
/** A "massaging" function for displayed types to give better info in error diagnostics */
def clarify(tp: Type)(implicit ctx: Context): Type = tp

final protected def qualify(implicit ctx: Context): String =
if (expectedType.exists)
if (argument.isEmpty) em"match type ${clarify(expectedType)}"
else em"convert from ${argument.tpe} to ${clarify(expectedType)}"
else
final protected def qualify(implicit ctx: Context): String = expectedType match {
case SelectionProto(name, mproto, _, _) if !argument.isEmpty =>
em"provide an extension method `$name` on ${argument.tpe}"
case NoType =>
if (argument.isEmpty) em"match expected type"
else em"convert from ${argument.tpe} to expected type"
case _ =>
if (argument.isEmpty) em"match type ${clarify(expectedType)}"
else em"convert from ${argument.tpe} to ${clarify(expectedType)}"
}

/** An explanation of the cause of the failure as a string */
def explanation(implicit ctx: Context): String
Expand Down Expand Up @@ -425,9 +428,12 @@ object Implicits {
class AmbiguousImplicits(val alt1: SearchSuccess, val alt2: SearchSuccess, val expectedType: Type, val argument: Tree) extends SearchFailureType {
def explanation(implicit ctx: Context): String =
em"both ${err.refStr(alt1.ref)} and ${err.refStr(alt2.ref)} $qualify"
override def whyNoConversion(implicit ctx: Context): String =
"\nNote that implicit conversions cannot be applied because they are ambiguous;" +
"\n" + explanation
override def whyNoConversion(implicit ctx: Context): String = {
val what = if (expectedType.isInstanceOf[SelectionProto]) "extension methods" else "conversions"
i"""
|Note that implicit $what cannot be applied because they are ambiguous;
|$explanation"""
}
}

class MismatchedImplicit(ref: TermRef,
Expand Down
11 changes: 8 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ object RefChecks {
* 1.8.1 M's type is a subtype of O's type, or
* 1.8.2 M is of type []S, O is of type ()T and S <: T, or
* 1.8.3 M is of type ()S, O is of type []T and S <: T, or
* 1.9 If M or O are erased, they must be both erased
* 1.9.1 If M or O are erased, they must both be erased
* 1.9.2 If M or O are extension methods, they must both be extension methods
* 1.10 If M is an inline or Scala-2 macro method, O cannot be deferred unless
* there's also a concrete method that M overrides.
* 1.11. If O is a Scala-2 macro, M must be a Scala-2 macro.
Expand Down Expand Up @@ -391,10 +392,14 @@ object RefChecks {
overrideError("may not override a non-lazy value")
} else if (other.is(Lazy) && !other.isRealMethod && !member.is(Lazy)) {
overrideError("must be declared lazy to override a lazy value")
} else if (member.is(Erased) && !other.is(Erased)) { // (1.9)
} else if (member.is(Erased) && !other.is(Erased)) { // (1.9.1)
overrideError("is erased, cannot override non-erased member")
} else if (other.is(Erased) && !member.is(Erased)) { // (1.9)
} else if (other.is(Erased) && !member.is(Erased)) { // (1.9.1)
overrideError("is not erased, cannot override erased member")
} else if (member.is(Extension) && !other.is(Extension)) { // (1.9.2)
overrideError("is an extension method, cannot override a normal method")
} else if (other.is(Extension) && !member.is(Extension)) { // (1.9.2)
overrideError("is a normal method, cannot override an extension method")
} else if ((member.isInlineMethod || member.is(Scala2Macro)) && other.is(Deferred) &&
member.extendedOverriddenSymbols.forall(_.is(Deferred))) { // (1.10)
overrideError("is an inline method, must override at least one concrete method")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class SignatureHelpTest {
}""".withSource
.signatureHelp(m1, List(sig0, sig1), None, 0)
.signatureHelp(m2, List(sig0, sig1), None, 0)
.signatureHelp(m3, List(sig0, sig1), Some(1), 1)
.signatureHelp(m3, List(), Some(1), 1) // TODO: investigate we do not get help at $m3
}

@Test def multipleParameterLists: Unit = {
Expand Down
14 changes: 14 additions & 0 deletions tests/neg/extmethod-overload.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object Test {
implied A {
def (x: Int) |+| (y: Int) = x + y
}
implied B {
def (x: Int) |+| (y: String) = x + y.length
}
assert((1 |+| 2) == 3) // error ambiguous

locally {
import B.|+|
assert((1 |+| "2") == 2) // OK
}
}
8 changes: 8 additions & 0 deletions tests/neg/extmethod-override.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class A {
def f(x: Int)(y: Int): Int = 0
def (x: Int) g (y: Int): Int = 1
}
class B extends A {
override def (x: Int) f (y: Int): Int = 1 // error
override def g(x: Int)(y: Int): Int = 0 // error
}
122 changes: 122 additions & 0 deletions tests/run/extmethod-overload.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
object Test extends App {
// warmup
def f(x: Int)(y: Int) = y
def f(x: Int)(y: String) = y.length
assert(f(1)(2) == 2)
assert(f(1)("two") == 3)

def g[T](x: T)(y: Int) = y
def g[T](x: T)(y: String) = y.length
assert(g[Int](1)(2) == 2)
assert(g[Int](1)("two") == 3)
assert(g(1)(2) == 2)
assert(g(1)("two") == 3)

def h[T](x: T)(y: T)(z: Int) = z
def h[T](x: T)(y: T)(z: String) = z.length
assert(h[Int](1)(1)(2) == 2)
assert(h[Int](1)(1)("two") == 3)
assert(h(1)(1)(2) == 2)
assert(h(1)(1)("two") == 3)

// Test with extension methods in implied object
object test1 {

implied Foo {
def (x: Int) |+| (y: Int) = x + y
def (x: Int) |+| (y: String) = x + y.length

def (xs: List[T]) +++ [T] (ys: List[T]): List[T] = xs ++ ys ++ ys
def (xs: List[T]) +++ [T] (ys: Iterator[T]): List[T] = xs ++ ys ++ ys
}

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test1

// Test with imported extension methods
object test2 {
import test1.Foo._

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test2

// Test with implied extension methods coming from base class
object test3 {
class Foo {
def (x: Int) |+| (y: Int) = x + y
def (x: Int) |+| (y: String) = x + y.length

def (xs: List[T]) +++ [T] (ys: List[T]): List[T] = xs ++ ys ++ ys
def (xs: List[T]) +++ [T] (ys: Iterator[T]): List[T] = xs ++ ys ++ ys
}
implied Bar for Foo

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test3

// Test with implied extension methods coming from implied alias
object test4 {
implied for test3.Foo = test3.Bar

assert((1 |+| 2) == 3)
assert((1 |+| "2") == 2)

val xs = List(1, 2)
assert((xs +++ xs).length == 6)
assert((xs +++ xs.iterator).length == 4, xs +++ xs.iterator)
}
test4

class C {
def xx (x: Any) = 2
}
def (c: C) xx (x: Int) = 1

val c = new C
assert(c.xx(1) == 2) // member method takes precedence

object D {
def (x: Int) yy (y: Int) = x + y
}

implied {
def (x: Int) yy (y: Int) = x - y
}

import D._
assert((1 yy 2) == 3) // imported extension method takes precedence

trait Rectangle {
def a: Long
def b: Long
}

case class GenericRectangle(a: Long, b: Long) extends Rectangle
case class Square(a: Long) extends Rectangle {
def b: Long = a
}

def (rectangle: Rectangle) area: Long = 0
def (square: Square) area: Long = square.a * square.a
val rectangles = List(GenericRectangle(2, 3), Square(5))
val areas = rectangles.map(_.area)
assert(areas.sum == 0)
}

0 comments on commit 8755bdf

Please sign in to comment.