Skip to content

Commit

Permalink
Implement transitive cleaning + add missing documentation
Browse files Browse the repository at this point in the history
See in-code comments for more detail on what this means.
  • Loading branch information
Andrew Or committed Apr 24, 2015
1 parent 4c722d7 commit 86f7823
Showing 1 changed file with 208 additions and 41 deletions.
249 changes: 208 additions & 41 deletions core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@ package org.apache.spark.util

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.collection.mutable.Map
import scala.collection.mutable.Set
import scala.collection.mutable.{Map, Set}

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.{Logging, SparkEnv, SparkException}

/**
* A cleaner that renders closures serializable if they can be done so safely.
*/
private[spark] object ClosureCleaner extends Logging {

// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
def getClassReader(cls: Class[_]): ClassReader = {
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
val resourceStream = cls.getResourceAsStream(className)
Expand Down Expand Up @@ -77,6 +80,9 @@ private[spark] object ClosureCleaner extends Logging {
Nil
}

/**
* Return a list of classes that represent closures enclosed in the given closure object.
*/
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
val seen = Set[Class[_]](obj.getClass)
var stack = List[Class[_]](obj.getClass)
Expand All @@ -101,21 +107,110 @@ private[spark] object ClosureCleaner extends Logging {
}
}

def clean(func: AnyRef, checkSerializable: Boolean = true) {
/**
* Clean the given closure in place.
*
* More specifically, this renders the given closure serializable as long as it does not
* explicitly reference unserializable objects.
*
* @param closure the closure to clean
* @param checkSerializable whether to verify that the closure is serializable after cleaning
* @param cleanTransitively whether to clean enclosing closures transitively
*/
def clean(
closure: AnyRef,
checkSerializable: Boolean = true,
cleanTransitively: Boolean = true): Unit = {
clean(closure, checkSerializable, cleanTransitively, Map.empty)
}

/**
* Helper method to clean the given closure in place.
*
* The mechanism is to traverse the hierarchy of enclosing closures and null out any
* references along the way that are not actually used by the starting closure, but are
* nevertheless included in the compiled anonymous classes. Note that it is unsafe to
* simply mutate the enclosing closures, as other code paths may depend on them. Instead,
* we clone each enclosing closure and set the parent pointers accordingly.
*
* By default, closures are cleaned transitively. This means we detect whether enclosing
* objects are actually referenced by the starting one, either directly or transitively,
* and, if not, sever these closures from the hierarchy. In other words, in addition to
* nulling out unused field references, we also null out any parent pointers that refer
* to enclosing objects not actually needed by the starting closure.
*
* For instance, transitive cleaning is necessary in the following scenario:
*
* class SomethingNotSerializable {
* def someValue = 1
* def someMethod(): Unit = scope("one") {
* def x = someValue
* def y = 2
* scope("two") { println(y + 1) }
* }
* def scope(name: String)(body: => Unit) = body
* }
*
* In this example, scope "two" is not serializable because it references scope "one", which
* references SomethingNotSerializable. Note that, however, scope "two" does not actually
* depend on SomethingNotSerializable. This means we can null out the parent pointer of
* a cloned scope "one" and set it the parent of scope "two", such that scope "two" no longer
* references SomethingNotSerializable transitively.
*
* @param func the starting closure to clean
* @param checkSerializable whether to verify that the closure is serializable after cleaning
* @param cleanTransitively whether to clean enclosing closures transitively
* @param accessedFields a map from a class to a set of its fields that are accessed by
* the starting closure
*/
private def clean(
func: AnyRef,
checkSerializable: Boolean,
cleanTransitively: Boolean,
accessedFields: Map[Class[_], Set[String]]) {

// TODO: clean all inner closures first. This requires us to find the inner objects.
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)

logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++")

// A list of classes that represents closures enclosed in the given one
val innerClasses = getInnerClasses(func)

// A list of enclosing objects and their respective classes, from innermost to outermost
// An outer object at a given index is of type outer class at the same index
val outerClasses = getOuterClasses(func)
val outerObjects = getOuterObjects(func)

val accessedFields = Map[Class[_], Set[String]]()

logDebug(s" + inner classes: " + innerClasses.size)
innerClasses.foreach { c => logDebug(" " + c.getName) }
logDebug(s" + outer classes: " + outerClasses.size)
outerClasses.foreach { c => logDebug(" " + c.getName) }
logDebug(s" + outer objects: " + outerObjects.size)
outerObjects.foreach { o => logDebug(" " + o) }

// Fail fast if we detect return statements in closures
getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)

for (cls <- outerClasses)
accessedFields(cls) = Set[String]()
for (cls <- func.getClass :: innerClasses)
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
// logInfo("accessedFields: " + accessedFields)

// If accessed fields is not populated yet, we assume that
// the closure we are trying to clean is the starting one
if (accessedFields.isEmpty) {
logDebug(s" + populating accessed fields because this is the starting closure")
// Initialize accessed fields with the outer classes first
// This step is needed to associate the fields to the correct classes later
for (cls <- outerClasses) {
accessedFields(cls) = Set[String]()
}
// Populate accessed fields by visiting all fields and methods accessed by this and
// all of its inner closures. If transitive cleaning is enabled, this may recursively
// visits methods that belong to other classes in search of transitively referenced fields.
for (cls <- func.getClass :: innerClasses) {
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
}
}

logDebug(s" + fields accessed by starting closure: " + accessedFields.size)
accessedFields.foreach { f => logDebug(" " + f) }

val inInterpreter = {
try {
Expand All @@ -126,34 +221,66 @@ private[spark] object ClosureCleaner extends Logging {
}
}

// List of outer (class, object) pairs, ordered from outermost to innermost
// Note that all outer objects but the outermost one (first one in this list) must be closures
var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
var outer: AnyRef = null
var parent: AnyRef = null
if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) {
// The closure is ultimately nested inside a class; keep the object of that
// class without cloning it since we don't want to clone the user's objects.
outer = outerPairs.head._2
// Note that we still need to keep around the outermost object itself because
// we need it to clone its child closure later (see below).
logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}")
parent = outerPairs.head._2 // e.g. SparkContext
outerPairs = outerPairs.tail
} else if (outerPairs.size > 0) {
logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}")
} else {
logDebug(" + there are no enclosing objects!")
}

