Skip to content

Commit

Permalink
Merge pull request apache#2 from apache/master
Browse files Browse the repository at this point in the history
merge upstream changes
  • Loading branch information
nchammas committed Aug 4, 2014
2 parents aa5b4b5 + 8e7d5ba commit 9da347f
Show file tree
Hide file tree
Showing 53 changed files with 1,976 additions and 264 deletions.
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import java.io.File
import java.net.Socket

import scala.collection.JavaConversions._
import scala.collection.mutable
Expand Down Expand Up @@ -102,10 +103,10 @@ class SparkEnv (
}

private[spark]
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers(key).stop()
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
}
Expand Down
Expand Up @@ -62,8 +62,8 @@ private[spark] class PythonRDD(
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
val worker: Socket = env.createPythonWorker(pythonExec,
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
Expand Down Expand Up @@ -241,7 +241,7 @@ private[spark] class PythonRDD(
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
Expand Down Expand Up @@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {

/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* This function is outdated, PySpark does not use it anymore
*/
@deprecated
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
Expand Down
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.api.python

import java.io.{DataInputStream, InputStream, OutputStreamWriter}
import java.lang.Runtime
import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}

import scala.collection.mutable
import scala.collection.JavaConversions._

import org.apache.spark._
Expand All @@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()

var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()

val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
Expand All @@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {

def createSocket(): Socket = {
val socket = new Socket(daemonHost, daemonPort)
val pid = new DataInputStream(socket.getInputStream).readInt()
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
daemonWorkers.put(socket, pid)
socket
}

synchronized {
// Start the daemon if it hasn't been started
startDaemon()

// Attempt to connect, restart and retry once if it fails
try {
val socket = new Socket(daemonHost, daemonPort)
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
if (launchStatus != 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
socket
createSocket()
} catch {
case exc: SocketException =>
logWarning("Failed to open socket to Python daemon:", exc)
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
createSocket()
}
}
}
Expand Down Expand Up @@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
return serverSocket.accept()
val socket = serverSocket.accept()
simpleWorkers.put(socket, worker)
return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
Expand Down Expand Up @@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
if (useDaemon) {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}

daemon = null
daemonPort = 0
daemon = null
daemonPort = 0
} else {
simpleWorkers.mapValues(_.destroy())
}
}
}

def stop() {
stopDaemon()
}

def stopWorker(worker: Socket) {
if (useDaemon) {
if (daemon != null) {
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.flush()
daemon.getOutputStream.flush()
}
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}
}

private object PythonWorkerFactory {
Expand Down
Expand Up @@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
/**
* Calling reset to avoid memory leak:
* http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
* But only call it every 10,000th time to avoid bloated serialization streams (when
* But only call it every 100th time to avoid bloated serialization streams (when
* the stream 'resets' object class descriptions have to be re-written)
*/
def writeObject[T: ClassTag](t: T): SerializationStream = {
objOut.writeObject(t)
counter += 1
if (counterReset > 0 && counter >= counterReset) {
objOut.reset()
counter = 0
} else {
counter += 1
}
this
}
Expand Down
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.util.collection

import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException}
import java.io._
import java.util.Comparator

import scala.collection.BufferedIterator
Expand All @@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams

import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.Serializer
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator

Expand Down Expand Up @@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C](

// Flush the disk writer's contents to disk, and update relevant variables
def flush() = {
writer.commitAndClose()
val bytesWritten = writer.bytesWritten
val w = writer
writer = null
w.commitAndClose()
val bytesWritten = w.bytesWritten
batchSizes.append(bytesWritten)
_diskBytesSpilled += bytesWritten
objectsWritten = 0
}

var success = false
try {
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
Expand All @@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C](

if (objectsWritten == serializerBatchSize) {
flush()
writer.close()
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
}
}
if (objectsWritten > 0) {
flush()
} else if (writer != null) {
val w = writer
writer = null
w.revertPartialWritesAndClose()
}
success = true
} finally {
// Partial failures cannot be tolerated; do not revert partial writes
writer.close()
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
if (writer != null) {
writer.revertPartialWritesAndClose()
}
if (file.exists()) {
file.delete()
}
}
}

currentMap = new SizeTrackingAppendOnlyMap[K, C]
Expand Down Expand Up @@ -389,27 +404,51 @@ class ExternalAppendOnlyMap[K, V, C](
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
extends Iterator[(K, C)] {
private val fileStream = new FileInputStream(file)
private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
extends Iterator[(K, C)]
{
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
assert(file.length() == batchOffsets(batchOffsets.length - 1))

private var batchIndex = 0 // Which batch we're in
private var fileStream: FileInputStream = null

// An intermediate stream that reads from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
private var batchStream = nextBatchStream()
private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
private var deserializeStream = ser.deserializeStream(compressedStream)
private var deserializeStream = nextBatchStream()
private var nextItem: (K, C) = null
private var objectsRead = 0

/**
* Construct a stream that reads only from the next batch.
*/
private def nextBatchStream(): InputStream = {
if (batchSizes.length > 0) {
ByteStreams.limit(bufferedStream, batchSizes.remove(0))
private def nextBatchStream(): DeserializationStream = {
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
// we're still in a valid batch.
if (batchIndex < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}

val start = batchOffsets(batchIndex)
fileStream = new FileInputStream(file)
fileStream.getChannel.position(start)
batchIndex += 1

val end = batchOffsets(batchIndex)

assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))

val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
ser.deserializeStream(compressedStream)
} else {
// No more batches left
bufferedStream
cleanup()
null
}
}

Expand All @@ -424,10 +463,8 @@ class ExternalAppendOnlyMap[K, V, C](
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
objectsRead += 1
if (objectsRead == serializerBatchSize) {
batchStream = nextBatchStream()
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
deserializeStream = ser.deserializeStream(compressedStream)
objectsRead = 0
deserializeStream = nextBatchStream()
}
item
} catch {
Expand All @@ -439,6 +476,9 @@ class ExternalAppendOnlyMap[K, V, C](

override def hasNext: Boolean = {
if (nextItem == null) {
if (deserializeStream == null) {
return false
}
nextItem = readNextItem()
}
nextItem != null
Expand All @@ -455,7 +495,11 @@ class ExternalAppendOnlyMap[K, V, C](

// TODO: Ensure this gets called even if the iterator isn't drained.
private def cleanup() {
deserializeStream.close()
batchIndex = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
deserializeStream = null
fileStream = null
ds.close()
file.delete()
}
}
Expand Down

0 comments on commit 9da347f

Please sign in to comment.