/
TrapExit.scala
246 lines (239 loc) · 9.15 KB
/
TrapExit.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
/* sbt -- Simple Build Tool
* Copyright 2008 Mark Harrah
*
* Partially based on exit trapping in Nailgun by Pete Kirkham,
* copyright 2004, Martian Software, Inc
* licensed under Apache 2.0 License.
*/
package sbt
import scala.collection.Set
import scala.reflect.Manifest
/** This provides functionality to catch System.exit calls to prevent the JVM from terminating.
* This is useful for executing user code that may call System.exit, but actually exiting is
* undesirable. This file handles the call to exit by disposing all top-level windows and interrupting
* all user started threads. It does not stop the threads and does not call shutdown hooks. It is
* therefore inappropriate to use this with code that requires shutdown hooks or creates threads that
* do not terminate. This category of code should only be called by forking the JVM. */
object TrapExit
{
/** Executes the given thunk in a context where System.exit(code) throws
* a custom SecurityException, which is then caught and the exit code returned.
* Otherwise, 0 is returned. No other exceptions are handled by this method.*/
def apply(execute: => Unit, log: Logger): Int =
{
log.debug("Starting sandboxed run...")
/** Take a snapshot of the threads that existed before execution in order to determine
* the threads that were created by 'execute'.*/
val originalThreads = allThreads
val code = new ExitCode
def executeMain =
try { execute }
catch
{
case e: TrapExitSecurityException => throw e
case x =>
code.set(1) //exceptions in the main thread cause the exit code to be 1
throw x
}
val customThreadGroup = new ExitThreadGroup(new ExitHandler(Thread.getDefaultUncaughtExceptionHandler, originalThreads, code, log))
val executionThread = new Thread(customThreadGroup, "run-main") { override def run() { executeMain } }
val originalSecurityManager = System.getSecurityManager
try
{
val newSecurityManager = new TrapExitSecurityManager(originalSecurityManager, customThreadGroup)
System.setSecurityManager(newSecurityManager)
executionThread.start()
log.debug("Waiting for threads to exit or System.exit to be called.")
waitForExit(originalThreads, log)
log.debug("Interrupting remaining threads (should be all daemons).")
interruptAll(originalThreads) // should only be daemon threads left now
log.debug("Sandboxed run complete..")
code.value.getOrElse(0)
}
finally { System.setSecurityManager(originalSecurityManager) }
}
// wait for all non-daemon threads to terminate
private def waitForExit(originalThreads: Set[Thread], log: Logger)
{
var daemonsOnly = true
processThreads(originalThreads, thread =>
if(!thread.isDaemon)
{
daemonsOnly = false
waitOnThread(thread, log)
}
)
if(!daemonsOnly)
waitForExit(originalThreads, log)
}
/** Waits for the given thread to exit. */
private def waitOnThread(thread: Thread, log: Logger)
{
log.debug("Waiting for thread " + thread.getName + " to exit")
thread.join
log.debug("\tThread " + thread.getName + " exited.")
}
/** Returns the exit code of the System.exit that caused the given Exception, or rethrows the exception
* if its cause was not calling System.exit.*/
private def exitCode(e: Throwable) =
withCause[TrapExitSecurityException, Int](e)
{exited => exited.exitCode}
{other => throw other}
/** Recurses into the causes of the given exception looking for a cause of type CauseType. If one is found, `withType` is called with that cause.
* If not, `notType` is called with the root cause.*/
private def withCause[CauseType <: Throwable, T](e: Throwable)(withType: CauseType => T)(notType: Throwable => T)(implicit mf: Manifest[CauseType]): T =
{
val clazz = mf.erasure
if(clazz.isInstance(e))
withType(e.asInstanceOf[CauseType])
else
{
val cause = e.getCause
if(cause == null)
notType(e)
else
withCause(cause)(withType)(notType)(mf)
}
}
/** Returns all threads that are not in the 'system' thread group and are not the AWT implementation
* thread (AWT-XAWT, AWT-Windows, ...)*/
private def allThreads: Set[Thread] =
{
import collection.JavaConversions._
Thread.getAllStackTraces.keySet.filter(thread => !isSystemThread(thread))
}
/** Returns true if the given thread is in the 'system' thread group and is an AWT thread other than
* AWT-EventQueue or AWT-Shutdown.*/
private def isSystemThread(t: Thread) =
{
val name = t.getName
if(name.startsWith("AWT-"))
!(name.startsWith("AWT-EventQueue") || name.startsWith("AWT-Shutdown"))
else
{
val group = t.getThreadGroup
(group != null) && (group.getName == "system")
}
}
/** Calls the provided function for each thread in the system as provided by the
* allThreads function except those in ignoreThreads.*/
private def processThreads(ignoreThreads: Set[Thread], process: Thread => Unit)
{
allThreads.filter(thread => !ignoreThreads.contains(thread)).foreach(process)
}
/** Handles System.exit by disposing all frames and calling interrupt on all user threads */
private def stopAll(originalThreads: Set[Thread])
{
disposeAllFrames()
interruptAll(originalThreads)
}
private def disposeAllFrames()
{
val allFrames = java.awt.Frame.getFrames
if(allFrames.length > 0)
{
allFrames.foreach(_.dispose) // dispose all top-level windows, which will cause the AWT-EventQueue-* threads to exit
Thread.sleep(2000) // AWT Thread doesn't exit immediately, so wait to interrupt it
}
}
// interrupt all threads that appear to have been started by the user
private def interruptAll(originalThreads: Set[Thread]): Unit =
processThreads(originalThreads, safeInterrupt)
// interrupts the given thread, but first replaces the exception handler so that the InterruptedException is not printed
private def safeInterrupt(thread: Thread)
{
if(!thread.getName.startsWith("AWT-"))
{
thread.setUncaughtExceptionHandler(new TrapInterrupt(thread.getUncaughtExceptionHandler))
thread.interrupt
}
}
// an uncaught exception handler that swallows InterruptedExceptions and otherwise defers to originalHandler
private final class TrapInterrupt(originalHandler: Thread.UncaughtExceptionHandler) extends Thread.UncaughtExceptionHandler
{
def uncaughtException(thread: Thread, e: Throwable)
{
withCause[InterruptedException, Unit](e)
{interrupted => ()}
{other => originalHandler.uncaughtException(thread, e) }
thread.setUncaughtExceptionHandler(originalHandler)
}
}
/** An uncaught exception handler that delegates to the original uncaught exception handler except when
* the cause was a call to System.exit (which generated a SecurityException)*/
private final class ExitHandler(originalHandler: Thread.UncaughtExceptionHandler, originalThreads: Set[Thread], codeHolder: ExitCode, log: Logger) extends Thread.UncaughtExceptionHandler
{
def uncaughtException(t: Thread, e: Throwable)
{
try
{
codeHolder.set(exitCode(e)) // will rethrow e if it was not because of a call to System.exit
stopAll(originalThreads)
}
catch
{
case _ =>
log.trace(e)
originalHandler.uncaughtException(t, e)
}
}
}
private final class ExitThreadGroup(handler: Thread.UncaughtExceptionHandler) extends ThreadGroup("trap.exit")
{
override def uncaughtException(t: Thread, e: Throwable) = handler.uncaughtException(t, e)
}
}
private final class ExitCode extends NotNull
{
private var code: Option[Int] = None
def set(c: Int): Unit = synchronized { code = code orElse Some(c) }
def value: Option[Int] = synchronized { code }
}
/////// These two classes are based on similar classes in Nailgun
/** A custom SecurityManager to disallow System.exit. */
private final class TrapExitSecurityManager(delegateManager: SecurityManager, group: ThreadGroup) extends SecurityManager
{
import java.security.Permission
override def checkExit(status: Int)
{
val stack = Thread.currentThread.getStackTrace
if(stack == null || stack.exists(isRealExit))
throw new TrapExitSecurityException(status)
}
/** This ensures that only actual calls to exit are trapped and not just calls to check if exit is allowed.*/
private def isRealExit(element: StackTraceElement): Boolean =
element.getClassName == "java.lang.Runtime" && element.getMethodName == "exit"
override def checkPermission(perm: Permission)
{
if(delegateManager != null)
delegateManager.checkPermission(perm)
}
override def checkPermission(perm: Permission, context: AnyRef)
{
if(delegateManager != null)
delegateManager.checkPermission(perm, context)
}
override def getThreadGroup = group
}
/** A custom SecurityException that tries not to be caught.*/
private final class TrapExitSecurityException(val exitCode: Int) extends SecurityException
{
private var accessAllowed = false
def allowAccess
{
accessAllowed = true
}
override def printStackTrace = ifAccessAllowed(super.printStackTrace)
override def toString = ifAccessAllowed(super.toString)
override def getCause = ifAccessAllowed(super.getCause)
override def getMessage = ifAccessAllowed(super.getMessage)
override def fillInStackTrace = ifAccessAllowed(super.fillInStackTrace)
override def getLocalizedMessage = ifAccessAllowed(super.getLocalizedMessage)
private def ifAccessAllowed[T](f: => T): T =
{
if(accessAllowed)
f
else
throw this
}
}