Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Updated syntax : added way to warn / fail build, added conditional ac…

…tions (`when` construct that matches trees)
  • Loading branch information...
commit b1891bb3286185caeace5ce13e12981eb475c9c9 1 parent 30bf227
@ochafik ochafik authored
View
40 ...alaxy/components/ReplacementsParser.scala → ...y/components/MatchActionDefinitions.scala
@@ -1,12 +1,13 @@
package scalaxy; package components
-object ReplacementDefinitions {
+object MatchActionDefinitions {
import scala.reflect.api._
import scala.reflect.runtime._
import scala.reflect.runtime.Mirror._
import definitions._
- private lazy val ReplacementClass = staticClass("scalaxy.Replacement")//definitions.getClass(newTypeName("scalaxy.Replacement"))
+ //private lazy val ReplacementClass = staticClass("scalaxy.Replacement")//definitions.getClass(newTypeName("scalaxy.Replacement"))
+ private lazy val MatchActionClass = staticClass("scalaxy.MatchAction")
private lazy val defaultValues = Map(
IntClass.tpe -> 0,
@@ -34,7 +35,7 @@ object ReplacementDefinitions {
invoke(target, method)(params)
}
}
- def getReplacementDefinitions(holder: AnyRef): Seq[(String, Replacement)] = {
+ def getMatchActionDefinitions(holder: AnyRef): Seq[(String, MatchAction[Any])] = {
//val holder = Example
//val holderSym = staticClass("scalaxy.Example")
//val holder = companionInstance(holderSym)
@@ -51,13 +52,22 @@ object ReplacementDefinitions {
def unapply(t: Type): Option[Type] = t match {
case MethodType(_, r) =>
unapply(r)
+ case PolyType(_, mt) =>
+ unapply(mt)
case _ =>
Some(t)
}
}
- holderType.members.filter(_.isMethod).map(m => (m, m.tpe)).collect {
- case (m, PolyType(paramsyms, mt @ FollowMethodResult(result)))
- if result == ReplacementClass.tpe =>
+ val methods = holderType.members.filter(_.isMethod)
+ //println("Scanning holder " + holder + " : " + methods.size + " methods")
+ //for (m <- methods) println("\t" + m.name)
+
+ methods.map(m => (m, m.tpe)).collect({
+ //case (m, PolyType(paramsyms, mt @ FollowMethodResult(result)))
+ case (m, mt @ FollowMethodResult(result))
+ =>
+ // TODO
+ //if result.stat_<:<(MatchActionClass.tpe) => // == ReplacementClass.tpe =>
def getParamTypes(t: Type): List[Type] = t match {
case MethodType(tt, r) =>
tt.map(_.tpe) ++ getParamTypes(r)
@@ -68,9 +78,19 @@ object ReplacementDefinitions {
val defaultParams =
actualParamTypes.map(getDefaultValue(_).asInstanceOf[AnyRef])
- val r = invokeMethod(holder, m, defaultParams)
- //println("r = " + r)
- (holder + "." + m.name, r.asInstanceOf[Replacement])
- }
+ try {
+ val r = invokeMethod(holder, m, defaultParams)
+ //println("r = " + r + " : " + r.getClass.getName)
+ if (r.isInstanceOf[MatchAction[_]])
+ Some((holder + "." + m.name, r.asInstanceOf[MatchAction[Any]]))
+ else
+ None
+ } catch { case ex =>
+ None
+ }
+ case (m, t) =>
+ println("Unable to parse method '" + m.name + "' : " + dbgStr(t))
+ None
+ }).flatten
}
}
View
64 ...xy/components/ReplacementsComponent.scala → ...xy/components/MatchActionsComponent.scala
@@ -15,7 +15,7 @@ import scala.reflect._
//import scala.tools.nsc.typechecker.Contexts._
-object ReplacementsComponent {
+object MatchActionsComponent {
val runsAfter = List[String](
"typer"
)
@@ -25,7 +25,7 @@ object ReplacementsComponent {
val phaseName = "scalaxy-rewriter"
}
-class ReplacementsComponent(val global: Global, val options: PluginOptions, val replacementHolders: AnyRef*)
+class MatchActionsComponent(val global: Global, val options: PluginOptions, val matchActionHolders: AnyRef*)
extends PluginComponent
with Transform
with TypingTransformers
@@ -43,21 +43,21 @@ extends PluginComponent
import typer.typed
import analyzer.{SearchResult, ImplicitSearch, UnTyper}
- override val runsAfter = ReplacementsComponent.runsAfter
- override val runsBefore = ReplacementsComponent.runsBefore
- override val phaseName = ReplacementsComponent.phaseName
+ override val runsAfter = MatchActionsComponent.runsAfter
+ override val runsBefore = MatchActionsComponent.runsBefore
+ override val phaseName = MatchActionsComponent.phaseName
- import ReplacementDefinitions._
+ import MatchActionDefinitions._
- case class ConvertedReplacement(pattern: Tree, replacement: Bindings => Tree)
+ case class ConvertedMatchAction(pattern: Tree, matchAction: MatchAction[Any])
- val replacements = replacementHolders.filter(_ != null).flatMap(getReplacementDefinitions(_)).map {
- case (n, r) =>
- val conv = mirrorToGlobal(r.pattern, EmptyBindings)
- println("Registered replacement '" + n + "'")
- (n, ConvertedReplacement(conv, bindings => {
+ val matchActions = matchActionHolders.filter(_ != null).flatMap(getMatchActionDefinitions(_)).map {
+ case (n, m) =>
+ val conv = mirrorToGlobal(m.pattern, EmptyBindings)
+ println("Registered match action '" + n + "'")// = " + m)
+ (n, ConvertedMatchAction(conv, m))/*bindings => {
mirrorToGlobal(r.replacement, bindings)
- }))
+ }))*/
}
def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) {
@@ -65,14 +65,42 @@ extends PluginComponent
val sup = super.transform(tree)
var expanded = sup
- for ((n, r) <- replacements) {
+ for ((n, convertedMatchAction) <- matchActions) {
try {
- val bindings @ Bindings(nameBindings, typeBindings) = matchAndResolveBindings(r.pattern, expanded)
+ val bindings @ Bindings(nameBindings, typeBindings) =
+ matchAndResolveBindings(convertedMatchAction.pattern, expanded)
+
println("Bindings for '" + n + "':\n\t" + (nameBindings ++ typeBindings).mkString("\n\t"))
- val replacement = r.replacement(bindings)
- println("Replacement '" + n + "':\n\t" + replacement.toString.replaceAll("\n", "\n\t"))
- expanded = replacement
+ convertedMatchAction.matchAction match {
+ case r: Replacement[_] =>
+ val replacement = mirrorToGlobal(r.replacement, bindings)
+ println("Replacement '" + n + "':\n\t" + replacement.toString.replaceAll("\n", "\n\t"))
+ expanded = replacement
+ case MatchWarning(_, message) =>
+ unit.warning(tree.pos, message)
+ case MatchError(_, message) =>
+ unit.error(tree.pos, message)
+ case ConditionalAction(_, when, then) =>
+ val treesToTest: Seq[mirror.Tree] =
+ when.map(n => {
+ globalToMirror(nameBindings(global.newTermName(n)))
+ })
+
+ if (then.isDefinedAt(treesToTest)) {
+ then.apply(treesToTest) match {
+ case r: ReplaceBy[_] =>
+ val replacement = mirrorToGlobal(r.replacement, bindings)
+ println("Replace by '" + n + "':\n\t" + replacement.toString.replaceAll("\n", "\n\t"))
+ expanded = replacement
+ case Warning(message) =>
+ unit.warning(tree.pos, message)
+ case Error(message) =>
+ unit.error(tree.pos, message)
+ case null =>
+ }
+ }
+ }
} catch {
case NoTypeMatchException(expected, found, msg) =>
case NoTreeMatchException(expected, found, msg) =>
View
36 Core/src/main/scala/scalaxy/components/MirrorConversions.scala
@@ -13,10 +13,10 @@ extends Replacements
import global.definitions._
- def newImporter(bindings: Bindings) = {
+
+ def newMirrorToGlobalImporter(bindings: Bindings) = {
new global.Importer {
val from = mirror.asInstanceOf[scala.reflect.internal.SymbolTable]
-
override def importTree(tree: from.Tree): global.Tree = {
tree match {
case from.Ident(n) =>
@@ -46,13 +46,31 @@ extends Replacements
}
}
}
+ def newGlobalToMirrorImporter = {
+ val mm = mirror.asInstanceOf[scala.reflect.internal.SymbolTable]
+ new mm.Importer {
+ val from = global
+ //val from = global.asInstanceOf[scala.reflect.internal.SymbolTable]
+ override def importTree(tree: from.Tree): mm.Tree = {
+ tree match {
+ case from.Ident(n) =>
+ val in = importName(n)
+ val imp = mm.Ident(in)
+ imp.tpe = importType(tree.tpe)
+ imp
+ case _ =>
+ super.importTree(tree)
+ }
+ }
+ }
+ }
/**
* TODO report missing API : scala.reflect.api.SymbolTable
* (scala.reflect.mirror does not extend scala.reflect.internal.SymbolTable publicly !)
*/
def mirrorToGlobal(m: mirror.Tree, bindings: Bindings): global.Tree = {
- val importer = newImporter(bindings)
+ val importer = newMirrorToGlobalImporter(bindings)
new mirror.Traverser {
override def traverse(t: mirror.Tree) = {
val tpe = t.tpe
@@ -69,10 +87,20 @@ extends Replacements
}
implicit def mirrorToGlobal(m: mirror.Name, bindings: Bindings): global.Name = {
- val importer = newImporter(bindings)
+ val importer = newMirrorToGlobalImporter(bindings)
importer.importName(m.asInstanceOf[importer.from.Name])
}
+ def globalToMirror(t: global.Name): mirror.Name = {
+ val importer = newGlobalToMirrorImporter
+ importer.importName(t.asInstanceOf[importer.from.Name]).asInstanceOf[mirror.Name]
+ }
+
+ def globalToMirror(t: global.Tree): mirror.Tree = {
+ val importer = newGlobalToMirrorImporter
+ importer.importTree(t.asInstanceOf[importer.from.Tree]).asInstanceOf[mirror.Tree]
+ }
+
/*
def mirrorNodeToString(tree: mirror.Tree) = {
new mirror.Traverser {
View
7 Core/src/main/scala/scalaxy/components/Replacements.scala
@@ -160,13 +160,13 @@ extends TypingTransformers
case (global.Ident(n), _) =>
if (internalDefs.contains(n))
EmptyBindings
- else tree match {
+ else /*tree match {
case global.Ident(nn) if n.toString == nn.toString =>
EmptyBindings
- case _ =>
+ case _ =>*/
//println("GOT BINDING " + pattern + " -> " + tree + " (tree is " + tree.getClass.getName + ")")
Bindings(Map(n -> tree), Map())
- }
+ //}
case (global.ValDef(mods, name, tpt, rhs), global.ValDef(mods2, name2, tpt2, rhs2))
if mods.modifiers == mods2.modifiers =>
@@ -189,6 +189,7 @@ extends TypingTransformers
matchAndResolveTreeBindings((v, v2) :: l.zip(l2), depth + 1)(internalDefs ++ getNamesDefinedIn(l))
case (global.Select(a, n), global.Select(a2, n2)) if n == n2 =>
+ //println("Matched select " + a + " vs. " + a2)
matchAndResolveTreeBindings(a, a2, depth + 1)
// TODO
View
45 IDEAS
@@ -0,0 +1,45 @@
+- //inside[_ <: javax.swing.JComponent].
+ // BOF : insidePackage("scala.lang").
+
+- activer remplacements dans un scope
+
+ replacing(..., ..., ...) {
+
+ }
+
+- activer ou desactiver
+
+- ajouter dependence des rewrites dans sbt
+
+- plugin sbt pour generer remplacement pour sa librairie et les publier (replacements.properties...)
+
+- Refactoring : hyper utile pour migrer � nouvelle version code
+ -> ...
+
+- object MesExemples extends Rock {
+
+ def enabled(compilerContext) = {
+ if (compC.scalaVersion < ...)
+ false
+
+ }
+
+ def context1(....) =
+ replace(patt, rep)
+
+ def context1(...) =
+ warn(msg) { patt }
+
+ def context1(...) =
+ fail(msg) { patt }
+
+ def context1(...) =
+ when(patt)(id1, id2...) {
+ case ...: Tree =>
+ replacement(rep)
+ case ... =>
+ error(rep)
+ case ... =>
+ warning(rep)
+ }
+}
View
81 Macros/src/main/scala/scalaxy/Macros.scala
@@ -1,28 +1,81 @@
package scalaxy
import scala.reflect._
-class Replacement(val pattern: mirror.Tree, val replacement: mirror.Tree) {
- override def toString =
- "Replacement(" + pattern + ", " + replacement + ")"
+
+sealed trait Action[T]
+trait MatchAction[T] extends Action[T] {
+ def pattern: mirror.Tree
}
-//sealed trait AnalysisResult
-//case class Warning(pos: mirror.Position, msg: String) extends AnalysisResult
-//case class Error(pos: mirror.Position, msg: String) extends AnalysisResult
+case class ReplaceBy[T](replacement: mirror.Tree) extends Action[T]
+case class Error[T](message: String) extends Action[T]
+case class Warning[T](message: String) extends Action[T]
+
+case class Replacement[T](
+ pattern: mirror.Tree,
+ replacement: mirror.Tree
+) extends MatchAction[T]
+case class MatchError[T](pattern: mirror.Tree, message: String) extends MatchAction[T]
+case class MatchWarning[T](pattern: mirror.Tree, message: String) extends MatchAction[T]
+
+case class ConditionalAction[T](
+ pattern: mirror.Tree,
+ when: Seq[String],
+ then: PartialFunction[Seq[mirror.Tree], Action[T]]
+) extends MatchAction[T]
-//class Analysis(val pattern: mirror.Tree, f: )
+package object macros {
-object Macros {
- implicit def tree2pos(tree: mirror.Tree) =
- tree.pos
-
- def macro Replacement[T](pattern: T, replacement: T): Replacement = {
+ def macro fail[T](message: String)(pattern: T): MatchAction[T] = {
+ New(
+ Select(Ident(newTermName("scalaxy")), newTypeName("MatchError")),
+ List(List(reify(pattern), message))
+ )
+ }
+
+ def macro warn[T](message: String)(pattern: T): MatchAction[T] = {
+ New(
+ Select(Ident(newTermName("scalaxy")), newTypeName("MatchWarning")),
+ List(List(reify(pattern), message))
+ )
+ }
+
+ def macro replace[T](pattern: T, replacement: T): Replacement[T] = {
New(
Select(Ident(newTermName("scalaxy")), newTypeName("Replacement")),
List(List(reify(pattern), reify(replacement)))
)
}
- def macro tree(v: Any): mirror.Tree =
- reify(v)
+ def macro when[T](pattern: T)(identifiers: Any*)(then: PartialFunction[Seq[mirror.Tree], Action[T]])
+ : ConditionalAction[T] =
+ {
+ val scalaCollection =
+ Select(Ident(newTermName("scala")), newTermName("collection"))
+
+ New(
+ Select(Ident(newTermName("scalaxy")), newTypeName("ConditionalAction")),
+ List(List(
+ reify(pattern),
+ Apply(
+ Select(Select(scalaCollection, newTermName("Seq")), newTermName("apply")),
+ identifiers.toList.map { case Ident(n) => Literal(Constant(n.toString)) }
+ ),
+ then
+ ))
+ )
+ }
+
+ def error[T](message: String): Action[T] =
+ Error[T](message)
+
+ def warning[T](message: String): Action[T] =
+ Warning[T](message)
+
+ def macro replacement[T](replacement: T): ReplaceBy[T] = {
+ New(
+ Select(Ident(newTermName("scalaxy")), newTypeName("ReplaceBy")),
+ List(List(reify(replacement)))
+ )
+ }
}
View
38 Macros/src/main/scala/scalaxy/Matchers.scala
@@ -0,0 +1,38 @@
+package scalaxy.matchers
+
+import scala.reflect.mirror._
+
+object PositiveConstant {
+ def unapply(tree: Tree): Option[Int] =
+ Option(tree) collect {
+ case Literal(Constant(v: Int)) if v > 0 =>
+ v
+ }
+}
+object NegativeConstant {
+ def unapply(tree: Tree): Option[Int] =
+ Option(tree) collect {
+ case Literal(Constant(v: Int)) if v < 0 =>
+ v
+ }
+}
+
+object True {
+ def unapply(tree: Tree): Boolean =
+ tree match {
+ case Literal(Constant(true)) =>
+ true
+ case _ =>
+ false
+ }
+}
+
+object False {
+ def unapply(tree: Tree): Boolean =
+ tree match {
+ case Literal(Constant(false)) =>
+ true
+ case _ =>
+ false
+ }
+}
View
4 Rewrites/Example.scala
@@ -1,10 +1,10 @@
package scalaxy; package rewrites
-import Macros._
+import macros._
object Example {
- def intToStringQuoter[U](i: Int) = Replacement(
+ def intToStringQuoter[U](i: Int) = replace(
i.toString,
"'" + i.toString + "'"
)
View
44 Rewrites/ForLoops.scala
@@ -1,10 +1,13 @@
package scalaxy; package rewrites
-import Macros._
+import macros._
+import matchers._
+//import scala.reflect.mirror._
object ForLoops {
- def simpleForeachUntil[U](start: Int, end: Int, body: U) = Replacement(
- for (i <- start until end) body,
+ def simpleForeachUntil[U](start: Int, end: Int, body: U) = replace(
+ for (i <- start until end)
+ body,
{
var ii = start
while (ii < end) {
@@ -14,7 +17,7 @@ object ForLoops {
}
}
)
- def simpleForeachTo[U](start: Int, end: Int, body: U) = Replacement(
+ def simpleForeachTo[U](start: Int, end: Int, body: U) = replace(
for (i <- start to end) body,
{
var ii = start
@@ -25,8 +28,37 @@ object ForLoops {
}
}
)
+
+ def rgForeachUntilBy[U](start: Int, end: Int, step: Int, body: U) =
+ when(
+ for (i <- start until end by step)
+ body
+ )(
+ step
+ ) {
+ case Seq(PositiveConstant(_)) =>
+ replacement {
+ var ii = start
+ while (ii < end) {
+ val i = ii
+ body
+ ii = ii + step
+ }
+ }
+ case Seq(NegativeConstant(_)) =>
+ replacement {
+ var ii = start
+ while (ii > end) {
+ val i = ii
+ body
+ ii = ii - step
+ }
+ }
+ case _ =>
+ warning("Cannot optimize : step is not constant")
+ }
/*
- def simpleForeachUntil[U](start: Int, end: Int, body: Int => U) = Replacement(
+ def simpleForeachUntil[U](start: Int, end: Int, body: Int => U) = replace(
for (i <- start until end) body(i),
{
var ii = start
@@ -37,7 +69,7 @@ object ForLoops {
}
}
)
- def simpleForeachTo[U](start: Int, end: Int, body: Int => U) = Replacement(
+ def simpleForeachTo[U](start: Int, end: Int, body: Int => U) = replace(
for (i <- start to end) body(i),
{
var ii = start
View
29 Rewrites/Java.scala
@@ -0,0 +1,29 @@
+package scalaxy.rewrites
+
+object Java {
+ import scalaxy._
+ import matchers._
+ import macros._
+
+
+ def warnAccessibleField(f: java.lang.reflect.Field, b: Boolean) =
+ when(f.setAccessible(b))(f, b) {
+ case Seq(_, True()) =>
+ warning("You shouldn't do that")
+ case r =>
+ println("Failed to match case in warnAccessibleField : " + r)
+ null
+ }
+
+ def forbidThreadStop(t: Thread) =
+ fail("You must NOT call Thread.stop() !") {
+ t.stop
+ }
+
+ //def replaceAccessibleField(f: java.lang.reflect.Field, b: Boolean) =
+ // replace(f.setAccessible(b), f.setAccessible(false))
+
+
+
+ println("Auto warn " + warnAccessibleField(null, false))
+}
View
20 Rewrites/Numeric.scala
@@ -1,47 +1,47 @@
package scalaxy; package rewrites
-import Macros._
+import macros._
object Numeric {
import math.Numeric.Implicits._
import Ordering.Implicits._
- def plus[T](a: T, b: T)(implicit n: Numeric[T]) = Replacement(
- a + b,
+ def plus[T](a: T, b: T)(implicit n: Numeric[T]) = replace(
+ a + b, // Numeric.Implicits.infixNumericOps[T](a)(n).+(b)
n.plus(a, b)
)
- def minus[T](a: T, b: T)(implicit n: Numeric[T]) = Replacement(
+ def minus[T](a: T, b: T)(implicit n: Numeric[T]) = replace(
a - b,
n.minus(a, b)
)
- def times[T](a: T, b: T)(implicit n: Numeric[T]) = Replacement(
+ def times[T](a: T, b: T)(implicit n: Numeric[T]) = replace(
a * b,
n.times(a, b)
)
- def negate[T](a: T)(implicit n: Numeric[T]) = Replacement(
+ def negate[T](a: T)(implicit n: Numeric[T]) = replace(
- a,
n.negate(a)
)
- def gt[T](a: T, b: T)(implicit n: Numeric[T]) = Replacement(
+ def gt[T](a: T, b: T)(implicit n: Numeric[T]) = replace(
a > b,
n.gt(a, b)
)
- def gteq[T](a: T, b: T)(implicit n: Numeric[T]) = Replacement(
+ def gteq[T](a: T, b: T)(implicit n: Numeric[T]) = replace(
a >= b,
n.gteq(a, b)
)
- def lt[T](a: T, b: T)(implicit n: Numeric[T]) = Replacement(
+ def lt[T](a: T, b: T)(implicit n: Numeric[T]) = replace(
a < b,
n.lt(a, b)
)
- def lteq[T](a: T, b: T)(implicit n: Numeric[T]) = Replacement(
+ def lteq[T](a: T, b: T)(implicit n: Numeric[T]) = replace(
a <= b,
n.lteq(a, b)
)
View
4 Rewrites/Streams.scala
@@ -1,10 +1,10 @@
package scalaxy; package rewrites
-import Macros._
+import macros._
object Streams {
// TODO add conditions macro + isSideEffectFree(f)
- def mapMap[A, B, C](col: Seq[A], f: A => B, g: B => C) = Replacement(
+ def mapMap[A, B, C](col: Seq[A], f: A => B, g: B => C) = replace(
col.map(f).map(g),
col.map(a => {
val b = f(a)
View
25 Test/test.scala
@@ -1,6 +1,24 @@
object RunMe extends App {
//class RunMe {
+ //import scalaxy._//macros._
+ //println(tree(Seq(1, 2, 3)))
+
+ {
+ val th = new Thread(new Runnable { override def run = println("...") })
+ th.start
+ th.stop
+ }
+
+ {
+ class C {
+ private val i = 10
+ }
+ val f = classOf[C].getField("i")
+ f.setAccessible(true)
+ }
+
+ /*
{
import math.Numeric.Implicits._
import Ordering.Implicits._
@@ -32,6 +50,10 @@ object RunMe extends App {
for (i <- 1 until 10)
println("i = " + i + " // v = " + v)
}
+ println(trans(Seq(1, 2, 3)))
+ println(trans(Seq(2, 3, 4), 10))
+ */
+
/*def transManual(col: Seq[Int]) = {
col.map(a => {
val b = ((a:Int) => a + 1)(a)
@@ -46,7 +68,4 @@ object RunMe extends App {
run
*/
- println(trans(Seq(1, 2, 3)))
- println(trans(Seq(2, 3, 4), 10))
-
}
View
7 project/ScalaxyBuild.scala
@@ -12,7 +12,8 @@ object ScalaxyBuild extends Build
infoSettings ++
compilationSettings ++
mavenSettings ++
- scalaSettings
+ scalaSettings ++
+ commonDepsSettings
lazy val infoSettings = Seq(
organization := "com.nativelibs4java",
@@ -43,6 +44,10 @@ object ScalaxyBuild extends Build
resolvers += Resolver.sonatypeRepo("snapshots")
//exportJars := true, // use jars in classpath
)
+ lazy val commonDepsSettings = Seq(
+ libraryDependencies <+= scalaVersion(v => "org.scalatest" % ("scalatest_2.9.1"/* + v*/) % "1.7.1" % "test")
+ //libraryDependencies += "org.scala-tools.testing" %% "scalacheck" % "1.9" % "test"
+ )
lazy val scalaxy =
Project(id = "scalaxy", base = file("."), settings = standardSettings ++ Seq(
View
3  src/main/scala/scalaxy/plugin/ScalaxyPlugin.scala
@@ -60,9 +60,10 @@ object ScalaxyPluginDef extends PluginDef {
override def createComponents(global: Global, options: PluginOptions): List[PluginComponent] =
List(
- new ReplacementsComponent(global, options,
+ new MatchActionsComponent(global, options,
//rewrites.Example,
//rewrites.Streams,
+ rewrites.Java,
rewrites.Numeric,
rewrites.ForLoops
)
View
475 src/test/scala/scalaxy/BaseTestUtils.scala
@@ -0,0 +1,475 @@
+package scalaxy ; package test
+//import plugin._
+import pluginBase._
+
+import java.io.BufferedReader
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
+import java.io.File
+import java.io.IOException
+import java.io.InputStreamReader
+import java.io.OutputStreamWriter
+import java.io.PrintWriter
+import scala.collection.mutable.HashMap
+import scala.concurrent.ops
+import scala.io.Source
+
+import java.net.URI
+import java.net.URLClassLoader
+import javax.tools.DiagnosticCollector
+import javax.tools.FileObject
+import javax.tools.ForwardingJavaFileManager
+import javax.tools.JavaCompiler
+import javax.tools.JavaFileManager
+import javax.tools.JavaFileObject
+import javax.tools.ToolProvider
+import org.junit.Assert._
+import scala.tools.nsc.Settings
+import scala.actors.Futures._
+import Function.{tupled, untupled}
+
+object Results {
+ import java.io._
+ import java.util.Properties
+ def getPropertiesFileName(n: String) = n + ".perf.properties"
+ val logs = new scala.collection.mutable.HashMap[String, (String, PrintStream, Properties)]
+ def getLog(key: String) = {
+ logs.getOrElseUpdate(key, {
+ val logName = getPropertiesFileName(key)
+ //println("Opening performance log file : " + logName)
+
+ val logRes = getClass.getClassLoader.getResourceAsStream(logName)
+ val properties = new java.util.Properties
+ if (logRes != null) {
+ println("Reading " + logName)
+ properties.load(logRes)
+ }
+ (logName, new PrintStream(logName), properties)
+ })
+ }
+ Runtime.getRuntime.addShutdownHook(new Thread { override def run {
+ for ((_, (logName, out, _)) <- logs) {
+ println("Wrote " + logName)
+ out.close
+ }
+ }})
+}
+object BaseTestUtils {
+ private var _nextId = 1
+ def nextId = BaseTestUtils synchronized {
+ val id = _nextId
+ _nextId += 1
+ id
+ }
+}
+trait BaseTestUtils {
+ import BaseTestUtils._
+
+ implicit val baseOutDir = new File("target/testSnippetsClasses")
+ baseOutDir.mkdirs
+
+ def pluginDef: PluginDef
+
+ object SharedCompilerWithPlugins extends SharedCompiler(true, pluginDef)
+ object SharedCompilerWithoutPlugins extends SharedCompiler(false, pluginDef)
+
+ lazy val options: PluginOptions = {
+ val o = pluginDef.createOptions(null)
+ o.test = true
+ o
+ }
+ /*def compile(src: String, outDir: String) = {
+ outDir.mkdirs
+
+ val srcFile = File.createTempFile("temp", ".scala")
+ val out = new PrintWriter(srcFile)
+ out.println(src)
+ out.close
+ //srcFile.delete
+
+ }*/
+ def getSnippetBytecode(className: String, source: String, subDir: String, compiler: SharedCompiler) = {
+ val src = "class " + className + " { def invoke(): Unit = {\n" + source + "\n}}"
+ val outDir = new File(baseOutDir, subDir)
+ outDir.mkdirs
+ val srcFile = new File(outDir, className + ".scala")
+ val out = new PrintWriter(srcFile)
+ out.println(src)
+ out.close
+ new File(outDir, className + ".class").delete
+
+ compiler.compile(
+ Array(
+ "-d",
+ outDir.getAbsolutePath,
+ srcFile.getAbsolutePath
+ ) ++
+ getAdditionalClassPath.map(
+ cp => Seq("-cp", cp.mkString(File.pathSeparator))
+ ).getOrElse(Seq())
+ )
+
+ val f = new File(outDir, className + ".class")
+ if (!f.exists())
+ throw new RuntimeException("Class file " + f + " not found !")
+
+
+ val byteCodeSource = getClassByteCode(className, outDir.getAbsolutePath)
+ val byteCode = byteCodeSource.mkString//("\n")
+ /*
+ println("COMPILED :")
+ println("\t" + source.replaceAll("\n", "\n\t"))
+ println("BYTECODE :")
+ println("\t" + byteCode.replaceAll("\n", "\n\t"))
+ */
+
+ byteCode.
+ replaceAll("scala/reflect/ClassManifest", "scala/reflect/Manifest").
+ replaceAll("#\\d+", "")
+ }
+
+ def ensurePluginCompilesSnippet(source: String) = {
+ val (_, testMethodName) = testClassInfo
+ assertNotNull(getSnippetBytecode(testMethodName, source, "temp", SharedCompilerWithPlugins))
+ }
+ def ensurePluginCompilesSnippetsToSameByteCode(sourcesAndReferences: Traversable[(String, String)]): Unit = {
+ def flatten(s: Traversable[String]) = s.map("{\n" + _ + "\n};").mkString("\n")
+ ensurePluginCompilesSnippetsToSameByteCode(flatten(sourcesAndReferences.map(_._1)), flatten(sourcesAndReferences.map(_._2)))
+ }
+ def ensurePluginCompilesSnippetsToSameByteCode(source: String, reference: String, allowSameResult: Boolean = false) = {
+ val (_, testMethodName) = testClassInfo
+
+ import scala.concurrent.ops._
+ implicit val runner = new scala.concurrent.ThreadRunner
+
+ /*
+ val expectedFut = future { getSnippetBytecode(className, reference, "expected", SharedCompilerWithoutPlugins1) }
+ val withoutPluginFut = future { getSnippetBytecode(className, source, "withoutPlugin", SharedCompilerWithoutPlugins2) }
+ val withPluginFut = future { getSnippetBytecode(className, source, "withPlugin", SharedCompilerWithPlugins) }//ScalaxyTestUtils.compilerWithPlugin) }
+ val (expected, withoutPlugin, withPlugin) = (expectedFut(), withoutPluginFut(), withPluginFut())
+ */
+ val enableFuture = true
+
+ def futEx[V](b: => V): () => V = if (!enableFuture) () => b else {
+ val f = future { try { Right(b) } catch { case ex => Left(ex) } }
+ () => f() match {
+ case Left(ex) =>
+ ex.printStackTrace
+ assertTrue(ex.toString, false)
+ error("")
+ case Right(v) =>
+ v
+ }
+ }
+
+ val withPluginFut = futEx { getSnippetBytecode(testMethodName, source, "withPlugin", SharedCompilerWithPlugins) }
+ val expected = getSnippetBytecode(testMethodName, reference, "expected", SharedCompilerWithoutPlugins)
+ val withoutPlugin = if (allowSameResult) null else getSnippetBytecode(testMethodName, source, "withoutPlugin", SharedCompilerWithoutPlugins)
+ val withPlugin = withPluginFut()
+
+ if (!allowSameResult)
+ assertTrue("Expected result already found without any plugin !!! (was the Scala compiler improved ?)", expected != withoutPlugin)
+
+ if (expected != withPlugin) {
+ def trans(tit: String, s: String) =
+ println(tit + " :\n\t" + s.replaceAll("\n", "\n\t"))
+
+ trans("EXPECTED", expected)
+ trans("FOUND", withPlugin)
+
+ assertEquals(expected, withPlugin)
+ }
+ }
+ def getClassByteCode(className: String, classpath: String) = {
+ val args = Array("-c", "-classpath", classpath, className)
+ val p = Runtime.getRuntime.exec("javap " + args.mkString(" "))//"javap", args)
+
+ var err = new StringBuffer
+ ops.spawn {
+ import scala.util.control.Exception._
+ val inputStream = new BufferedReader(new InputStreamReader(p.getErrorStream))
+ var str: String = null
+ //ignoring(classOf[IOException]) {
+ while ({ str = inputStream.readLine; str != null }) {
+ //err.synchronized {
+ println(str)
+ err.append(str).append("\n")
+ //}
+ //}
+ }
+ }
+
+ val out = Source.fromInputStream(p.getInputStream).toList
+ if (p.waitFor != 0) {
+ Thread.sleep(100)
+ error("javap (args = " + args.mkString(" ") + ") failed with :\n" + err.synchronized { err.toString } + "\nAnd :\n" + out)
+ }
+ out
+ }
+
+ import java.io.File
+ /*val outputDirectory = {
+ val f = new File(".")//target/classes")
+ if (!f.exists)
+ f.mkdirs
+ f
+ }*/
+
+ import java.io._
+
+ val packageName = "tests"
+
+ case class Res(withPlugin: Boolean, output: AnyRef, time: Double)
+ type TesterGen = Int => (Boolean => Res)
+
+ def fail(msg: String) = {
+ println(msg)
+ println()
+ assertTrue(msg, false)
+ }
+
+ trait RunnableMethod {
+ def apply(args: Any*): Any
+ }
+ abstract class RunnableCode(val pluginOptions: PluginOptions) {
+ def newInstance(constructorArgs: Any*): RunnableMethod
+ }
+
+ protected def compileCodeWithPlugin(decls: String, code: String) =
+ compileCode(withPlugin = true, code, "", decls, "")
+
+ def getAdditionalClassPath: Option[Seq[String]] =
+ None
+
+ protected def compileCode(withPlugin: Boolean, code: String, constructorArgsDecls: String = "", decls: String = "", methodArgsDecls: String = ""): RunnableCode = {
+ val (testClassName, testMethodName) = testClassInfo
+
+ val suffixPlugin = (if (withPlugin) "Optimized" else "Normal")
+ val className = "Test_" + testMethodName + "_" + suffixPlugin + "_" + nextId
+ val src = "package " + packageName + "\nclass " + className + "(" + constructorArgsDecls + """) {
+ """ + (if (decls == null) "" else decls) + """
+ def """ + testMethodName + "(" + methodArgsDecls + ")" + """ = {
+ """ + code + """
+ }
+ }"""
+
+ val outputDirectory = new File("tmpTestClasses" + suffixPlugin)
+ def del(dir: File): Unit = {
+ val fs = dir.listFiles
+ if (fs != null)
+ fs foreach del
+
+ dir.delete
+ }
+
+ del(outputDirectory)
+ outputDirectory.mkdirs
+
+ var tmpFile = new File(outputDirectory, testMethodName + ".scala")
+ val pout = new PrintStream(tmpFile)
+ pout.println(src)
+ //println("Source = \n\t" + src.replaceAll("\n", "\n\t"))
+ pout.close
+ //println(src)
+ val compileArgs = Array(
+ "-d",
+ outputDirectory.getAbsolutePath,
+ tmpFile.getAbsolutePath
+ ) ++ getAdditionalClassPath.map(
+ cp => Seq("-cp", cp.mkString(File.pathSeparator))
+ ).getOrElse(Seq())
+
+ //println("Compiling '" + tmpFile.getAbsolutePath + "' with args '" + compileArgs.mkString(" ") +"'")
+ val pluginOptions = (
+ if (withPlugin)
+ SharedCompilerWithPlugins
+ else
+ SharedCompilerWithoutPlugins
+ ).compile(compileArgs)
+
+ //println("CLASS LOADER WITH PATH = '" + outputDirectory + "'")
+ val loader = new URLClassLoader(Array(
+ outputDirectory.toURI.toURL,
+ new File(CompilerMain.bootClassPath).toURI.toURL
+ ))
+
+ val parent =
+ if (packageName == "")
+ outputDirectory
+ else
+ new File(outputDirectory, packageName.replace('.', File.separatorChar))
+
+ val f = new File(parent, className + ".class")
+ if (!f.exists())
+ throw new RuntimeException("Class file " + f + " not found !")
+
+ //compileFile(tmpFile, withPlugin, outputDirectory)
+
+ val testClass = loader.loadClass(packageName + "." + className)
+ val testMethod = testClass.getMethod(testMethodName)//, classOf[Int])
+ val testConstructor = testClass.getConstructors.first
+
+ new RunnableCode(pluginOptions) {
+ override def newInstance(constructorArgs: Any*) = new RunnableMethod {
+ val inst =
+ testConstructor.newInstance(constructorArgs.map(_.asInstanceOf[AnyRef]):_*).asInstanceOf[AnyRef]
+
+ assert(inst != null)
+
+ override def apply(args: Any*): Any = {
+ try {
+ testMethod.invoke(inst, args.map(_.asInstanceOf[AnyRef]):_*)
+ } catch { case ex =>
+ ex.printStackTrace
+ throw ex
+ }
+ }
+ }
+ }
+ }
+
+
+ private def getTesterGen(withPlugin: Boolean, decls: String, code: String) = {
+ val runnableCode = compileCode(withPlugin, code, "n: Int", decls, "")
+
+ (n: Int) => {
+ val i = runnableCode.newInstance(n)
+ (isWarmup: Boolean) => {
+ if (isWarmup) {
+ i()
+ null
+ } else {
+ System.gc
+ Thread.sleep(50)
+ val start = System.nanoTime
+ val o = i().asInstanceOf[AnyRef]
+ val time: Double = System.nanoTime - start
+ Res(withPlugin, o, time)
+ }
+ }
+ }
+ }
+ def testClassInfo = {
+ val testTrace = new RuntimeException().getStackTrace.filter(se => se.getClassName.endsWith("Test")).last
+ val testClassName = testTrace.getClassName
+ val methodName = testTrace.getMethodName
+ (testClassName, methodName)
+ }
+
+ val defaultExpectedFasterFactor = Option(System.getenv(pluginDef.envVarPrefix + "MIN_PERF")).map(_.toDouble).getOrElse(0.95)
+ val perfRuns = Option(System.getenv(pluginDef.envVarPrefix + "PERF_RUNS")).map(_.toInt).getOrElse(4)
+
+ def ensureCodeWithSameResult(code: String): Unit = {
+ val (testClassName, testMethodName) = testClassInfo
+
+ val gens @ Array(genWith, genWithout) = Array(getTesterGen(true, "", code), getTesterGen(false, "", code))
+
+ val testers @ Array(testerWith, testerWithout) = gens.map(_(-1))
+
+ val firstRun = testers.map(_(false))
+ val Array(optimizedOutput, normalOutput) = firstRun.map(_.output)
+
+ val pref = "[" + testClassName + "." + testMethodName + "] "
+ if (normalOutput != optimizedOutput) {
+ fail(pref + "ERROR: Output is not the same !\n" + pref + "\t Normal output = " + normalOutput + "\n" + pref + "\tOptimized output = " + optimizedOutput)
+ }
+ }
+ def ensureFasterCodeWithSameResult(decls: String, code: String, params: Seq[Int] = Array(2, 10, 1000, 100000)/*10000, 100, 20, 2)*/, minFaster: Double = 1.0, nRuns: Int = perfRuns): Unit = {
+
+ //println("Ensuring faster code with same result :\n\t" + (decls + "\n#\n" + code).replaceAll("\n", "\n\t"))
+ val (testClassName, methodName) = testClassInfo
+
+ val gens @ Array(genWith, genWithout) = Array(getTesterGen(true, decls, code), getTesterGen(false, decls, code))
+
+ def run = params.toList.sorted.map(param => {
+ //println("Running with param " + param)
+ val testers @ Array(testerWith, testerWithout) = gens.map(_(param))
+
+ val firstRun = testers.map(_(false))
+ val Array(optimizedOutput, normalOutput) = firstRun.map(_.output)
+
+ val pref = "[" + testClassName + "." + methodName + ", n = " + param + "] "
+ if (normalOutput != optimizedOutput) {
+ fail(pref + "ERROR: Output is not the same !\n" + pref + "\t Normal output = " + normalOutput + "\n" + pref + "\tOptimized output = " + optimizedOutput)
+ }
+
+ val runs: List[Res] = firstRun.toList ++ (1 until nRuns).toList.flatMap(_ => testers.map(_(false)))
+ def calcTime(list: List[Res]) = {
+ val times = list.map(_.time)
+ times.sum / times.size.toDouble
+ }
+ val (runsWithPlugin, runsWithoutPlugin) = runs.partition(_.withPlugin)
+ val (timeWithPlugin, timeWithoutPlugin) = (calcTime(runsWithPlugin), calcTime(runsWithoutPlugin))
+
+ (param, timeWithoutPlugin / timeWithPlugin)
+ }).toMap
+
+ val (logName, log, properties) = Results.getLog(testClassName)
+
+ //println("Cold run...")
+ val coldRun = run
+
+ //println("Warming up...");
+ // Warm up the code being benchmarked :
+ {
+ val testers = gens.map(_(5))
+ (0 until 2500).foreach(_ => testers.foreach(_(true)))
+ };
+
+ //println("Warm run...")
+ val warmRun = run
+
+
+ val errors = coldRun.flatMap { case (param, coldFactor) =>
+ val warmFactor = warmRun(param)
+ //println("coldFactor = " + coldFactor + ", warmFactor = " + warmFactor)
+
+ def f2s(f: Double) = ((f * 10).toInt / 10.0) + ""
+ def printFacts(warmFactor: Double, coldFactor: Double) = {
+ val txt = methodName + "\\:" + param + "=" + Array(warmFactor, coldFactor).map(f2s).mkString(";")
+ //println(txt)
+ log.println(txt)
+ }
+ //def printFact(factor: Double) = log.println(methodName + "\\:" + param + "=" + f2s(factor))
+ val (expectedWarmFactor, expectedColdFactor) = {
+ //val expectedColdFactor = {
+ val p = Option(properties.getProperty(methodName + ":" + param)).map(_.split(";")).orNull
+ if (p != null && p.length == 2) {
+ //val Array(c) = p.map(_.toDouble)
+ //val c = p.toDouble; printFact(c); c
+ //log.print("# Test result (" + (if (actualFasterFactor >= f) "succeeded" else "failed") + "): ")
+ val Array(w, c) = p.map(_.toDouble)
+ printFacts(w, c)
+ (w, c)
+ } else {
+ //printFact(coldFactor - 0.1); 1.0
+ printFacts(warmFactor - 0.1, coldFactor - 0.1)
+ (defaultExpectedFasterFactor, defaultExpectedFasterFactor)
+ }
+ }
+
+ def check(warm: Boolean, factor: Double, expectedFactor: Double) = {
+ val pref = "[" + testClassName + "." + methodName + ", n = " + param + ", " + (if (warm) "warm" else "cold") + "] "
+
+ if (factor >= expectedFactor) {
+ println(pref + " OK (" + factor + "x faster, expected > " + expectedFactor + "x)")
+ Nil
+ } else {
+ val msg = "ERROR: only " + factor + "x faster (expected >= " + expectedFactor + "x)"
+ println(pref + msg)
+ List(msg)
+ }
+ }
+
+ check(false, coldFactor, expectedColdFactor) ++
+ check(true, warmFactor, expectedWarmFactor)
+ }
+ try {
+ if (!errors.isEmpty)
+ assertTrue(errors.mkString("\n"), false)
+ } finally {
+ println()
+ }
+ }
+
+}
Please sign in to comment.
Something went wrong with that request. Please try again.