Permalink
Browse files

[scalac-plugin] Refactor recursive match:

```scala
x match {
  case y => y match {
    case ~(RichPat("x"), _) => ???
  }
}
```
  • Loading branch information...
cchantep
cchantep committed May 25, 2014
1 parent f4acc2d commit 28fff28e89368a64d4882efd7517abebc88c3edd
View
@@ -3,4 +3,5 @@ jdk:
- openjdk6
env: JAVA_OPTS="-Xms32m -Xmx128m"
scala:
- - 2.10.3
+ - 2.10.4
+ - 2.11.0
View
@@ -9,7 +9,7 @@ object Acolyte extends Build with Dependencies
aggregate(jdbcDriver, scalacPlugin, jdbcScala, studio).
settings(
organization in ThisBuild := "org.eu.acolyte",
- version in ThisBuild := "1.0.18-1",
+ version in ThisBuild := "1.0.19",
javaOptions in ThisBuild ++= Seq("-source", "1.6", "-target", "1.6"),
scalaVersion in ThisBuild := "2.10.4",
crossScalaVersions in ThisBuild := Seq("2.10.4", "2.11.0"),
@@ -39,8 +39,10 @@ class AcolytePlugin(val global: Global) extends Plugin {
reporter,
Apply,
Block,
+ Bind,
CaseDef,
DefDef,
+ Ident,
Match,
Position,
Tree,
@@ -49,7 +51,9 @@ class AcolytePlugin(val global: Global) extends Plugin {
import scala.reflect.io.VirtualFile
import scala.reflect.internal.util.BatchSourceFile
- override def transform(tree: Tree): Tree = tree match {
+ override def transform(tree: Tree): Tree = refactor(tree)
+
+ private def refactor(tree: Tree): Tree = tree match {
case m @ Match(_, _) {
val richMatch = refactorMatch(m)
@@ -63,71 +67,71 @@ class AcolytePlugin(val global: Global) extends Plugin {
val tildeTerm = global.newTermName("$tilde")
- @inline private def refactorMatch(orig: Match): Tree =
- orig match {
- case Match(t, cs) {
- import global.{ Bind, Ident }
-
- val vds = ListBuffer[ValDef]()
- val cds = cs.map {
- case ocd @ CaseDef(pat, g, by) {
- val ocp = ocd.pos // g, by
-
- val tx = new global.Transformer {
- override def transform(tree: Tree): Tree = tree match {
- case oa @ Apply(Ident(it), x) if (it == tildeTerm) {
- (x.headOption, x.tail) match {
- case (Some(xt @ Apply(ex, xa)), bs) {
- val xpo: Option[List[Tree]] = bs.headOption match {
- case Some(Apply(_, ua)) Some(ua)
- case Some(bn @ Bind(_, _)) Some(bn :: Nil)
- case Some(id @ Ident(_)) Some(id :: Nil)
- case None Some(Nil)
- case _ None
- }
-
- xpo.fold({
- reporter.error(oa.pos, "Invalid ~ pattern")
- //abort("Invalid ~ pattern")
- oa
- }) { xp
- val (vd, rp) = refactorPattern(xt.pos, ex, xa, xp)
- vds += vd
- rp
- }
- }
- case _
- reporter.error(oa.pos, "Invalid ~ pattern")
- //abort("Invalid ~ pattern")
- oa
-
- }
- }
- case _ super.transform(tree)
- }
- }
+ @inline private def refactorMatch(orig: Match): Tree = orig match {
+ case Match(t, cs) {
+ val vds = ListBuffer[ValDef]()
+ val tx = caseDefTransformer(vds)
- val of = ocp.source.file
- val file = new VirtualFile(of.name,
- s"${of.path}#refactored-match-${ocp.line}")
+ val cds = cs map {
+ case ocd @ CaseDef(pat, g, by) {
+ val ocp = ocd.pos // g, by
- val nc = CaseDef(tx.transform(pat), g, by)
- val cdc = s"${global.show(nc)} // generated from ln ${ocp.line}, col ${ocp.column - 5}"
- val cdp = ocp.withPoint(0).
- withSource(new BatchSourceFile(file, cdc), 0)
+ val of = ocp.source.file
+ val file = new VirtualFile(of.name,
+ s"${of.path}#refactored-match-${ocp.line}")
- global.atPos(cdp)(nc)
- }
- case cd cd
+ val nc = CaseDef(tx.transform(pat), g, refactor(by))
+ val cdc = s"${global show nc} // generated from ln ${ocp.line}, col ${ocp.column - 5}"
+ val cdp = ocp.withPoint(0).
+ withSource(new BatchSourceFile(file, cdc), 0)
+
+ global.atPos(cdp)(nc)
}
+ case cd cd
+ }
+
+ if (vds.isEmpty) Match(t, cds)
+ else Block(vds.toList, Match(t, cds).setPos(orig.pos))
+ }
+ case _
+ reporter.error(orig.pos, "Invalid Match")
+ //abort("Invalid Match")
+ orig
+ }
+
+ private def caseDefTransformer(vds: ListBuffer[ValDef]) =
+ new global.Transformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case oa @ Apply(Ident(it), x) if (it == tildeTerm) {
+ (x.headOption, x.tail) match {
+ case (Some(xt @ Apply(ex, xa)), bs) {
+ val xpo: Option[List[Tree]] = bs.headOption match {
+ case Some(Apply(_, ua)) Some(ua)
+ case Some(bn @ Bind(_, _)) Some(bn :: Nil)
+ case Some(id @ Ident(_)) Some(id :: Nil)
+ case None Some(Nil)
+ case _ None
+ }
+
+ xpo.fold({
+ reporter.error(oa.pos, "Invalid ~ pattern")
+ //abort("Invalid ~ pattern")
+ oa
+ }) { xp
+ val (vd, rp) = refactorPattern(xt.pos, ex, xa, xp)
+ vds += vd
+ rp
+ }
+ }
+ case _
+ reporter.error(oa.pos, "Invalid ~ pattern")
+ //abort("Invalid ~ pattern")
+ oa
- if (vds.isEmpty) orig // revert to original Match
- else Block(vds.toList, Match(t, cds).setPos(orig.pos))
+ }
+ }
+ case _ super.transform(tree)
}
- case _
- reporter.error(orig.pos, "Invalid Match")
- //abort("Invalid Match")
- orig
}
@inline private def refactorPattern[T](xp: Position, ex: Tree, xa: List[Tree], ua: List[Tree]): (ValDef, Apply) = {
@@ -42,6 +42,44 @@ object ExtractorComponentSpec extends org.specs2.mutable.Specification
}
}
+ "Recursive match" >> {
+ "Basic Pattern matching" should {
+ "match extractor: Integer(n)" in {
+ recursivePatternMatching("456") aka "matching" mustEqual List("num-456")
+ }
+
+ "not match" in {
+ recursivePatternMatching("@") aka "matching" mustEqual Nil
+ }
+ }
+
+ "Extractor with unapply" should {
+ "rich match without binding: ~(IntRange(5, 10))" in {
+ recursivePatternMatching("7") aka "matching" mustEqual List("5-to-10")
+ }
+
+ "rich match without binding: ~(IntRange(10, 20), i)" in {
+ recursivePatternMatching("12") aka "matching" mustEqual List("range:12")
+ }
+ }
+
+ "Extractor with unapplySeq" should {
+ "rich match without binding: ~(Regex(re))" in {
+ recursivePatternMatching("abc").
+ aka("matching") mustEqual List("no-binding")
+ }
+
+ "rich match with one binding: ~(Regex(re), a)" in {
+ recursivePatternMatching("# BCD.") aka "matching" mustEqual List("BCD")
+ }
+
+ "rich match with several bindings: ~(Regex(re), (a, b))" in {
+ recursivePatternMatching("123;xyz") aka "matching" mustEqual List(
+ "123;xyz", "xyz", "123")
+ }
+ }
+ }
+
"Partial function #1" >> {
"Basic Pattern matching" should {
"match extractor: Integer(n)" in {
@@ -195,6 +233,18 @@ sealed trait MatchTest {
case str @ ~(Regex("([0-9]+);([a-z]+)"), (a, b)) List(str, b, a)
case x Nil
}
+
+ def recursivePatternMatching(s: String): List[String] = s match {
+ case v v match {
+ case ~(IntRange(5, 10), _) List("5-to-10")
+ case ~(IntRange(10, 20), i) List(s"range:$i")
+ case Integer(n) List(s"num-$n")
+ case ~(Regex("^a.*")) List("no-binding")
+ case ~(Regex("# ([A-Z]+).*"), a) List(a)
+ case str @ ~(Regex("([0-9]+);([a-z]+)"), (a, b)) List(str, b, a)
+ case x Nil
+ }
+ }
}
/**

0 comments on commit 28fff28

Please sign in to comment.