/
Eval.scala
222 lines (199 loc) · 8.6 KB
/
Eval.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
package sbt
package compiler
import scala.reflect.Manifest
import scala.tools.nsc.{ast, interpreter, io, reporters, util, CompilerCommand, Global, Phase, Settings}
import interpreter.AbstractFileClassLoader
import io.{AbstractFile, PlainFile, VirtualDirectory}
import ast.parser.Tokens
import reporters.{ConsoleReporter, Reporter}
import util.BatchSourceFile
import Tokens.EOF
import java.io.File
import java.nio.ByteBuffer
import java.net.URLClassLoader
// TODO: provide a way to cleanup backing directory
final class EvalImports(val strings: Seq[(String,Int)], val srcName: String)
final class EvalResult(val tpe: String, val value: Any, val generated: Seq[File], val enclosingModule: String)
final class EvalException(msg: String) extends RuntimeException(msg)
// not thread safe, since it reuses a Global instance
final class Eval(optionsNoncp: Seq[String], classpath: Seq[File], mkReporter: Settings => Reporter, parent: ClassLoader, backing: Option[File])
{
def this(mkReporter: Settings => Reporter, backing: Option[File]) = this(Nil, IO.classLocationFile[ScalaObject] :: Nil, mkReporter, getClass.getClassLoader, backing)
def this() = this(s => new ConsoleReporter(s), None)
backing.foreach(IO.createDirectory)
val classpathString = Path.makeString(classpath)
val options = "-cp" +: classpathString +: optionsNoncp
lazy val settings =
{
val s = new Settings(println)
val command = new CompilerCommand(options.toList, s)
s
}
lazy val reporter = mkReporter(settings)
lazy val global: Global = new Global(settings, reporter)
import global._
import definitions._
def eval(expression: String, imports: EvalImports = noImports, tpeName: Option[String] = None, srcName: String = "<setting>", line: Int = DefaultStartLine): EvalResult =
{
import Eval._
val hash = Hash.toHex(Hash(bytes( bytes(expression) :: optBytes(backing)(fileExistsBytes) :: seqBytes(options)(bytes) ::
seqBytes(classpath)(fileModifiedBytes) :: seqBytes(imports.strings.map(_._1))(bytes) :: optBytes(tpeName)(bytes) :: Nil)))
val moduleName = makeModuleName(hash)
lazy val unit = {
reporter.reset
mkUnit(srcName, line, expression)
}
lazy val run = new Run {
override def units = (unit :: Nil).iterator
}
def unlinkAll(): Unit = for( (sym, _) <- run.symSource ) unlink(sym)
def unlink(sym: Symbol) = sym.owner.info.decls.unlink(sym)
val (tpe, value) =
(tpeName, backing) match {
case (Some(tpe), Some(back)) if classExists(back, moduleName) => (tpe, loadPlain(back, moduleName))
case _ => try { eval0(expression, imports, tpeName, run, unit, backing, moduleName) } finally { unlinkAll() }
}
val classFiles = getClassFiles(backing, moduleName)
new EvalResult(tpe, value, classFiles, moduleName)
}
def eval0(expression: String, imports: EvalImports, tpeName: Option[String], run: Run, unit: CompilationUnit, backing: Option[File], moduleName: String): (String, Any) =
{
val dir = backing match { case None => new VirtualDirectory("<virtual>", None); case Some(dir) => new PlainFile(dir) }
settings.outputDirs setSingleOutput dir
val importTrees = parseImports(imports)
val (parser, tree) = parseExpr(unit)
val tpt: Tree = tpeName match {
case Some(tpe) => parseType(tpe)
case None => TypeTree(NoType)
}
unit.body = augment(parser, importTrees, tree, tpt, moduleName)
def compile(phase: Phase): Unit =
{
globalPhase = phase
if(phase == null || phase == phase.next || reporter.hasErrors)
()
else
{
reporter.withSource(unit.source) {
atPhase(phase) { phase.run }
}
compile(phase.next)
}
}
compile(run.namerPhase)
checkError("Type error in expression")
val tpe = atPhase(run.typerPhase.next) { (new TypeExtractor).getType(unit.body) }
(tpe, load(dir, moduleName))
}
def load(dir: AbstractFile, moduleName: String): Any = getValue(moduleName, new AbstractFileClassLoader(dir, parent))
def loadPlain(dir: File, moduleName: String): Any = getValue(moduleName, new URLClassLoader(Array(dir.toURI.toURL), parent))
val WrapValName = "$sbtdef"
//wrap tree in object objectName { def WrapValName = <tree> }
def augment(parser: global.syntaxAnalyzer.UnitParser, imports: Seq[Tree], tree: Tree, tpt: Tree, objectName: String): Tree =
{
val emptyTypeName = nme.EMPTY.toTypeName
def emptyPkg = parser.atPos(0, 0, 0) { Ident(nme.EMPTY_PACKAGE_NAME) }
def emptyInit = DefDef(
NoMods,
nme.CONSTRUCTOR,
Nil,
List(Nil),
TypeTree(),
Block(List(Apply(Select(Super(This(emptyTypeName), emptyTypeName), nme.CONSTRUCTOR), Nil)), Literal(Constant(())))
)
def method = DefDef(NoMods, WrapValName, Nil, Nil, tpt, tree)
def moduleBody = Template(List(gen.scalaScalaObjectConstr), emptyValDef, List(emptyInit, method))
def moduleDef = ModuleDef(NoMods, objectName, moduleBody)
parser.makePackaging(0, emptyPkg, (imports :+ moduleDef).toList)
}
def getValue[T](objectName: String, loader: ClassLoader): T =
{
val clazz = Class.forName(objectName + "$", true, loader)
val module = clazz.getField("MODULE$").get(null)
val accessor = module.getClass.getMethod(WrapValName)
val value = accessor.invoke(module)
value.asInstanceOf[T]
}
final class TypeExtractor extends Traverser {
private[this] var result = ""
def getType(t: Tree) = { result = ""; traverse(t); result }
override def traverse(tree: Tree): Unit = tree match {
case d: DefDef if d.symbol.nameString == WrapValName => result = d.symbol.tpe.finalResultType.toString
case _ => super.traverse(tree)
}
}
private[this] def classExists(dir: File, name: String) = (new File(dir, name + ".class")).exists
// TODO: use the code from Analyzer
private[this] def getClassFiles(backing: Option[File], moduleName: String): Seq[File] =
backing match {
case None => Nil
case Some(dir) => dir listFiles moduleClassFilter(moduleName)
}
private[this] def moduleClassFilter(moduleName: String) = new java.io.FilenameFilter { def accept(dir: File, s: String) =
(s contains moduleName) && (s endsWith ".class")
}
private[this] def parseExpr(unit: CompilationUnit) =
{
val parser = new syntaxAnalyzer.UnitParser(unit)
val tree: Tree = parser.expr()
parser.accept(EOF)
checkError("Error parsing expression.")
(parser, tree)
}
private[this] def parseType(tpe: String): Tree =
{
val tpeParser = new syntaxAnalyzer.UnitParser(mkUnit("<expected-type>", DefaultStartLine, tpe))
val tpt0: Tree = tpeParser.typ()
tpeParser.accept(EOF)
checkError("Error parsing expression type.")
tpt0
}
private[this] def parseImports(imports: EvalImports): Seq[Tree] =
imports.strings flatMap { case (s, line) => parseImport(mkUnit(imports.srcName, line, s)) }
private[this] def parseImport(importUnit: CompilationUnit): Seq[Tree] =
{
val parser = new syntaxAnalyzer.UnitParser(importUnit)
val trees: Seq[Tree] = parser.importClause()
parser.accept(EOF)
checkError("Error parsing imports for expression.")
trees
}
val DefaultStartLine = 0
private[this] def makeModuleName(hash: String): String = "$" + Hash.halve(hash)
private[this] def noImports = new EvalImports(Nil, "")
private[this] def mkUnit(srcName: String, firstLine: Int, s: String) = new CompilationUnit(new EvalSourceFile(srcName, firstLine, s))
private[this] def checkError(label: String) = if(reporter.hasErrors) throw new EvalException(label)
private[this] final class EvalSourceFile(name: String, startLine: Int, contents: String) extends BatchSourceFile(name, contents)
{
override def lineToOffset(line: Int): Int = super.lineToOffset((line - startLine) max 0)
override def offsetToLine(offset: Int): Int = super.offsetToLine(offset) + startLine
}
}
private object Eval
{
def optBytes[T](o: Option[T])(f: T => Array[Byte]): Array[Byte] = seqBytes(o.toSeq)(f)
def seqBytes[T](s: Seq[T])(f: T => Array[Byte]): Array[Byte] = bytes(s map f)
def bytes(b: Seq[Array[Byte]]): Array[Byte] = bytes(b.length) ++ b.flatten.toArray[Byte]
def bytes(b: Boolean): Array[Byte] = Array[Byte](if(b) 1 else 0)
def filesModifiedBytes(fs: Array[File]): Array[Byte] = if(fs eq null) filesModifiedBytes(Array[File]()) else seqBytes(fs)(fileModifiedBytes)
def fileModifiedBytes(f: File): Array[Byte] =
(if(f.isDirectory) filesModifiedBytes(f listFiles classDirFilter) else bytes(f.lastModified)) ++
bytes(f.getAbsolutePath)
def fileExistsBytes(f: File): Array[Byte] =
bytes(f.exists) ++
bytes(f.getAbsolutePath)
def bytes(s: String): Array[Byte] = s getBytes "UTF-8"
def bytes(l: Long): Array[Byte] =
{
val buffer = ByteBuffer.allocate(8)
buffer.putLong(l)
buffer.array
}
def bytes(i: Int): Array[Byte] =
{
val buffer = ByteBuffer.allocate(4)
buffer.putInt(i)
buffer.array
}
private val classDirFilter: FileFilter = DirectoryFilter || GlobFilter("*.class")
}