diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index cca8cedb1d080..6c0c926755c20 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -49,7 +49,6 @@ class ReceivedBlockHandlerSuite
val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
val hadoopConf = new Configuration()
- val storageLevel = StorageLevel.MEMORY_ONLY_SER
val streamId = 1
val securityMgr = new SecurityManager(conf)
val mapOutputTracker = new MapOutputTrackerMaster(conf)
@@ -57,10 +56,12 @@ class ReceivedBlockHandlerSuite
val serializer = new KryoSerializer(conf)
val manualClock = new ManualClock
val blockManagerSize = 10000000
+ val blockManagerBuffer = new ArrayBuffer[BlockManager]()
var rpcEnv: RpcEnv = null
var blockManagerMaster: BlockManagerMaster = null
var blockManager: BlockManager = null
+ var storageLevel: StorageLevel = null
var tempDirectory: File = null
before {
@@ -70,20 +71,21 @@ class ReceivedBlockHandlerSuite
blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true)
- blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer,
- blockManagerSize, conf, mapOutputTracker, shuffleManager,
- new NioBlockTransferService(conf, securityMgr), securityMgr, 0)
- blockManager.initialize("app-id")
+ storageLevel = StorageLevel.MEMORY_ONLY_SER
+ blockManager = createBlockManager(blockManagerSize, conf)
tempDirectory = Utils.createTempDir()
manualClock.setTime(0)
}
after {
- if (blockManager != null) {
- blockManager.stop()
- blockManager = null
+ for ( blockManager <- blockManagerBuffer ) {
+ if (blockManager != null) {
+ blockManager.stop()
+ }
}
+ blockManager = null
+ blockManagerBuffer.clear()
if (blockManagerMaster != null) {
blockManagerMaster.stop()
blockManagerMaster = null
@@ -174,6 +176,130 @@ class ReceivedBlockHandlerSuite
}
}
+ test("Test Block - count messages") {
+ // Test count with BlockManagedBasedBlockHandler
+ testCountWithBlockManagerBasedBlockHandler(true)
+ // Test count with WriteAheadLogBasedBlockHandler
+ testCountWithBlockManagerBasedBlockHandler(false)
+ }
+
+ test("Test Block - isFullyConsumed") {
+ val sparkConf = new SparkConf()
+ sparkConf.set("spark.storage.unrollMemoryThreshold", "512")
+ // spark.storage.unrollFraction set to 0.4 for BlockManager
+ sparkConf.set("spark.storage.unrollFraction", "0.4")
+ // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll
+ blockManager = createBlockManager(12000, sparkConf)
+
+ // there is not enough space to store this block in MEMORY,
+ // But BlockManager will be able to sereliaze this block to WAL
+ // and hence count returns correct value.
+ testRecordcount(false, StorageLevel.MEMORY_ONLY,
+ IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70))
+
+ // there is not enough space to store this block in MEMORY,
+ // But BlockManager will be able to sereliaze this block to DISK
+ // and hence count returns correct value.
+ testRecordcount(true, StorageLevel.MEMORY_AND_DISK,
+ IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70))
+
+ // there is not enough space to store this block With MEMORY_ONLY StorageLevel.
+ // BlockManager will not be able to unroll this block
+ // and hence it will not tryToPut this block, resulting the SparkException
+ storageLevel = StorageLevel.MEMORY_ONLY
+ withBlockManagerBasedBlockHandler { handler =>
+ val thrown = intercept[SparkException] {
+ storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator))
+ }
+ }
+ }
+
+ private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) {
+ // ByteBufferBlock-MEMORY_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY,
+ ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None)
+ // ByteBufferBlock-MEMORY_ONLY_SER
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER,
+ ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None)
+ // ArrayBufferBlock-MEMORY_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY,
+ ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25))
+ // ArrayBufferBlock-MEMORY_ONLY_SER
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER,
+ ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25))
+ // ArrayBufferBlock-DISK_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY,
+ ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50))
+ // ArrayBufferBlock-MEMORY_AND_DISK
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK,
+ ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75))
+ // IteratorBlock-MEMORY_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY,
+ IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100))
+ // IteratorBlock-MEMORY_ONLY_SER
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER,
+ IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100))
+ // IteratorBlock-DISK_ONLY
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY,
+ IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125))
+ // IteratorBlock-MEMORY_AND_DISK
+ testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK,
+ IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150))
+ }
+
+ private def createBlockManager(
+ maxMem: Long,
+ conf: SparkConf,
+ name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
+ val transfer = new NioBlockTransferService(conf, securityMgr)
+ val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf,
+ mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
+ manager.initialize("app-id")
+ blockManagerBuffer += manager
+ manager
+ }
+
+ /**
+ * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks
+ * and verify the correct record count
+ */
+ private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean,
+ sLevel: StorageLevel,
+ receivedBlock: ReceivedBlock,
+ bManager: BlockManager,
+ expectedNumRecords: Option[Long]
+ ) {
+ blockManager = bManager
+ storageLevel = sLevel
+ var bId: StreamBlockId = null
+ try {
+ if (isBlockManagedBasedBlockHandler) {
+ // test received block with BlockManager based handler
+ withBlockManagerBasedBlockHandler { handler =>
+ val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock)
+ bId = blockId
+ assert(blockStoreResult.numRecords === expectedNumRecords,
+ "Message count not matches for a " +
+ receivedBlock.getClass.getName +
+ " being inserted using BlockManagerBasedBlockHandler with " + sLevel)
+ }
+ } else {
+ // test received block with WAL based handler
+ withWriteAheadLogBasedBlockHandler { handler =>
+ val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock)
+ bId = blockId
+ assert(blockStoreResult.numRecords === expectedNumRecords,
+ "Message count not matches for a " +
+ receivedBlock.getClass.getName +
+ " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel)
+ }
+ }
+ } finally {
+ // Removing the Block Id to use same blockManager for next test
+ blockManager.removeBlock(bId, true)
+ }
+ }
+
/**
* Test storing of data using different forms of ReceivedBlocks and verify that they succeeded
* using the given verification function
@@ -251,9 +377,21 @@ class ReceivedBlockHandlerSuite
(blockIds, storeResults)
}
+ /** Store single block using a handler */
+ private def storeSingleBlock(
+ handler: ReceivedBlockHandler,
+ block: ReceivedBlock
+ ): (StreamBlockId, ReceivedBlockStoreResult) = {
+ val blockId = generateBlockId
+ val blockStoreResult = handler.storeBlock(blockId, block)
+ logDebug("Done inserting")
+ (blockId, blockStoreResult)
+ }
+
private def getWriteAheadLogFiles(): Seq[String] = {
getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId))
}
private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong)
}
+
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index be305b5e0dfea..f793a12843b2f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -225,7 +225,7 @@ class ReceivedBlockTrackerSuite
/** Generate blocks infos using random ids */
def generateBlockInfos(): Seq[ReceivedBlockInfo] = {
List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None,
- BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)))))
+ BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L))))
}
/** Get all the data written in the given write ahead log file. */
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index 1dc8960d60528..7bc7727a9fbe4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -116,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
ssc.start()
try {
- eventually(timeout(2000 millis), interval(20 millis)) {
+ eventually(timeout(30 seconds), interval(20 millis)) {
collector.startedReceiverStreamIds.size should equal (1)
collector.startedReceiverStreamIds(0) should equal (0)
collector.stoppedReceiverStreamIds should have size 1
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
index cbc24aee4fa1e..a08578680cff9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
@@ -27,9 +27,10 @@ import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.ui.SparkUICssErrorHandler
/**
- * Selenium tests for the Spark Web UI.
+ * Selenium tests for the Spark Streaming Web UI.
*/
class UISeleniumSuite
extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase {
@@ -37,7 +38,9 @@ class UISeleniumSuite
implicit var webDriver: WebDriver = _
override def beforeAll(): Unit = {
- webDriver = new HtmlUnitDriver
+ webDriver = new HtmlUnitDriver {
+ getWebClient.setCssErrorHandler(new SparkUICssErrorHandler)
+ }
}
override def afterAll(): Unit = {
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 62c6354f1e203..33782c6c66f90 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -67,7 +67,7 @@
org.mockito
- mockito-all
+ mockito-core
test
@@ -80,7 +80,7 @@
net.alchim31.maven
scala-maven-plugin
-
+
-XDignore.symbol.file
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
new file mode 100644
index 0000000000000..9302b472925ed
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.types;
+
+import javax.annotation.Nonnull;
+import java.io.Serializable;
+import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * A UTF-8 String for internal Spark use.
+ *
+ * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison,
+ * search, see http://en.wikipedia.org/wiki/UTF-8 for details.
+ *
+ * Note: This is not designed for general use cases, should not be used outside SQL.
+ */
+public final class UTF8String implements Comparable, Serializable {
+
+ @Nonnull
+ private byte[] bytes;
+
+ private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5,
+ 6, 6, 6, 6};
+
+ public static UTF8String fromBytes(byte[] bytes) {
+ return (bytes != null) ? new UTF8String().set(bytes) : null;
+ }
+
+ public static UTF8String fromString(String str) {
+ return (str != null) ? new UTF8String().set(str) : null;
+ }
+
+ /**
+ * Updates the UTF8String with String.
+ */
+ protected UTF8String set(final String str) {
+ try {
+ bytes = str.getBytes("utf-8");
+ } catch (UnsupportedEncodingException e) {
+ // Turn the exception into unchecked so we can find out about it at runtime, but
+ // don't need to add lots of boilerplate code everywhere.
+ PlatformDependent.throwException(e);
+ }
+ return this;
+ }
+
+ /**
+ * Updates the UTF8String with byte[], which should be encoded in UTF-8.
+ */
+ protected UTF8String set(final byte[] bytes) {
+ this.bytes = bytes;
+ return this;
+ }
+
+ /**
+ * Returns the number of bytes for a code point with the first byte as `b`
+ * @param b The first byte of a code point
+ */
+ public int numBytes(final byte b) {
+ final int offset = (b & 0xFF) - 192;
+ return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1;
+ }
+
+ /**
+ * Returns the number of code points in it.
+ *
+ * This is only used by Substring() when `start` is negative.
+ */
+ public int length() {
+ int len = 0;
+ for (int i = 0; i < bytes.length; i+= numBytes(bytes[i])) {
+ len += 1;
+ }
+ return len;
+ }
+
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ /**
+ * Returns a substring of this.
+ * @param start the position of first code point
+ * @param until the position after last code point, exclusive.
+ */
+ public UTF8String substring(final int start, final int until) {
+ if (until <= start || start >= bytes.length) {
+ return UTF8String.fromBytes(new byte[0]);
+ }
+
+ int i = 0;
+ int c = 0;
+ for (; i < bytes.length && c < start; i += numBytes(bytes[i])) {
+ c += 1;
+ }
+
+ int j = i;
+ for (; j < bytes.length && c < until; j += numBytes(bytes[i])) {
+ c += 1;
+ }
+
+ return UTF8String.fromBytes(Arrays.copyOfRange(bytes, i, j));
+ }
+
+ public boolean contains(final UTF8String substring) {
+ final byte[] b = substring.getBytes();
+ if (b.length == 0) {
+ return true;
+ }
+
+ for (int i = 0; i <= bytes.length - b.length; i++) {
+ if (bytes[i] == b[0] && startsWith(b, i)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private boolean startsWith(final byte[] prefix, int offsetInBytes) {
+ if (prefix.length + offsetInBytes > bytes.length || offsetInBytes < 0) {
+ return false;
+ }
+ int i = 0;
+ while (i < prefix.length && prefix[i] == bytes[i + offsetInBytes]) {
+ i++;
+ }
+ return i == prefix.length;
+ }
+
+ public boolean startsWith(final UTF8String prefix) {
+ return startsWith(prefix.getBytes(), 0);
+ }
+
+ public boolean endsWith(final UTF8String suffix) {
+ return startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length);
+ }
+
+ public UTF8String toUpperCase() {
+ return UTF8String.fromString(toString().toUpperCase());
+ }
+
+ public UTF8String toLowerCase() {
+ return UTF8String.fromString(toString().toLowerCase());
+ }
+
+ @Override
+ public String toString() {
+ try {
+ return new String(bytes, "utf-8");
+ } catch (UnsupportedEncodingException e) {
+ // Turn the exception into unchecked so we can find out about it at runtime, but
+ // don't need to add lots of boilerplate code everywhere.
+ PlatformDependent.throwException(e);
+ return "unknown"; // we will never reach here.
+ }
+ }
+
+ @Override
+ public UTF8String clone() {
+ return new UTF8String().set(bytes);
+ }
+
+ @Override
+ public int compareTo(final UTF8String other) {
+ final byte[] b = other.getBytes();
+ for (int i = 0; i < bytes.length && i < b.length; i++) {
+ int res = bytes[i] - b[i];
+ if (res != 0) {
+ return res;
+ }
+ }
+ return bytes.length - b.length;
+ }
+
+ public int compare(final UTF8String other) {
+ return compareTo(other);
+ }
+
+ @Override
+ public boolean equals(final Object other) {
+ if (other instanceof UTF8String) {
+ return Arrays.equals(bytes, ((UTF8String) other).getBytes());
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(bytes);
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
index 18393db9f382f..a93fc0ee297c4 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java
@@ -18,7 +18,6 @@
package org.apache.spark.unsafe.bitset;
import junit.framework.Assert;
-import org.apache.spark.unsafe.bitset.BitSet;
import org.junit.Test;
import org.apache.spark.unsafe.memory.MemoryBlock;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
new file mode 100644
index 0000000000000..796cdc9dbebdb
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -0,0 +1,91 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.unsafe.types;
+
+import java.io.UnsupportedEncodingException;
+
+import junit.framework.Assert;
+import org.junit.Test;
+
+public class UTF8StringSuite {
+
+ private void checkBasic(String str, int len) throws UnsupportedEncodingException {
+ Assert.assertEquals(UTF8String.fromString(str).length(), len);
+ Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len);
+
+ Assert.assertEquals(UTF8String.fromString(str).toString(), str);
+ Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str);
+ Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str));
+
+ Assert.assertEquals(UTF8String.fromString(str).hashCode(),
+ UTF8String.fromBytes(str.getBytes("utf8")).hashCode());
+ }
+
+ @Test
+ public void basicTest() throws UnsupportedEncodingException {
+ checkBasic("hello", 5);
+ checkBasic("世 界", 3);
+ }
+
+ @Test
+ public void contains() {
+ Assert.assertTrue(UTF8String.fromString("hello").contains(UTF8String.fromString("ello")));
+ Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("vello")));
+ Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("hellooo")));
+ Assert.assertTrue(UTF8String.fromString("大千世界").contains(UTF8String.fromString("千世")));
+ Assert.assertFalse(UTF8String.fromString("大千世界").contains(UTF8String.fromString("世千")));
+ Assert.assertFalse(
+ UTF8String.fromString("大千世界").contains(UTF8String.fromString("大千世界好")));
+ }
+
+ @Test
+ public void startsWith() {
+ Assert.assertTrue(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hell")));
+ Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("ell")));
+ Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hellooo")));
+ Assert.assertTrue(UTF8String.fromString("数据砖头").startsWith(UTF8String.fromString("数据")));
+ Assert.assertFalse(UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("千")));
+ Assert.assertFalse(
+ UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("大千世界好")));
+ }
+
+ @Test
+ public void endsWith() {
+ Assert.assertTrue(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ello")));
+ Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ellov")));
+ Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("hhhello")));
+ Assert.assertTrue(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世界")));
+ Assert.assertFalse(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世")));
+ Assert.assertFalse(
+ UTF8String.fromString("数据砖头").endsWith(UTF8String.fromString("我的数据砖头")));
+ }
+
+ @Test
+ public void substring() {
+ Assert.assertEquals(
+ UTF8String.fromString("hello").substring(0, 0), UTF8String.fromString(""));
+ Assert.assertEquals(
+ UTF8String.fromString("hello").substring(1, 3), UTF8String.fromString("el"));
+ Assert.assertEquals(
+ UTF8String.fromString("数据砖头").substring(0, 1), UTF8String.fromString("数"));
+ Assert.assertEquals(
+ UTF8String.fromString("数据砖头").substring(1, 3), UTF8String.fromString("据砖"));
+ Assert.assertEquals(
+ UTF8String.fromString("数据砖头").substring(3, 5), UTF8String.fromString("头"));
+ }
+}
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 644def7501dc8..2aeed98285aa8 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -107,7 +107,7 @@
org.mockito
- mockito-all
+ mockito-core
test
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 002d7b6eaf498..83dafa4a125d2 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -32,7 +32,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.SparkException
-import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -46,6 +46,14 @@ private[spark] class ApplicationMaster(
client: YarnRMClient)
extends Logging {
+ // Load the properties file with the Spark configuration and set entries as system properties,
+ // so that user code run inside the AM also has access to them.
+ if (args.propertiesFile != null) {
+ Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) =>
+ sys.props(k) = v
+ }
+ }
+
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
// optimal as more containers are available. Might need to handle this better.
@@ -490,9 +498,11 @@ private[spark] class ApplicationMaster(
new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
}
+ var userArgs = args.userArgs
if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
- System.setProperty("spark.submit.pyFiles",
- PythonRunner.formatPaths(args.pyFiles).mkString(","))
+ // When running pyspark, the app is run using PythonRunner. The second argument is the list
+ // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty.
+ userArgs = Seq(args.primaryPyFile, "") ++ userArgs
}
if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
// TODO(davies): add R dependencies here
@@ -503,9 +513,7 @@ private[spark] class ApplicationMaster(
val userThread = new Thread {
override def run() {
try {
- val mainArgs = new Array[String](args.userArgs.size)
- args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
- mainMethod.invoke(null, mainArgs)
+ mainMethod.invoke(null, userArgs.toArray)
finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
logDebug("Done running users class")
} catch {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
index ae6dc1094d724..68e9f6b4db7f4 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -26,11 +26,11 @@ class ApplicationMasterArguments(val args: Array[String]) {
var userClass: String = null
var primaryPyFile: String = null
var primaryRFile: String = null
- var pyFiles: String = null
- var userArgs: Seq[String] = Seq[String]()
+ var userArgs: Seq[String] = Nil
var executorMemory = 1024
var executorCores = 1
var numExecutors = DEFAULT_NUMBER_EXECUTORS
+ var propertiesFile: String = null
parseArgs(args.toList)
@@ -59,10 +59,6 @@ class ApplicationMasterArguments(val args: Array[String]) {
primaryRFile = value
args = tail
- case ("--py-files") :: value :: tail =>
- pyFiles = value
- args = tail
-
case ("--args" | "--arg") :: value :: tail =>
userArgsBuffer += value
args = tail
@@ -79,6 +75,10 @@ class ApplicationMasterArguments(val args: Array[String]) {
executorCores = value
args = tail
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ args = tail
+
case _ =>
printUsageAndExit(1, args)
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 234051eb7d3bb..67a5c95400e53 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -17,18 +17,21 @@
package org.apache.spark.deploy.yarn
-import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException}
+import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException,
+ OutputStreamWriter}
import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException}
import java.nio.ByteBuffer
import java.security.PrivilegedExceptionAction
-import java.util.UUID
+import java.util.{Properties, UUID}
import java.util.zip.{ZipEntry, ZipOutputStream}
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map}
import scala.reflect.runtime.universe
import scala.util.{Try, Success, Failure}
+import scala.util.control.NonFatal
+import com.google.common.base.Charsets.UTF_8
import com.google.common.base.Objects
import com.google.common.io.Files
@@ -121,24 +124,31 @@ private[spark] class Client(
} catch {
case e: Throwable =>
if (appId != null) {
- val appStagingDir = getAppStagingDir(appId)
- try {
- val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false)
- val stagingDirPath = new Path(appStagingDir)
- val fs = FileSystem.get(hadoopConf)
- if (!preserveFiles && fs.exists(stagingDirPath)) {
- logInfo("Deleting staging directory " + stagingDirPath)
- fs.delete(stagingDirPath, true)
- }
- } catch {
- case ioe: IOException =>
- logWarning("Failed to cleanup staging dir " + appStagingDir, ioe)
- }
+ cleanupStagingDir(appId)
}
throw e
}
}
+ /**
+ * Cleanup application staging directory.
+ */
+ private def cleanupStagingDir(appId: ApplicationId): Unit = {
+ val appStagingDir = getAppStagingDir(appId)
+ try {
+ val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false)
+ val stagingDirPath = new Path(appStagingDir)
+ val fs = FileSystem.get(hadoopConf)
+ if (!preserveFiles && fs.exists(stagingDirPath)) {
+ logInfo("Deleting staging directory " + stagingDirPath)
+ fs.delete(stagingDirPath, true)
+ }
+ } catch {
+ case ioe: IOException =>
+ logWarning("Failed to cleanup staging dir " + appStagingDir, ioe)
+ }
+ }
+
/**
* Set up the context for submitting our ApplicationMaster.
* This uses the YarnClientApplication not available in the Yarn alpha API.
@@ -240,7 +250,9 @@ private[spark] class Client(
* This is used for setting up a container launch context for our ApplicationMaster.
* Exposed for testing.
*/
- def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
+ def prepareLocalResources(
+ appStagingDir: String,
+ pySparkArchives: Seq[String]): HashMap[String, LocalResource] = {
logInfo("Preparing resources for our AM container")
// Upload Spark and the application JAR to the remote file system if necessary,
// and add them as local resources to the application master.
@@ -270,20 +282,6 @@ private[spark] class Client(
"for alternatives.")
}
- // If we passed in a keytab, make sure we copy the keytab to the staging directory on
- // HDFS, and setup the relevant environment vars, so the AM can login again.
- if (loginFromKeytab) {
- logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" +
- " via the YARN Secure Distributed Cache.")
- val localUri = new URI(args.keytab)
- val localPath = getQualifiedLocalPath(localUri, hadoopConf)
- val destinationPath = copyFileToRemote(dst, localPath, replication)
- val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf)
- distCacheMgr.addResource(
- destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE,
- sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true)
- }
-
def addDistributedUri(uri: URI): Boolean = {
val uriStr = uri.toString()
if (distributedUris.contains(uriStr)) {
@@ -295,6 +293,57 @@ private[spark] class Client(
}
}
+ /**
+ * Distribute a file to the cluster.
+ *
+ * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied
+ * to HDFS (if not already there) and added to the application's distributed cache.
+ *
+ * @param path URI of the file to distribute.
+ * @param resType Type of resource being distributed.
+ * @param destName Name of the file in the distributed cache.
+ * @param targetDir Subdirectory where to place the file.
+ * @param appMasterOnly Whether to distribute only to the AM.
+ * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the
+ * localized path for non-local paths, or the input `path` for local paths.
+ * The localized path will be null if the URI has already been added to the cache.
+ */
+ def distribute(
+ path: String,
+ resType: LocalResourceType = LocalResourceType.FILE,
+ destName: Option[String] = None,
+ targetDir: Option[String] = None,
+ appMasterOnly: Boolean = false): (Boolean, String) = {
+ val localURI = new URI(path.trim())
+ if (localURI.getScheme != LOCAL_SCHEME) {
+ if (addDistributedUri(localURI)) {
+ val localPath = getQualifiedLocalPath(localURI, hadoopConf)
+ val linkname = targetDir.map(_ + "/").getOrElse("") +
+ destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName())
+ val destPath = copyFileToRemote(dst, localPath, replication)
+ distCacheMgr.addResource(
+ fs, hadoopConf, destPath, localResources, resType, linkname, statCache,
+ appMasterOnly = appMasterOnly)
+ (false, linkname)
+ } else {
+ (false, null)
+ }
+ } else {
+ (true, path.trim())
+ }
+ }
+
+ // If we passed in a keytab, make sure we copy the keytab to the staging directory on
+ // HDFS, and setup the relevant environment vars, so the AM can login again.
+ if (loginFromKeytab) {
+ logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" +
+ " via the YARN Secure Distributed Cache.")
+ val (_, localizedPath) = distribute(args.keytab,
+ destName = Some(sparkConf.get("spark.yarn.keytab")),
+ appMasterOnly = true)
+ require(localizedPath != null, "Keytab file already distributed.")
+ }
+
/**
* Copy the given main resource to the distributed cache if the scheme is not "local".
* Otherwise, set the corresponding key in our SparkConf to handle it downstream.
@@ -307,33 +356,18 @@ private[spark] class Client(
(SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR),
(APP_JAR, args.userJar, CONF_SPARK_USER_JAR),
("log4j.properties", oldLog4jConf.orNull, null)
- ).foreach { case (destName, _localPath, confKey) =>
- val localPath: String = if (_localPath != null) _localPath.trim() else ""
- if (!localPath.isEmpty()) {
- val localURI = new URI(localPath)
- if (localURI.getScheme != LOCAL_SCHEME) {
- if (addDistributedUri(localURI)) {
- val src = getQualifiedLocalPath(localURI, hadoopConf)
- val destPath = copyFileToRemote(dst, src, replication)
- val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
- distCacheMgr.addResource(destFs, hadoopConf, destPath,
- localResources, LocalResourceType.FILE, destName, statCache)
- }
- } else if (confKey != null) {
+ ).foreach { case (destName, path, confKey) =>
+ if (path != null && !path.trim().isEmpty()) {
+ val (isLocal, localizedPath) = distribute(path, destName = Some(destName))
+ if (isLocal && confKey != null) {
+ require(localizedPath != null, s"Path $path already distributed.")
// If the resource is intended for local use only, handle this downstream
// by setting the appropriate property
- sparkConf.set(confKey, localPath)
+ sparkConf.set(confKey, localizedPath)
}
}
}
- createConfArchive().foreach { file =>
- require(addDistributedUri(file.toURI()))
- val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication)
- distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE,
- LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true)
- }
-
/**
* Do the same for any additional resources passed in through ClientArguments.
* Each resource category is represented by a 3-tuple of:
@@ -349,21 +383,10 @@ private[spark] class Client(
).foreach { case (flist, resType, addToClasspath) =>
if (flist != null && !flist.isEmpty()) {
flist.split(',').foreach { file =>
- val localURI = new URI(file.trim())
- if (localURI.getScheme != LOCAL_SCHEME) {
- if (addDistributedUri(localURI)) {
- val localPath = new Path(localURI)
- val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
- val destPath = copyFileToRemote(dst, localPath, replication)
- distCacheMgr.addResource(
- fs, hadoopConf, destPath, localResources, resType, linkname, statCache)
- if (addToClasspath) {
- cachedSecondaryJarLinks += linkname
- }
- }
- } else if (addToClasspath) {
- // Resource is intended for local use only and should be added to the class path
- cachedSecondaryJarLinks += file.trim()
+ val (_, localizedPath) = distribute(file, resType = resType)
+ require(localizedPath != null)
+ if (addToClasspath) {
+ cachedSecondaryJarLinks += localizedPath
}
}
}
@@ -372,11 +395,31 @@ private[spark] class Client(
sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(","))
}
+ if (isClusterMode && args.primaryPyFile != null) {
+ distribute(args.primaryPyFile, appMasterOnly = true)
+ }
+
+ pySparkArchives.foreach { f => distribute(f) }
+
+ // The python files list needs to be treated especially. All files that are not an
+ // archive need to be placed in a subdirectory that will be added to PYTHONPATH.
+ args.pyFiles.foreach { f =>
+ val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None
+ distribute(f, targetDir = targetDir)
+ }
+
+ // Distribute an archive with Hadoop and Spark configuration for the AM.
+ val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(),
+ resType = LocalResourceType.ARCHIVE,
+ destName = Some(LOCALIZED_CONF_DIR),
+ appMasterOnly = true)
+ require(confLocalizedPath != null)
+
localResources
}
/**
- * Create an archive with the Hadoop config files for distribution.
+ * Create an archive with the config files for distribution.
*
* These are only used by the AM, since executors will use the configuration object broadcast by
* the driver. The files are zipped and added to the job as an archive, so that YARN will explode
@@ -388,8 +431,11 @@ private[spark] class Client(
*
* Currently this makes a shallow copy of the conf directory. If there are cases where a
* Hadoop config directory contains subdirectories, this code will have to be fixed.
+ *
+ * The archive also contains some Spark configuration. Namely, it saves the contents of
+ * SparkConf in a file to be loaded by the AM process.
*/
- private def createConfArchive(): Option[File] = {
+ private def createConfArchive(): File = {
val hadoopConfFiles = new HashMap[String, File]()
Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey =>
sys.env.get(envKey).foreach { path =>
@@ -404,28 +450,32 @@ private[spark] class Client(
}
}
- if (!hadoopConfFiles.isEmpty) {
- val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip",
- new File(Utils.getLocalDir(sparkConf)))
+ val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip",
+ new File(Utils.getLocalDir(sparkConf)))
+ val confStream = new ZipOutputStream(new FileOutputStream(confArchive))
- val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive))
- try {
- hadoopConfStream.setLevel(0)
- hadoopConfFiles.foreach { case (name, file) =>
- if (file.canRead()) {
- hadoopConfStream.putNextEntry(new ZipEntry(name))
- Files.copy(file, hadoopConfStream)
- hadoopConfStream.closeEntry()
- }
+ try {
+ confStream.setLevel(0)
+ hadoopConfFiles.foreach { case (name, file) =>
+ if (file.canRead()) {
+ confStream.putNextEntry(new ZipEntry(name))
+ Files.copy(file, confStream)
+ confStream.closeEntry()
}
- } finally {
- hadoopConfStream.close()
}
- Some(hadoopConfArchive)
- } else {
- None
+ // Save Spark configuration to a file in the archive.
+ val props = new Properties()
+ sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) }
+ confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE))
+ val writer = new OutputStreamWriter(confStream, UTF_8)
+ props.store(writer, "Spark configuration.")
+ writer.flush()
+ confStream.closeEntry()
+ } finally {
+ confStream.close()
}
+ confArchive
}
/**
@@ -453,7 +503,9 @@ private[spark] class Client(
/**
* Set up the environment for launching our ApplicationMaster container.
*/
- private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = {
+ private def setupLaunchEnv(
+ stagingDir: String,
+ pySparkArchives: Seq[String]): HashMap[String, String] = {
logInfo("Setting up the launch environment for our AM container")
val env = new HashMap[String, String]()
val extraCp = sparkConf.getOption("spark.driver.extraClassPath")
@@ -471,9 +523,6 @@ private[spark] class Client(
val renewalInterval = getTokenRenewalInterval(stagingDirPath)
sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString)
}
- // Set the environment variables to be passed on to the executors.
- distCacheMgr.setDistFilesEnv(env)
- distCacheMgr.setDistArchivesEnv(env)
// Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.*
val amEnvPrefix = "spark.yarn.appMasterEnv."
@@ -490,15 +539,32 @@ private[spark] class Client(
env("SPARK_YARN_USER_ENV") = userEnvs
}
- // if spark.submit.pyArchives is in sparkConf, append pyArchives to PYTHONPATH
- // that can be passed on to the ApplicationMaster and the executors.
- if (sparkConf.contains("spark.submit.pyArchives")) {
- var pythonPath = sparkConf.get("spark.submit.pyArchives")
- if (env.contains("PYTHONPATH")) {
- pythonPath = Seq(env.get("PYTHONPATH"), pythonPath).mkString(File.pathSeparator)
+ // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH
+ // of the container processes too. Add all non-.py files directly to PYTHONPATH.
+ //
+ // NOTE: the code currently does not handle .py files defined with a "local:" scheme.
+ val pythonPath = new ListBuffer[String]()
+ val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py"))
+ if (pyFiles.nonEmpty) {
+ pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ LOCALIZED_PYTHON_DIR)
+ }
+ (pySparkArchives ++ pyArchives).foreach { path =>
+ val uri = new URI(path)
+ if (uri.getScheme != LOCAL_SCHEME) {
+ pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ new Path(path).getName())
+ } else {
+ pythonPath += uri.getPath()
}
- env("PYTHONPATH") = pythonPath
- sparkConf.setExecutorEnv("PYTHONPATH", pythonPath)
+ }
+
+ // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors.
+ if (pythonPath.nonEmpty) {
+ val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath)
+ .mkString(YarnSparkHadoopUtil.getClassPathSeparator)
+ env("PYTHONPATH") = pythonPathStr
+ sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr)
}
// In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to
@@ -548,8 +614,19 @@ private[spark] class Client(
logInfo("Setting up container launch context for our AM")
val appId = newAppResponse.getApplicationId
val appStagingDir = getAppStagingDir(appId)
- val localResources = prepareLocalResources(appStagingDir)
- val launchEnv = setupLaunchEnv(appStagingDir)
+ val pySparkArchives =
+ if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) {
+ findPySparkArchives()
+ } else {
+ Nil
+ }
+ val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives)
+ val localResources = prepareLocalResources(appStagingDir, pySparkArchives)
+
+ // Set the environment variables to be passed on to the executors.
+ distCacheMgr.setDistFilesEnv(launchEnv)
+ distCacheMgr.setDistArchivesEnv(launchEnv)
+
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
amContainer.setLocalResources(localResources)
amContainer.setEnvironment(launchEnv)
@@ -589,13 +666,6 @@ private[spark] class Client(
javaOpts += "-XX:CMSIncrementalDutyCycle=10"
}
- // Forward the Spark configuration to the application master / executors.
- // TODO: it might be nicer to pass these as an internal environment variable rather than
- // as Java options, due to complications with string parsing of nested quotes.
- for ((k, v) <- sparkConf.getAll) {
- javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v")
- }
-
// Include driver-specific java options if we are launching a driver
if (isClusterMode) {
val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions")
@@ -606,7 +676,7 @@ private[spark] class Client(
val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"),
sys.props.get("spark.driver.libraryPath")).flatten
if (libraryPaths.nonEmpty) {
- prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths))
+ prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths)))
}
if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) {
logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode")
@@ -628,7 +698,7 @@ private[spark] class Client(
}
sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths =>
- prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths)))
+ prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths))))
}
}
@@ -648,14 +718,8 @@ private[spark] class Client(
Nil
}
val primaryPyFile =
- if (args.primaryPyFile != null) {
- Seq("--primary-py-file", args.primaryPyFile)
- } else {
- Nil
- }
- val pyFiles =
- if (args.pyFiles != null) {
- Seq("--py-files", args.pyFiles)
+ if (isClusterMode && args.primaryPyFile != null) {
+ Seq("--primary-py-file", new Path(args.primaryPyFile).getName())
} else {
Nil
}
@@ -671,9 +735,6 @@ private[spark] class Client(
} else {
Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName
}
- if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
- args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs
- }
if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs
}
@@ -681,11 +742,13 @@ private[spark] class Client(
Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg))
}
val amArgs =
- Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++
+ Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++
userArgs ++ Seq(
"--executor-memory", args.executorMemory.toString + "m",
"--executor-cores", args.executorCores.toString,
- "--num-executors ", args.numExecutors.toString)
+ "--num-executors ", args.numExecutors.toString,
+ "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
+ LOCALIZED_CONF_DIR, SPARK_CONF_FILE))
// Command for the ApplicationMaster
val commands = prefixEnv ++ Seq(
@@ -764,6 +827,9 @@ private[spark] class Client(
case e: ApplicationNotFoundException =>
logError(s"Application $appId not found.")
return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED)
+ case NonFatal(e) =>
+ logError(s"Failed to contact YARN for application $appId.", e)
+ return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED)
}
val state = report.getYarnApplicationState
@@ -782,6 +848,7 @@ private[spark] class Client(
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
+ cleanupStagingDir(appId)
return (state, report.getFinalApplicationStatus)
}
@@ -849,6 +916,22 @@ private[spark] class Client(
}
}
}
+
+ private def findPySparkArchives(): Seq[String] = {
+ sys.env.get("PYSPARK_ARCHIVES_PATH")
+ .map(_.split(",").toSeq)
+ .getOrElse {
+ val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator)
+ val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
+ require(pyArchivesFile.exists(),
+ "pyspark.zip not found; cannot run pyspark application in YARN mode.")
+ val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip")
+ require(py4jFile.exists(),
+ "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.")
+ Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath())
+ }
+ }
+
}
object Client extends Logging {
@@ -899,8 +982,14 @@ object Client extends Logging {
// Distribution-defined classpath to add to processes
val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH"
- // Subdirectory where the user's hadoop config files will be placed.
- val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__"
+ // Subdirectory where the user's Spark and Hadoop config files will be placed.
+ val LOCALIZED_CONF_DIR = "__spark_conf__"
+
+ // Name of the file in the conf archive containing Spark configuration.
+ val SPARK_CONF_FILE = "__spark_conf__.properties"
+
+ // Subdirectory where the user's python files (not archives) will be placed.
+ val LOCALIZED_PYTHON_DIR = "__pyfiles__"
/**
* Find the user-defined Spark jar if configured, or return the jar containing this
@@ -1017,15 +1106,15 @@ object Client extends Logging {
env: HashMap[String, String],
isAM: Boolean,
extraClassPath: Option[String] = None): Unit = {
- extraClassPath.foreach(addClasspathEntry(_, env))
- addClasspathEntry(
- YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env
- )
+ extraClassPath.foreach { cp =>
+ addClasspathEntry(getClusterPath(sparkConf, cp), env)
+ }
+ addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env)
if (isAM) {
addClasspathEntry(
YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR +
- LOCALIZED_HADOOP_CONF_DIR, env)
+ LOCALIZED_CONF_DIR, env)
}
if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) {
@@ -1036,12 +1125,14 @@ object Client extends Logging {
getUserClasspath(sparkConf)
}
userClassPath.foreach { x =>
- addFileToClasspath(x, null, env)
+ addFileToClasspath(sparkConf, x, null, env)
}
}
- addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env)
+ addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env)
populateHadoopClasspath(conf, env)
- sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env))
+ sys.env.get(ENV_DIST_CLASSPATH).foreach { cp =>
+ addClasspathEntry(getClusterPath(sparkConf, cp), env)
+ }
}
/**
@@ -1070,16 +1161,18 @@ object Client extends Logging {
*
* If not a "local:" file and no alternate name, the environment is not modified.
*
+ * @parma conf Spark configuration.
* @param uri URI to add to classpath (optional).
* @param fileName Alternate name for the file (optional).
* @param env Map holding the environment variables.
*/
private def addFileToClasspath(
+ conf: SparkConf,
uri: URI,
fileName: String,
env: HashMap[String, String]): Unit = {
if (uri != null && uri.getScheme == LOCAL_SCHEME) {
- addClasspathEntry(uri.getPath, env)
+ addClasspathEntry(getClusterPath(conf, uri.getPath), env)
} else if (fileName != null) {
addClasspathEntry(buildPath(
YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env)
@@ -1093,6 +1186,29 @@ object Client extends Logging {
private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit =
YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path)
+ /**
+ * Returns the path to be sent to the NM for a path that is valid on the gateway.
+ *
+ * This method uses two configuration values:
+ *
+ * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may
+ * only be valid in the gateway node.
+ * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may
+ * contain, for example, env variable references, which will be expanded by the NMs when
+ * starting containers.
+ *
+ * If either config is not available, the input path is returned.
+ */
+ def getClusterPath(conf: SparkConf, path: String): String = {
+ val localPath = conf.get("spark.yarn.config.gatewayPath", null)
+ val clusterPath = conf.get("spark.yarn.config.replacementPath", null)
+ if (localPath != null && clusterPath != null) {
+ path.replace(localPath, clusterPath)
+ } else {
+ path
+ }
+ }
+
/**
* Obtains token for the Hive metastore and adds them to the credentials.
*/
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 9c7b1b3988082..35e990602a6cf 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -30,7 +30,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
var archives: String = null
var userJar: String = null
var userClass: String = null
- var pyFiles: String = null
+ var pyFiles: Seq[String] = Nil
var primaryPyFile: String = null
var primaryRFile: String = null
var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
@@ -228,7 +228,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
args = tail
case ("--py-files") :: value :: tail =>
- pyFiles = value
+ pyFiles = value.split(",")
args = tail
case ("--files") :: value :: tail =>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 9d04d241dae9e..78e27fb7f3337 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -146,7 +146,7 @@ class ExecutorRunnable(
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
sys.props.get("spark.executor.extraLibraryPath").foreach { p =>
- prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p)))
+ prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p))))
}
javaOpts += "-Djava.io.tmpdir=" +
@@ -195,7 +195,7 @@ class ExecutorRunnable(
val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri =>
val absPath =
if (new File(uri.getPath()).isAbsolute()) {
- uri.getPath()
+ Client.getClusterPath(sparkConf, uri.getPath())
} else {
Client.buildPath(Environment.PWD.$(), uri.getPath())
}
@@ -303,8 +303,8 @@ class ExecutorRunnable(
val address = container.getNodeHttpAddress
val baseUrl = s"$httpScheme$address/node/containerlogs/$containerId/$user"
- env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=0"
- env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=0"
+ env("SPARK_LOG_URL_STDERR") = s"$baseUrl/stderr?start=-4096"
+ env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096"
}
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v }
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 99c05329b4d73..1c8d7ec57635f 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -76,7 +76,8 @@ private[spark] class YarnClientSchedulerBackend(
("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"),
("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"),
("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"),
- ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue")
+ ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"),
+ ("--py-files", null, "spark.submit.pyFiles")
)
// Warn against the following deprecated environment variables: env var -> suggestion
val deprecatedEnvVars = Map(
@@ -86,7 +87,7 @@ private[spark] class YarnClientSchedulerBackend(
optionTuples.foreach { case (optionName, envVar, sparkProp) =>
if (sc.getConf.contains(sparkProp)) {
extraArgs += (optionName, sc.getConf.get(sparkProp))
- } else if (System.getenv(envVar) != null) {
+ } else if (envVar != null && System.getenv(envVar) != null) {
extraArgs += (optionName, System.getenv(envVar))
if (deprecatedEnvVars.contains(envVar)) {
logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.")
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
index 1ace1a97d5156..33f580aaebdc0 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -115,8 +115,9 @@ private[spark] class YarnClusterSchedulerBackend(
val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://"
val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user"
logDebug(s"Base URL for logs: $baseUrl")
- driverLogs = Some(
- Map("stderr" -> s"$baseUrl/stderr?start=0", "stdout" -> s"$baseUrl/stdout?start=0"))
+ driverLogs = Some(Map(
+ "stderr" -> s"$baseUrl/stderr?start=-4096",
+ "stdout" -> s"$baseUrl/stdout?start=-4096"))
}
}
} catch {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 01d33c9ce9297..837f8d3fa55a7 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -113,7 +113,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll {
Environment.PWD.$()
}
cp should contain(pwdVar)
- cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}")
+ cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}")
cp should not contain (Client.SPARK_JAR)
cp should not contain (Client.APP_JAR)
}
@@ -129,7 +129,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll {
val tempDir = Utils.createTempDir()
try {
- client.prepareLocalResources(tempDir.getAbsolutePath())
+ client.prepareLocalResources(tempDir.getAbsolutePath(), Nil)
sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER))
// The non-local path should be propagated by name only, since it will end up in the app's
@@ -151,6 +151,25 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll {
}
}
+ test("Cluster path translation") {
+ val conf = new Configuration()
+ val sparkConf = new SparkConf()
+ .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar")
+ .set("spark.yarn.config.gatewayPath", "/localPath")
+ .set("spark.yarn.config.replacementPath", "/remotePath")
+
+ Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath")
+ Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be (
+ "/remotePath/1:/remotePath/2")
+
+ val env = new MutableHashMap[String, String]()
+ Client.populateClasspath(null, conf, sparkConf, env, false,
+ extraClassPath = Some("/localPath/my1.jar"))
+ val cp = classpath(env)
+ cp should contain ("/remotePath/spark.jar")
+ cp should contain ("/remotePath/my1.jar")
+ }
+
object Fixtures {
val knownDefYarnAppCP: Seq[String] =
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 93d587d0cb36a..335e966519c7c 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -56,6 +56,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
""".stripMargin
private val TEST_PYFILE = """
+ |import mod1, mod2
|import sys
|from operator import add
|
@@ -67,7 +68,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
| sc = SparkContext(conf=SparkConf())
| status = open(sys.argv[1],'w')
| result = "failure"
- | rdd = sc.parallelize(range(10))
+ | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func())
| cnt = rdd.count()
| if cnt == 10:
| result = "success"
@@ -76,6 +77,11 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
| sc.stop()
""".stripMargin
+ private val TEST_PYMODULE = """
+ |def func():
+ | return 42
+ """.stripMargin
+
private var yarnCluster: MiniYARNCluster = _
private var tempDir: File = _
private var fakeSparkJar: File = _
@@ -124,7 +130,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
- hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR)
+ hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR)
assert(hadoopConfDir.mkdir())
File.createTempFile("token", ".txt", hadoopConfDir)
}
@@ -151,26 +157,12 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
}
}
- // Enable this once fix SPARK-6700
- test("run Python application in yarn-cluster mode") {
- val primaryPyFile = new File(tempDir, "test.py")
- Files.write(TEST_PYFILE, primaryPyFile, UTF_8)
- val pyFile = new File(tempDir, "test2.py")
- Files.write(TEST_PYFILE, pyFile, UTF_8)
- var result = File.createTempFile("result", null, tempDir)
+ test("run Python application in yarn-client mode") {
+ testPySpark(true)
+ }
- // The sbt assembly does not include pyspark / py4j python dependencies, so we need to
- // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala.
- val sparkHome = sys.props("spark.test.home")
- val extraConf = Map(
- "spark.executorEnv.SPARK_HOME" -> sparkHome,
- "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome)
-
- runSpark(false, primaryPyFile.getAbsolutePath(),
- sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()),
- appArgs = Seq(result.getAbsolutePath()),
- extraConf = extraConf)
- checkResult(result)
+ test("run Python application in yarn-cluster mode") {
+ testPySpark(false)
}
test("user class path first in client mode") {
@@ -188,6 +180,33 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher
checkResult(result)
}
+ private def testPySpark(clientMode: Boolean): Unit = {
+ val primaryPyFile = new File(tempDir, "test.py")
+ Files.write(TEST_PYFILE, primaryPyFile, UTF_8)
+
+ val moduleDir =
+ if (clientMode) {
+ // In client-mode, .py files added with --py-files are not visible in the driver.
+ // This is something that the launcher library would have to handle.
+ tempDir
+ } else {
+ val subdir = new File(tempDir, "pyModules")
+ subdir.mkdir()
+ subdir
+ }
+ val pyModule = new File(moduleDir, "mod1.py")
+ Files.write(TEST_PYMODULE, pyModule, UTF_8)
+
+ val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir)
+ val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",")
+ val result = File.createTempFile("result", null, tempDir)
+
+ runSpark(clientMode, primaryPyFile.getAbsolutePath(),
+ sparkArgs = Seq("--py-files", pyFiles),
+ appArgs = Seq(result.getAbsolutePath()))
+ checkResult(result)
+ }
+
private def testUseClassPathFirst(clientMode: Boolean): Unit = {
// Create a jar file that contains a different version of "test.resource".
val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
@@ -357,7 +376,7 @@ private object YarnClusterDriver extends Logging with Matchers {
new URL(urlStr)
val containerId = YarnSparkHadoopUtil.get.getContainerId
val user = Utils.getCurrentUserName()
- assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0"))
+ assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096"))
}
}
|