// Clone the closure objects themselves, nulling out any fields that are not
// used in the closure we're working on or any of its inner closures.
for ((cls, obj) <- outerPairs) {
outer = instantiateClass(cls, outer, inInterpreter)
logDebug(s" + cloning the object $obj of class ${cls.getName}")
// We null out these unused references by cloning each object and then filling in all
// required fields from the original object. We need the parent here because the Java
// language specification requires the first constructor parameter of any closure to be
// its enclosing object.
val clone = instantiateClass(cls, parent, inInterpreter)
for (fieldName <- accessedFields(cls)) {
val field = cls.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
// logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
field.set(outer, value)
field.set(clone, value)
}
// If transitive cleaning is enabled, we recursively clean any enclosing closure using
// the already populated accessed fields map of the starting closure
if (cleanTransitively && isClosure(clone.getClass)) {
logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})")
clean(clone, checkSerializable, cleanTransitively, accessedFields)
}
parent = clone
}

if (outer != null) {
// logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
// Update the parent pointer ($outer) of this closure
if (parent != null) {
val field = func.getClass.getDeclaredField("$outer")
field.setAccessible(true)
field.set(func, outer)
// If the starting closure doesn't actually need our enclosing object, then just null it out
if (accessedFields.contains(func.getClass) &&
!accessedFields(func.getClass).contains("$outer")) {
logDebug(s" + the starting closure doesn't actually need $parent, so we null it out")
field.set(func, null)
} else {
// Update this closure's parent pointer to point to our enclosing object,
// which could either be a cloned closure or the original user object
field.set(func, parent)
}
}


logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++")

if (checkSerializable) {
ensureSerializable(func)
}
Expand All @@ -167,15 +294,17 @@ private[spark] object ClosureCleaner extends Logging {
}
}

private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
// logInfo("Creating a " + cls + " with outer = " + outer)
private def instantiateClass(
cls: Class[_],
enclosingObject: AnyRef,
inInterpreter: Boolean): AnyRef = {
if (!inInterpreter) {
// This is a bona fide closure class, whose constructor has no effects
// other than to set its fields, so use its constructor
val cons = cls.getConstructors()(0)
val params = cons.getParameterTypes.map(createNullValue).toArray
if (outer != null) {
params(0) = outer // First param is always outer object
if (enclosingObject!= null) {
params(0) = enclosingObject // First param is always enclosing object
}
return cons.newInstance(params: _*).asInstanceOf[AnyRef]
} else {
Expand All @@ -184,11 +313,10 @@ private[spark] object ClosureCleaner extends Logging {
val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
val obj = newCtor.newInstance().asInstanceOf[AnyRef]
if (outer != null) {
// logInfo("3: Setting $outer on " + cls + " to " + outer);
if (enclosingObject != null) {
val field = cls.getDeclaredField("$outer")
field.setAccessible(true)
field.set(obj, outer)
field.set(obj, enclosingObject)
}
obj
}
Expand All @@ -213,29 +341,68 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) {
}
}

/**
* Find the fields accessed by a given class.
*
* The fields are stored in the mutable map passed in by the class that contains them.
* This map is assumed to have its keys already populated by the classes of interest.
*
* @param fields the mutable map that stores the fields to return
* @param specificMethodNames if not empty, only visit methods whose names are in this set
* @param findTransitively if true, find fields indirectly referenced in other classes
*/
private[spark]
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
class FieldAccessFinder(
fields: Map[Class[_], Set[String]],
specificMethodNames: Set[String] = Set.empty,
findTransitively: Boolean = true)
extends ClassVisitor(ASM4) {

override def visitMethod(
access: Int,
name: String,
desc: String,
sig: String,
exceptions: Array[String]): MethodVisitor = {

// Ignore this method if we don't want to visit it
if (specificMethodNames.nonEmpty && !specificMethodNames.contains(name)) {
return new MethodVisitor(ASM4) { }
}

new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
output(cl) += name
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
fields(cl) += name
}
}
}

override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
// Check for calls a getter method for a variable in an interpreter wrapper object.
// This means that the corresponding field will be accessed, so we should save it.
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
output(cl) += name
override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
if (isInvoke(op)) {
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
// Check for calls a getter method for a variable in an interpreter wrapper object.
// This means that the corresponding field will be accessed, so we should save it.
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
fields(cl) += name
}
// Visit other methods to find fields that are transitively referenced
if (findTransitively) {
ClosureCleaner.getClassReader(cl)
.accept(new FieldAccessFinder(fields, Set(name), findTransitively), 0)
}
}
}
}

private def isInvoke(op: Int): Boolean = {
op == INVOKEVIRTUAL ||
op == INVOKESPECIAL ||
op == INVOKEDYNAMIC ||
op == INVOKEINTERFACE ||
op == INVOKESTATIC
}
}
}
}
Expand Down

0 comments on commit 86f7823

Please sign in to comment.