diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 916ec25b9deab..6a9300cd8960b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -223,7 +223,7 @@ private[spark] object ClosureCleaner extends Logging { // all of its inner closures. If transitive cleaning is enabled, this may recursively // visits methods that belong to other classes in search of transitively referenced fields. for (cls <- func.getClass :: innerClasses) { - getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) } } @@ -358,6 +358,9 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) { } } +/** Helper class to identify a method. */ +private case class MethodIdentifier(cls: Class[_], name: String, desc: String) + /** * Find the fields accessed by a given class. * @@ -365,13 +368,15 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) { * This map is assumed to have its keys already populated with the classes of interest. * * @param fields the mutable map that stores the fields to return - * @param specificMethodNames if not empty, only visit methods whose names are in this set * @param findTransitively if true, find fields indirectly referenced in other classes + * @param specificMethod if not empty, visit only this method + * @param visitedMethods a list of visited methods to avoid cycles */ private class FieldAccessFinder( fields: Map[Class[_], Set[String]], - specificMethodNames: Set[String] = Set.empty, - findTransitively: Boolean = true) + findTransitively: Boolean, + specificMethod: Option[MethodIdentifier] = None, + visitedMethods: Set[MethodIdentifier] = Set.empty) extends ClassVisitor(ASM4) { override def visitMethod( @@ -381,9 +386,10 @@ private class FieldAccessFinder( sig: String, exceptions: Array[String]): MethodVisitor = { - // Ignore this method if we don't want to visit it - if (specificMethodNames.nonEmpty && !specificMethodNames.contains(name)) { - return new MethodVisitor(ASM4) { } + // Ignore this method unless we are told to visit it + if (specificMethod.nonEmpty && + specificMethod.get.name != name || specificMethod.get.desc != desc) { + return null } new MethodVisitor(ASM4) { @@ -396,30 +402,24 @@ private class FieldAccessFinder( } override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { - if (isInvoke(op)) { - for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { - // Check for calls a getter method for a variable in an interpreter wrapper object. - // This means that the corresponding field will be accessed, so we should save it. - if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { - fields(cl) += name - } - // Visit other methods to find fields that are transitively referenced - // FIXME: This could lead to infinite cycles!! - if (findTransitively) { - ClosureCleaner.getClassReader(cl) - .accept(new FieldAccessFinder(fields, Set(name), findTransitively), 0) + for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { + // Check for calls a getter method for a variable in an interpreter wrapper object. + // This means that the corresponding field will be accessed, so we should save it. + if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { + fields(cl) += name + } + // Visit other methods to find fields that are transitively referenced + if (findTransitively) { + val m = MethodIdentifier(cl, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + ClosureCleaner.getClassReader(cl).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) } } } } - - private def isInvoke(op: Int): Boolean = { - op == INVOKEVIRTUAL || - op == INVOKESPECIAL || - op == INVOKEDYNAMIC || - op == INVOKEINTERFACE || - op == INVOKESTATIC - } } } }