Skip to content

Commit

Permalink
Use private method tester for a few things
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Or committed Apr 25, 2015
1 parent a3aa465 commit eb127e5
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[spark] object ClosureCleaner extends Logging {
}

// Check whether a class represents a Scala closure
private[util] def isClosure(cls: Class[_]): Boolean = {
private def isClosure(cls: Class[_]): Boolean = {
cls.getName.contains("$anonfun$")
}

Expand All @@ -55,7 +55,7 @@ private[spark] object ClosureCleaner extends Logging {
// for outer objects beyond that because cloning the user's object is probably
// not a good idea (whereas we can clone closure objects just fine since we
// understand how all their fields are used).
private[util] def getOuterClasses(obj: AnyRef): List[Class[_]] = {
private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
val outer = f.get(obj)
Expand All @@ -72,7 +72,7 @@ private[spark] object ClosureCleaner extends Logging {
}

// Get a list of the outer objects for a given closure object.
private[util] def getOuterObjects(obj: AnyRef): List[AnyRef] = {
private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
val outer = f.get(obj)
Expand All @@ -91,7 +91,7 @@ private[spark] object ClosureCleaner extends Logging {
/**
* Return a list of classes that represent closures enclosed in the given closure object.
*/
private[util] def getInnerClasses(obj: AnyRef): List[Class[_]] = {
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
val seen = Set[Class[_]](obj.getClass)
var stack = List[Class[_]](obj.getClass)
while (!stack.isEmpty) {
Expand Down
102 changes: 62 additions & 40 deletions core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.NotSerializableException

import scala.collection.mutable

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester}

import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.serializer.SerializerInstance
Expand All @@ -30,7 +30,7 @@ import org.apache.spark.serializer.SerializerInstance
* Another test suite for the closure cleaner that is finer-grained.
* For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}.
*/
class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with PrivateMethodTester {

// Start a SparkContext so that SparkEnv.get.closureSerializer is accessible
// We do not actually use this explicitly except to stop it later
Expand Down Expand Up @@ -111,6 +111,28 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
fields.mapValues(_.toSet).toMap
}

// Accessors for private methods
private val _isClosure = PrivateMethod[Boolean]('isClosure)
private val _getInnerClasses = PrivateMethod[List[Class[_]]]('getInnerClasses)
private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses)
private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects)

private def isClosure(obj: AnyRef): Boolean = {
ClosureCleaner invokePrivate _isClosure(obj)
}

private def getInnerClasses(closure: AnyRef): List[Class[_]] = {
ClosureCleaner invokePrivate _getInnerClasses(closure)
}

private def getOuterClasses(closure: AnyRef): List[Class[_]] = {
ClosureCleaner invokePrivate _getOuterClasses(closure)
}

private def getOuterObjects(closure: AnyRef): List[AnyRef] = {
ClosureCleaner invokePrivate _getOuterObjects(closure)
}

test("get inner classes") {
val closure1 = () => 1
val closure2 = () => { () => 1 }
Expand All @@ -124,17 +146,17 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
}
}
}
val inner1 = ClosureCleaner.getInnerClasses(closure1)
val inner2 = ClosureCleaner.getInnerClasses(closure2)
val inner3 = ClosureCleaner.getInnerClasses(closure3)
val inner4 = ClosureCleaner.getInnerClasses(closure4)
val inner1 = getInnerClasses(closure1)
val inner2 = getInnerClasses(closure2)
val inner3 = getInnerClasses(closure3)
val inner4 = getInnerClasses(closure4)
assert(inner1.isEmpty)
assert(inner2.size === 1)
assert(inner3.size === 2)
assert(inner4.size === 3)
assert(inner2.forall(ClosureCleaner.isClosure))
assert(inner3.forall(ClosureCleaner.isClosure))
assert(inner4.forall(ClosureCleaner.isClosure))
assert(inner2.forall(isClosure))
assert(inner3.forall(isClosure))
assert(inner4.forall(isClosure))
}

test("get outer classes and objects") {
Expand All @@ -143,14 +165,14 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
val closure2 = () => localValue
val closure3 = () => someSerializableValue
val closure4 = () => someSerializableMethod()
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
val outerClasses4 = ClosureCleaner.getOuterClasses(closure4)
val outerObjects1 = ClosureCleaner.getOuterObjects(closure1)
val outerObjects2 = ClosureCleaner.getOuterObjects(closure2)
val outerObjects3 = ClosureCleaner.getOuterObjects(closure3)
val outerObjects4 = ClosureCleaner.getOuterObjects(closure4)
val outerClasses1 = getOuterClasses(closure1)
val outerClasses2 = getOuterClasses(closure2)
val outerClasses3 = getOuterClasses(closure3)
val outerClasses4 = getOuterClasses(closure4)
val outerObjects1 = getOuterObjects(closure1)
val outerObjects2 = getOuterObjects(closure2)
val outerObjects3 = getOuterObjects(closure3)
val outerObjects4 = getOuterObjects(closure4)

// The classes and objects should have the same size
assert(outerClasses1.size === outerObjects1.size)
Expand All @@ -167,8 +189,8 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
// The second $outer pointer refers to ClosureCleanerSuite2
assert(outerClasses3.size === 2)
assert(outerClasses4.size === 2)
assert(ClosureCleaner.isClosure(outerClasses3(0)))
assert(ClosureCleaner.isClosure(outerClasses4(0)))
assert(isClosure(outerClasses3(0)))
assert(isClosure(outerClasses4(0)))
assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope
assert(outerClasses3(1) === this.getClass)
assert(outerClasses4(1) === this.getClass)
Expand All @@ -183,10 +205,10 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
val x = 1
val closure1 = () => 1
val closure2 = () => x
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
val outerObjects1 = ClosureCleaner.getOuterObjects(closure1)
val outerObjects2 = ClosureCleaner.getOuterObjects(closure2)
val outerClasses1 = getOuterClasses(closure1)
val outerClasses2 = getOuterClasses(closure2)
val outerObjects1 = getOuterObjects(closure1)
val outerObjects2 = getOuterObjects(closure2)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
// These inner closures only reference local variables, and so do not have $outer pointer
Expand All @@ -199,12 +221,12 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
val closure1 = () => 1
val closure2 = () => y
val closure3 = () => localValue
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
val outerObjects1 = ClosureCleaner.getOuterObjects(closure1)
val outerObjects2 = ClosureCleaner.getOuterObjects(closure2)
val outerObjects3 = ClosureCleaner.getOuterObjects(closure3)
val outerClasses1 = getOuterClasses(closure1)
val outerClasses2 = getOuterClasses(closure2)
val outerClasses3 = getOuterClasses(closure3)
val outerObjects1 = getOuterObjects(closure1)
val outerObjects2 = getOuterObjects(closure2)
val outerObjects3 = getOuterObjects(closure3)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
assert(outerClasses3.size === outerObjects3.size)
Expand All @@ -216,10 +238,10 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
// This closure references the "test2" scope because it needs to find the
// `localValue` defined outside of this scope
assert(outerClasses3.size === 3)
assert(ClosureCleaner.isClosure(outerClasses2(0)))
assert(ClosureCleaner.isClosure(outerClasses3(0)))
assert(ClosureCleaner.isClosure(outerClasses2(1)))
assert(ClosureCleaner.isClosure(outerClasses3(1)))
assert(isClosure(outerClasses2(0)))
assert(isClosure(outerClasses3(0)))
assert(isClosure(outerClasses2(1)))
assert(isClosure(outerClasses3(1)))
assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope
assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope
assert(outerClasses2(2) === this.getClass)
Expand All @@ -237,9 +259,9 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
val closure1 = () => 1
val closure2 = () => localValue
val closure3 = () => someSerializableValue
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
val outerClasses1 = getOuterClasses(closure1)
val outerClasses2 = getOuterClasses(closure2)
val outerClasses3 = getOuterClasses(closure3)

val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
Expand Down Expand Up @@ -277,10 +299,10 @@ class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll {
val closure2 = () => a
val closure3 = () => localValue
val closure4 = () => someSerializableValue
val outerClasses1 = ClosureCleaner.getOuterClasses(closure1)
val outerClasses2 = ClosureCleaner.getOuterClasses(closure2)
val outerClasses3 = ClosureCleaner.getOuterClasses(closure3)
val outerClasses4 = ClosureCleaner.getOuterClasses(closure4)
val outerClasses1 = getOuterClasses(closure1)
val outerClasses2 = getOuterClasses(closure2)
val outerClasses3 = getOuterClasses(closure3)
val outerClasses4 = getOuterClasses(closure4)

// First, find only fields the closures directly access
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
Expand Down

0 comments on commit eb127e5

Please sign in to comment.