diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 19f4c95fcad74..7670a870e860a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -613,7 +613,16 @@ private[spark] object PythonRDD extends Logging { setDaemon(true) override def run() { try { - val sock = serverSocket.accept() + var sock: Socket = null + try { + sock = serverSocket.accept() + } catch { + case e: SocketTimeoutException => + // there is a small chance that the client had connected, so retry + logWarning("Timed out after 4 seconds, retry once") + serverSocket.setSoTimeout(10) + sock = serverSocket.accept() + } val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { writeIteratorToStream(items, out) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c337a43c8a7fc..b66d32c9d5b39 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -113,6 +113,7 @@ def _parse_memory(s): def _load_from_socket(port, serializer): sock = socket.socket() + sock.settimeout(5) try: sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536)