Permalink
Browse files

Loops: filter vals to defs

  • Loading branch information...
1 parent 825e725 commit 9da5fe9c6db36663b3b3a77b185ea435ba819309 @ochafik ochafik committed Aug 19, 2013
@@ -5,6 +5,8 @@ import scalaxy.loops._
import scalaxy.beans._
import scalaxy.reified._
+import scala.language.postfixOps
+
import org.junit._
import org.junit.Assert._
@@ -140,6 +140,14 @@ package loops
(start, end, step, isInclusive)
}
}
+ def newInlineAnnotation = {
+ Apply(
+ Select(
+ New(Ident(typeOf[scala.inline].typeSymbol)),
+ nme.CONSTRUCTOR),
+ Nil)
+ }
+
c.typeCheck(c.prefix.tree) match {
case OptimizedRange(rangeTree, range) =>
if (disabled) {
@@ -169,9 +177,17 @@ package loops
// Body expects a local constant: create a var outside the loop + a val inside it.
val iVar = newIntVar(c.fresh("i"), range.start)
val iVal = newIntVal(param.name, Ident(iVar.name))
- val filterVals = range.filters.map(filter => {
- ValDef(NoMods, c.fresh("filter"): TermName, TypeTree(typeOf[Int => Boolean]), filter)
- })
+ val filterVals = range.filters.map {
+ case Function(vparams, body) =>
+ DefDef(
+ NoMods.mapAnnotations(list => newInlineAnnotation :: list),
+ c.fresh("filter"): TermName,
+ Nil,
+ List(vparams),
+ TypeTree(NoType),
+ body)
+ // ValDef(NoMods, c.fresh("filter"): TermName, TypeTree(typeOf[Int => Boolean]), filter)
+ }
val stepVal = newIntVal(c.fresh("step"), Literal(Constant(step)))
val endVal = newIntVal(c.fresh("end"), range.end)
val loopCondition =
@@ -204,7 +220,7 @@ package loops
Apply(
Select(
Ident(iVar.name),
- encode("+")
+ encode("+"): TermName
),
List(Ident(stepVal.name))
)
@@ -213,7 +229,7 @@ package loops
val iVarRef = c.Expr[Int](Ident(iVar.name))
val stepValRef = c.Expr[Int](Ident(stepVal.name))
- val loop =
+ val loop =
if (filterVals.isEmpty) {
reify {
while (loopConditionExpr.splice) {
@@ -225,7 +241,8 @@ package loops
} else {
val filterApplies: List[Tree] = filterVals.map(filterVal => {
Apply(
- Select(Ident(filterVal.name), "apply": TermName),
+ Ident(filterVal.name),
+ // Select(Ident(filterVal.name), "apply": TermName),
List(
Ident(iVar.name)
)
@@ -254,7 +271,7 @@ package loops
Block(
(iVar :: endVal :: stepVal :: filterVals) :+ loop.tree: _*)
)
- println("res = " + res)
+ // println("res = " + res)
res
case _ =>
c.error(f.tree.pos, s"Unsupported function: $f")
@@ -182,4 +182,15 @@ class LoopsTest
)
)
}
+
+ @Test
+ def simpleFilter {
+ var tot = 0
+ for (i <- 0 to 10 optimized;
+ j <- 0 to 2 optimized;
+ if i != j) {
+ tot += i * 10 + j
+ }
+ assertEquals(1650, tot)
+ }
}

0 comments on commit 9da5fe9

Please sign in to comment.