diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 5dbdde02e0133..1779c16050eae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -287,23 +287,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]], executors: List[String]): Array[ArrayBuffer[String]] = { val locations = new Array[ArrayBuffer[String]](receivers.length) - if (!executors.isEmpty) { - var i = 0 - for (i <- 0 until receivers.length) { - locations(i) = new ArrayBuffer[String]() - if (receivers(i).preferredLocation.isDefined) { - locations(i) += receivers(i).preferredLocation.get - } + var i = 0 + for (i <- 0 until receivers.length) { + locations(i) = new ArrayBuffer[String]() + if (receivers(i).preferredLocation.isDefined) { + locations(i) += receivers(i).preferredLocation.get } - - var count = 0; - for (i <- 0 until max(receivers.length, executors.length)) { - if (!receivers(i % receivers.length).preferredLocation.isDefined) { - locations(i % receivers.length) += executors(count) - count += 1; - if (count == executors.length) { - count = 0; - } + } + var count = 0 + for (i <- 0 until max(receivers.length, executors.length)) { + if (!receivers(i % receivers.length).preferredLocation.isDefined) { + locations(i % receivers.length) += executors(count) + count += 1 + if (count == executors.length) { + count = 0 } } } @@ -345,9 +342,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Get the list of executors and schedule receivers val executors = getExecutors(ssc) - val locations = scheduleReceivers(receivers, executors) val tempRDD = - if (locations(0) != null) { + if (!executors.isEmpty) { + val locations = scheduleReceivers(receivers, executors) val roundRobinReceivers = (0 until receivers.length).map(i => (receivers(i), locations(i))) ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 988135390e0f1..864dcf0d33d84 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -29,50 +29,45 @@ class ReceiverTrackerSuite extends TestSuiteBase { val ssc = new StreamingContext(sparkConf, Milliseconds(100)) val tracker = new ReceiverTracker(ssc) val launcher = new tracker.ReceiverLauncher() + val executors: List[String] = List("0", "1", "2", "3") - test("receiver scheduling - no preferred location") { - val numReceivers = 10; - val receivers = (1 to numReceivers).map(i => new DummyReceiver) - val executors: List[String] = List("Host1", "Host2", "Host3", "Host4", "Host5") - val locations = launcher.scheduleReceivers(receivers, executors) - assert(locations(0)(0) === "Host1") - assert(locations(4)(0) === "Host5") - assert(locations(5)(0) === "Host1") - assert(locations(9)(0) === "Host5") - } - - test("receiver scheduling - no preferred location, numExecutors > numReceivers") { - val numReceivers = 3; - val receivers = (1 to numReceivers).map(i => new DummyReceiver) - val executors: List[String] = List("Host1", "Host2", "Host3", "Host4", "Host5") - val locations = launcher.scheduleReceivers(receivers, executors) - assert(locations(0)(0) === "Host1") - assert(locations(2)(0) === "Host3") - assert(locations(0)(1) === "Host4") - assert(locations(1)(1) === "Host5") - } - - test("receiver scheduling - all have preferred location") { - val numReceivers = 5; - val receivers = (1 to numReceivers).map(i => new DummyReceiver(host = Some("Host" + i))) - val executors: List[String] = List("Host1", "Host5", "Host4", "Host3", "Host2") - val locations = launcher.scheduleReceivers(receivers, executors) - assert(locations(1)(0) === "Host2") - assert(locations(4)(0) === "Host5") + test("receiver scheduling - all or none have preferred location") { + def parse(s: String): Array[Array[String]] = { + val outerSplit = s.split("\\|") + val loc = new Array[Array[String]](outerSplit.length) + var i = 0 + for (i <- 0 until outerSplit.length) { + loc(i) = outerSplit(i).split("\\,") + } + loc + } + def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { + val receivers = + if (preferredLocation) { + (0 until numReceivers).map(i => new DummyReceiver(host = + Some(((i + 1) % executors.length).toString))) + } else { + (0 until numReceivers).map(i => new DummyReceiver) + } + val locations = launcher.scheduleReceivers(receivers, executors) + val expectedLocations = parse(allocation) + assert(locations.deep === expectedLocations.deep) + } + testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") + testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") + testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") } test("receiver scheduling - some have preferred location") { - val numReceivers = 3; - val receivers: Seq[Receiver[_]] = Seq( - new DummyReceiver(host = Some("Host2")), - new DummyReceiver, - new DummyReceiver) - val executors: List[String] = List("Host1", "Host2", "Host3", "Host4", "Host5") + val numReceivers = 4; + val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), + new DummyReceiver, new DummyReceiver, new DummyReceiver) val locations = launcher.scheduleReceivers(receivers, executors) - assert(locations(0)(0) === "Host2") - assert(locations(1)(0) === "Host1") - assert(locations(2)(0) === "Host2") - assert(locations(1)(1) === "Host3") + assert(locations(0)(0) === "1") + assert(locations(1)(0) === "0") + assert(locations(2)(0) === "1") + assert(locations(0).length === 1) + assert(locations(3).length === 1) } }