Skip to content

Commit

Permalink
Guard against potential infinite cycles in method visitor
Browse files Browse the repository at this point in the history
Now we keep track of the methods that we visited to avoid visiting
the same method twice.
  • Loading branch information
Andrew Or committed Apr 25, 2015
1 parent 6d36f38 commit e672170
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ private[spark] object ClosureCleaner extends Logging {
// all of its inner closures. If transitive cleaning is enabled, this may recursively
// visits methods that belong to other classes in search of transitively referenced fields.
for (cls <- func.getClass :: innerClasses) {
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0)
}
}

Expand Down Expand Up @@ -358,20 +358,25 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) {
}
}

/** Helper class to identify a method. */
private case class MethodIdentifier(cls: Class[_], name: String, desc: String)

/**
* Find the fields accessed by a given class.
*
* The fields are stored in the mutable map passed in by the class that contains them.
* This map is assumed to have its keys already populated with the classes of interest.
*
* @param fields the mutable map that stores the fields to return
* @param specificMethodNames if not empty, only visit methods whose names are in this set
* @param findTransitively if true, find fields indirectly referenced in other classes
* @param specificMethod if not empty, visit only this method
* @param visitedMethods a list of visited methods to avoid cycles
*/
private class FieldAccessFinder(
fields: Map[Class[_], Set[String]],
specificMethodNames: Set[String] = Set.empty,
findTransitively: Boolean = true)
findTransitively: Boolean,
specificMethod: Option[MethodIdentifier] = None,
visitedMethods: Set[MethodIdentifier] = Set.empty)
extends ClassVisitor(ASM4) {

override def visitMethod(
Expand All @@ -381,9 +386,10 @@ private class FieldAccessFinder(
sig: String,
exceptions: Array[String]): MethodVisitor = {

// Ignore this method if we don't want to visit it
if (specificMethodNames.nonEmpty && !specificMethodNames.contains(name)) {
return new MethodVisitor(ASM4) { }
// Ignore this method unless we are told to visit it
if (specificMethod.nonEmpty &&
specificMethod.get.name != name || specificMethod.get.desc != desc) {
return null
}

new MethodVisitor(ASM4) {
Expand All @@ -396,30 +402,24 @@ private class FieldAccessFinder(
}

override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
if (isInvoke(op)) {
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
// Check for calls a getter method for a variable in an interpreter wrapper object.
// This means that the corresponding field will be accessed, so we should save it.
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
fields(cl) += name
}
// Visit other methods to find fields that are transitively referenced
// FIXME: This could lead to infinite cycles!!
if (findTransitively) {
ClosureCleaner.getClassReader(cl)
.accept(new FieldAccessFinder(fields, Set(name), findTransitively), 0)
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
// Check for calls a getter method for a variable in an interpreter wrapper object.
// This means that the corresponding field will be accessed, so we should save it.
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
fields(cl) += name
}
// Visit other methods to find fields that are transitively referenced
if (findTransitively) {
val m = MethodIdentifier(cl, name, desc)
if (!visitedMethods.contains(m)) {
// Keep track of visited methods to avoid potential infinite cycles
visitedMethods += m
ClosureCleaner.getClassReader(cl).accept(
new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
}
}
}
}

private def isInvoke(op: Int): Boolean = {
op == INVOKEVIRTUAL ||
op == INVOKESPECIAL ||
op == INVOKEDYNAMIC ||
op == INVOKEINTERFACE ||
op == INVOKESTATIC
}
}
}
}
Expand Down

0 comments on commit e672170

Please sign in to comment.