Skip to content
Permalink
Browse files

[scalac-plugin] Refactor recursive match:

```scala
x match {
  case y => y match {
    case ~(RichPat("x"), _) => ???
  }
}
```
  • Loading branch information
cchantep
cchantep committed May 25, 2014
1 parent f4acc2d commit 28fff28e89368a64d4882efd7517abebc88c3edd
@@ -3,4 +3,5 @@ jdk:
- openjdk6
env: JAVA_OPTS="-Xms32m -Xmx128m"
scala:
- 2.10.3
- 2.10.4
- 2.11.0
@@ -9,7 +9,7 @@ object Acolyte extends Build with Dependencies
aggregate(jdbcDriver, scalacPlugin, jdbcScala, studio).
settings(
organization in ThisBuild := "org.eu.acolyte",
version in ThisBuild := "1.0.18-1",
version in ThisBuild := "1.0.19",
javaOptions in ThisBuild ++= Seq("-source", "1.6", "-target", "1.6"),
scalaVersion in ThisBuild := "2.10.4",
crossScalaVersions in ThisBuild := Seq("2.10.4", "2.11.0"),
@@ -39,8 +39,10 @@ class AcolytePlugin(val global: Global) extends Plugin {
reporter,
Apply,
Block,
Bind,
CaseDef,
DefDef,
Ident,
Match,
Position,
Tree,
@@ -49,7 +51,9 @@ class AcolytePlugin(val global: Global) extends Plugin {
import scala.reflect.io.VirtualFile
import scala.reflect.internal.util.BatchSourceFile

override def transform(tree: Tree): Tree = tree match {
override def transform(tree: Tree): Tree = refactor(tree)

private def refactor(tree: Tree): Tree = tree match {
case m @ Match(_, _) {
val richMatch = refactorMatch(m)

@@ -63,71 +67,71 @@ class AcolytePlugin(val global: Global) extends Plugin {

val tildeTerm = global.newTermName("$tilde")

@inline private def refactorMatch(orig: Match): Tree =
orig match {
case Match(t, cs) {
import global.{ Bind, Ident }

val vds = ListBuffer[ValDef]()
val cds = cs.map {
case ocd @ CaseDef(pat, g, by) {
val ocp = ocd.pos // g, by

val tx = new global.Transformer {
override def transform(tree: Tree): Tree = tree match {
case oa @ Apply(Ident(it), x) if (it == tildeTerm) {
(x.headOption, x.tail) match {
case (Some(xt @ Apply(ex, xa)), bs) {
val xpo: Option[List[Tree]] = bs.headOption match {
case Some(Apply(_, ua)) Some(ua)
case Some(bn @ Bind(_, _)) Some(bn :: Nil)
case Some(id @ Ident(_)) Some(id :: Nil)
case None Some(Nil)
case _ None
}

xpo.fold({
reporter.error(oa.pos, "Invalid ~ pattern")
//abort("Invalid ~ pattern")
oa
}) { xp
val (vd, rp) = refactorPattern(xt.pos, ex, xa, xp)
vds += vd
rp
}
}
case _
reporter.error(oa.pos, "Invalid ~ pattern")
//abort("Invalid ~ pattern")
oa

}
}
case _ super.transform(tree)
}
}
@inline private def refactorMatch(orig: Match): Tree = orig match {
case Match(t, cs) {
val vds = ListBuffer[ValDef]()
val tx = caseDefTransformer(vds)

val of = ocp.source.file
val file = new VirtualFile(of.name,
s"${of.path}#refactored-match-${ocp.line}")
val cds = cs map {
case ocd @ CaseDef(pat, g, by) {
val ocp = ocd.pos // g, by

val nc = CaseDef(tx.transform(pat), g, by)
val cdc = s"${global.show(nc)} // generated from ln ${ocp.line}, col ${ocp.column - 5}"
val cdp = ocp.withPoint(0).
withSource(new BatchSourceFile(file, cdc), 0)
val of = ocp.source.file
val file = new VirtualFile(of.name,
s"${of.path}#refactored-match-${ocp.line}")

global.atPos(cdp)(nc)
}
case cd cd
val nc = CaseDef(tx.transform(pat), g, refactor(by))
val cdc = s"${global show nc} // generated from ln ${ocp.line}, col ${ocp.column - 5}"
val cdp = ocp.withPoint(0).
withSource(new BatchSourceFile(file, cdc), 0)

global.atPos(cdp)(nc)
}
case cd cd
}

if (vds.isEmpty) Match(t, cds)
else Block(vds.toList, Match(t, cds).setPos(orig.pos))
}
case _
reporter.error(orig.pos, "Invalid Match")
//abort("Invalid Match")
orig
}

private def caseDefTransformer(vds: ListBuffer[ValDef]) =
new global.Transformer {
override def transform(tree: Tree): Tree = tree match {
case oa @ Apply(Ident(it), x) if (it == tildeTerm) {
(x.headOption, x.tail) match {
case (Some(xt @ Apply(ex, xa)), bs) {
val xpo: Option[List[Tree]] = bs.headOption match {
case Some(Apply(_, ua)) Some(ua)
case Some(bn @ Bind(_, _)) Some(bn :: Nil)
case Some(id @ Ident(_)) Some(id :: Nil)
case None Some(Nil)
case _ None
}

xpo.fold({
reporter.error(oa.pos, "Invalid ~ pattern")
//abort("Invalid ~ pattern")
oa
}) { xp
val (vd, rp) = refactorPattern(xt.pos, ex, xa, xp)
vds += vd
rp
}
}
case _
reporter.error(oa.pos, "Invalid ~ pattern")
//abort("Invalid ~ pattern")
oa

if (vds.isEmpty) orig // revert to original Match
else Block(vds.toList, Match(t, cds).setPos(orig.pos))
}
}
case _ super.transform(tree)
}
case _
reporter.error(orig.pos, "Invalid Match")
//abort("Invalid Match")
orig
}

@inline private def refactorPattern[T](xp: Position, ex: Tree, xa: List[Tree], ua: List[Tree]): (ValDef, Apply) = {
@@ -42,6 +42,44 @@ object ExtractorComponentSpec extends org.specs2.mutable.Specification
}
}

"Recursive match" >> {
"Basic Pattern matching" should {
"match extractor: Integer(n)" in {
recursivePatternMatching("456") aka "matching" mustEqual List("num-456")
}

"not match" in {
recursivePatternMatching("@") aka "matching" mustEqual Nil
}
}

"Extractor with unapply" should {
"rich match without binding: ~(IntRange(5, 10))" in {
recursivePatternMatching("7") aka "matching" mustEqual List("5-to-10")
}

"rich match without binding: ~(IntRange(10, 20), i)" in {
recursivePatternMatching("12") aka "matching" mustEqual List("range:12")
}
}

"Extractor with unapplySeq" should {
"rich match without binding: ~(Regex(re))" in {
recursivePatternMatching("abc").
aka("matching") mustEqual List("no-binding")
}

"rich match with one binding: ~(Regex(re), a)" in {
recursivePatternMatching("# BCD.") aka "matching" mustEqual List("BCD")
}

"rich match with several bindings: ~(Regex(re), (a, b))" in {
recursivePatternMatching("123;xyz") aka "matching" mustEqual List(
"123;xyz", "xyz", "123")
}
}
}

"Partial function #1" >> {
"Basic Pattern matching" should {
"match extractor: Integer(n)" in {
@@ -195,6 +233,18 @@ sealed trait MatchTest {
case str @ ~(Regex("([0-9]+);([a-z]+)"), (a, b)) List(str, b, a)
case x Nil
}

def recursivePatternMatching(s: String): List[String] = s match {
case v v match {
case ~(IntRange(5, 10), _) List("5-to-10")
case ~(IntRange(10, 20), i) List(s"range:$i")
case Integer(n) List(s"num-$n")
case ~(Regex("^a.*")) List("no-binding")
case ~(Regex("# ([A-Z]+).*"), a) List(a)
case str @ ~(Regex("([0-9]+);([a-z]+)"), (a, b)) List(str, b, a)
case x Nil
}
}
}

/**

0 comments on commit 28fff28

Please sign in to comment.
You can’t perform that action at this time.