diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index e3f52f6ff1e63..18c4e4b87cc83 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -19,17 +19,20 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import scala.collection.mutable.Map -import scala.collection.mutable.Set +import scala.collection.mutable.{Map, Set} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.{Logging, SparkEnv, SparkException} +/** + * A cleaner that renders closures serializable if they can be done so safely. + */ private[spark] object ClosureCleaner extends Logging { + // Get an ASM class reader for a given class from the JAR that loaded it - private def getClassReader(cls: Class[_]): ClassReader = { + def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) @@ -77,6 +80,9 @@ private[spark] object ClosureCleaner extends Logging { Nil } + /** + * Return a list of classes that represent closures enclosed in the given closure object. + */ private def getInnerClasses(obj: AnyRef): List[Class[_]] = { val seen = Set[Class[_]](obj.getClass) var stack = List[Class[_]](obj.getClass) @@ -101,21 +107,110 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean(func: AnyRef, checkSerializable: Boolean = true) { + /** + * Clean the given closure in place. + * + * More specifically, this renders the given closure serializable as long as it does not + * explicitly reference unserializable objects. + * + * @param closure the closure to clean + * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param cleanTransitively whether to clean enclosing closures transitively + */ + def clean( + closure: AnyRef, + checkSerializable: Boolean = true, + cleanTransitively: Boolean = true): Unit = { + clean(closure, checkSerializable, cleanTransitively, Map.empty) + } + + /** + * Helper method to clean the given closure in place. + * + * The mechanism is to traverse the hierarchy of enclosing closures and null out any + * references along the way that are not actually used by the starting closure, but are + * nevertheless included in the compiled anonymous classes. Note that it is unsafe to + * simply mutate the enclosing closures, as other code paths may depend on them. Instead, + * we clone each enclosing closure and set the parent pointers accordingly. + * + * By default, closures are cleaned transitively. This means we detect whether enclosing + * objects are actually referenced by the starting one, either directly or transitively, + * and, if not, sever these closures from the hierarchy. In other words, in addition to + * nulling out unused field references, we also null out any parent pointers that refer + * to enclosing objects not actually needed by the starting closure. + * + * For instance, transitive cleaning is necessary in the following scenario: + * + * class SomethingNotSerializable { + * def someValue = 1 + * def someMethod(): Unit = scope("one") { + * def x = someValue + * def y = 2 + * scope("two") { println(y + 1) } + * } + * def scope(name: String)(body: => Unit) = body + * } + * + * In this example, scope "two" is not serializable because it references scope "one", which + * references SomethingNotSerializable. Note that, however, scope "two" does not actually + * depend on SomethingNotSerializable. This means we can null out the parent pointer of + * a cloned scope "one" and set it the parent of scope "two", such that scope "two" no longer + * references SomethingNotSerializable transitively. + * + * @param func the starting closure to clean + * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param cleanTransitively whether to clean enclosing closures transitively + * @param accessedFields a map from a class to a set of its fields that are accessed by + * the starting closure + */ + private def clean( + func: AnyRef, + checkSerializable: Boolean, + cleanTransitively: Boolean, + accessedFields: Map[Class[_], Set[String]]) { + + // TODO: clean all inner closures first. This requires us to find the inner objects. // TODO: cache outerClasses / innerClasses / accessedFields - val outerClasses = getOuterClasses(func) + + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++") + + // A list of classes that represents closures enclosed in the given one val innerClasses = getInnerClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index + val outerClasses = getOuterClasses(func) val outerObjects = getOuterObjects(func) - val accessedFields = Map[Class[_], Set[String]]() - + logDebug(s" + inner classes: " + innerClasses.size) + innerClasses.foreach { c => logDebug(" " + c.getName) } + logDebug(s" + outer classes: " + outerClasses.size) + outerClasses.foreach { c => logDebug(" " + c.getName) } + logDebug(s" + outer objects: " + outerObjects.size) + outerObjects.foreach { o => logDebug(" " + o) } + + // Fail fast if we detect return statements in closures getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) - - for (cls <- outerClasses) - accessedFields(cls) = Set[String]() - for (cls <- func.getClass :: innerClasses) - getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) - // logInfo("accessedFields: " + accessedFields) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + logDebug(s" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + for (cls <- outerClasses) { + accessedFields(cls) = Set[String]() + } + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) + } + } + + logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + accessedFields.foreach { f => logDebug(" " + f) } val inInterpreter = { try { @@ -126,34 +221,66 @@ private[spark] object ClosureCleaner extends Logging { } } + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var outer: AnyRef = null + var parent: AnyRef = null if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { // The closure is ultimately nested inside a class; keep the object of that // class without cloning it since we don't want to clone the user's objects. - outer = outerPairs.head._2 + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}") + parent = outerPairs.head._2 // e.g. SparkContext outerPairs = outerPairs.tail + } else if (outerPairs.size > 0) { + logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}") + } else { + logDebug(" + there are no enclosing objects!") } + // Clone the closure objects themselves, nulling out any fields that are not // used in the closure we're working on or any of its inner closures. for ((cls, obj) <- outerPairs) { - outer = instantiateClass(cls, outer, inInterpreter) + logDebug(s" + cloning the object $obj of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = instantiateClass(cls, parent, inInterpreter) for (fieldName <- accessedFields(cls)) { val field = cls.getDeclaredField(fieldName) field.setAccessible(true) val value = field.get(obj) - // logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); - field.set(outer, value) + field.set(clone, value) } + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + clean(clone, checkSerializable, cleanTransitively, accessedFields) + } + parent = clone } - if (outer != null) { - // logInfo("2: Setting $outer on " + func.getClass + " to " + outer); + // Update the parent pointer ($outer) of this closure + if (parent != null) { val field = func.getClass.getDeclaredField("$outer") field.setAccessible(true) - field.set(func, outer) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } } - + + logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + if (checkSerializable) { ensureSerializable(func) } @@ -167,15 +294,17 @@ private[spark] object ClosureCleaner extends Logging { } } - private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { - // logInfo("Creating a " + cls + " with outer = " + outer) + private def instantiateClass( + cls: Class[_], + enclosingObject: AnyRef, + inInterpreter: Boolean): AnyRef = { if (!inInterpreter) { // This is a bona fide closure class, whose constructor has no effects // other than to set its fields, so use its constructor val cons = cls.getConstructors()(0) val params = cons.getParameterTypes.map(createNullValue).toArray - if (outer != null) { - params(0) = outer // First param is always outer object + if (enclosingObject!= null) { + params(0) = enclosingObject // First param is always enclosing object } return cons.newInstance(params: _*).asInstanceOf[AnyRef] } else { @@ -184,11 +313,10 @@ private[spark] object ClosureCleaner extends Logging { val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() val newCtor = rf.newConstructorForSerialization(cls, parentCtor) val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (outer != null) { - // logInfo("3: Setting $outer on " + cls + " to " + outer); + if (enclosingObject != null) { val field = cls.getDeclaredField("$outer") field.setAccessible(true) - field.set(obj, outer) + field.set(obj, enclosingObject) } obj } @@ -213,29 +341,68 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) { } } +/** + * Find the fields accessed by a given class. + * + * The fields are stored in the mutable map passed in by the class that contains them. + * This map is assumed to have its keys already populated by the classes of interest. + * + * @param fields the mutable map that stores the fields to return + * @param specificMethodNames if not empty, only visit methods whose names are in this set + * @param findTransitively if true, find fields indirectly referenced in other classes + */ private[spark] -class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { +class FieldAccessFinder( + fields: Map[Class[_], Set[String]], + specificMethodNames: Set[String] = Set.empty, + findTransitively: Boolean = true) + extends ClassVisitor(ASM4) { + + override def visitMethod( + access: Int, + name: String, + desc: String, + sig: String, + exceptions: Array[String]): MethodVisitor = { + + // Ignore this method if we don't want to visit it + if (specificMethodNames.nonEmpty && !specificMethodNames.contains(name)) { + return new MethodVisitor(ASM4) { } + } + new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name + for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { + fields(cl) += name } } } - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - // Check for calls a getter method for a variable in an interpreter wrapper object. - // This means that the corresponding field will be accessed, so we should save it. - if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (isInvoke(op)) { + for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { + // Check for calls a getter method for a variable in an interpreter wrapper object. + // This means that the corresponding field will be accessed, so we should save it. + if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { + fields(cl) += name + } + // Visit other methods to find fields that are transitively referenced + if (findTransitively) { + ClosureCleaner.getClassReader(cl) + .accept(new FieldAccessFinder(fields, Set(name), findTransitively), 0) + } } } } + + private def isInvoke(op: Int): Boolean = { + op == INVOKEVIRTUAL || + op == INVOKESPECIAL || + op == INVOKEDYNAMIC || + op == INVOKEINTERFACE || + op == INVOKESTATIC + } } } }