Skip to content

Commit

Permalink
Auto-uncurry n-ary functions.
Browse files Browse the repository at this point in the history
Implements SIP scala#897.
  • Loading branch information
odersky committed Feb 16, 2016
1 parent 5e80233 commit 29104c9
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 19 deletions.
19 changes: 19 additions & 0 deletions src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,25 @@ object desugar {
Function(params, Match(selector, cases))
}

/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
*
* x$1 => {
* val p1 = x$1._1
* ...
* val pn = x$1._n
* body
* }
*/
def makeUnaryCaseLambda(params: List[ValDef], body: Tree)(implicit ctx: Context): Tree = {
val param = makeSyntheticParameter()
def selector(n: Int) = Select(refOfDef(param), nme.selectorName(n))
val vdefs =
params.zipWithIndex.map{
case(param, idx) => cpy.ValDef(param)(rhs = selector(idx))
}
Function(param :: Nil, Block(vdefs, body))
}

/** Add annotation with class `cls` to tree:
* tree @cls
*/
Expand Down
56 changes: 37 additions & 19 deletions src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -611,26 +611,44 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
if (protoFormals.length == params.length) protoFormals(i)
else errorType(i"wrong number of parameters, expected: ${protoFormals.length}", tree.pos)

val inferredParams: List[untpd.ValDef] =
for ((param, i) <- params.zipWithIndex) yield
if (!param.tpt.isEmpty) param
else cpy.ValDef(param)(
tpt = untpd.TypeTree(
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))

// Define result type of closure as the expected type, thereby pushing
// down any implicit searches. We do this even if the expected type is not fully
// defined, which is a bit of a hack. But it's needed to make the following work
// (see typers.scala and printers/PlainPrinter.scala for examples).
//
// def double(x: Char): String = s"$x$x"
// "abc" flatMap double
//
val resultTpt = protoResult match {
case WildcardType(_) => untpd.TypeTree()
case _ => untpd.TypeTree(protoResult)
/** Is `formal` a product type which is elementwise compatible with `params`? */
def ptIsCorrectProduct(formal: Type) = {
val pclass = defn.ProductNClass(params.length)
isFullyDefined(formal, ForceDegree.noBottom) &&
formal.derivesFrom(pclass) &&
formal.baseArgTypes(pclass).corresponds(params) {
(argType, param) =>
param.tpt.isEmpty || isCompatible(argType, typedAheadType(param.tpt).tpe)
}
}
typed(desugar.makeClosure(inferredParams, fnBody, resultTpt), pt)

val desugared =
if (protoFormals.length == 1 && params.length != 1 && ptIsCorrectProduct(protoFormals.head)) {
desugar.makeUnaryCaseLambda(params, fnBody)
}
else {
val inferredParams: List[untpd.ValDef] =
for ((param, i) <- params.zipWithIndex) yield
if (!param.tpt.isEmpty) param
else cpy.ValDef(param)(
tpt = untpd.TypeTree(
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))

// Define result type of closure as the expected type, thereby pushing
// down any implicit searches. We do this even if the expected type is not fully
// defined, which is a bit of a hack. But it's needed to make the following work
// (see typers.scala and printers/PlainPrinter.scala for examples).
//
// def double(x: Char): String = s"$x$x"
// "abc" flatMap double
//
val resultTpt = protoResult match {
case WildcardType(_) => untpd.TypeTree()
case _ => untpd.TypeTree(protoResult)
}
desugar.makeClosure(inferredParams, fnBody, resultTpt)
}
typed(desugared, pt)
}
}

Expand Down
1 change: 1 addition & 0 deletions test/dotc/tests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class tests extends CompilerTest {
@Test def neg_abstractOverride() = compileFile(negDir, "abstract-override", xerrors = 2)
@Test def neg_blockescapes() = compileFile(negDir, "blockescapesNeg", xerrors = 1)
@Test def neg_bounds() = compileFile(negDir, "bounds", xerrors = 2)
@Test def neg_functionArity() = compileFile(negDir, "function-arity", xerrors = 5)
@Test def neg_typedapply() = compileFile(negDir, "typedapply", xerrors = 3)
@Test def neg_typedIdents() = compileDir(negDir, "typedIdents", xerrors = 2)
@Test def neg_assignments() = compileFile(negDir, "assignments", xerrors = 3)
Expand Down
22 changes: 22 additions & 0 deletions tests/neg/function-arity.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
object Test {

// From #873:

trait X extends Function1[Int, String]
implicit def f2x(f: Function1[Int, String]): X = ???
({case _ if "".isEmpty => 0} : X) // error: expected String, found Int

// Tests where parameter list cannot be made into a pattern

def unary[T](x: T => Unit) = ???
unary((x, y) => ()) // error

unary[(Int, Int)]((x, y) => ())

unary[(Int, Int)](() => ()) // error
unary[(Int, Int)]((x, y, _) => ()) // error

unary[(Int, Int)]((x: String, y) => ()) // error


}
8 changes: 8 additions & 0 deletions tests/pos/i873.scala → tests/pos/function-arity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,12 @@ object Test {
({case _ if "".isEmpty => ""} : X) // allowed, implicit view used to adapt

// ({case _ if "".isEmpty => 0} : X) // expected String, found Int

def unary[T](a: T, b: T, f: ((T, T)) => T): T = f((a, b))
unary(1, 2, (x, y) => x)
unary(1, 2, (x: Int, y) => x)
unary(1, 2, (x: Int, y: Float) => x)

val xs = List(1, 2, 3)
xs.zipWithIndex.map(_ + _)
}

0 comments on commit 29104c9

Please sign in to comment.