Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(CDAP-16870) Fix PySpark support for Spark 2.1.3+ #12316

Merged
merged 1 commit into from Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -60,6 +60,8 @@
public final class SparkRuntimeUtils {

public static final String CDAP_SPARK_EXECUTION_SERVICE_URI = "CDAP_SPARK_EXECUTION_SERVICE_URI";
public static final String PYSPARK_PORT_FILE_NAME = "cdap.py4j.gateway.port.txt";
public static final String PYSPARK_SECRET_FILE_NAME = "cdap.py4j.gateway.secret.txt";

private static final String LOCALIZED_RESOURCES = "spark.cdap.localized.resources";
private static final Logger LOG = LoggerFactory.getLogger(SparkRuntimeUtils.class);
Expand Down
Expand Up @@ -190,13 +190,9 @@ public void run() {
}

// Otherwise start the gateway server using reflection. Also write the port number to a local file
Path portFile = Paths.get("cdap.py4j.gateway.port.txt");
try {
final Object server = classLoader.loadClass(SparkPythonUtil.class.getName())
.getMethod("startPy4jGateway", Path.class).invoke(null, portFile);

log(logger, "info", "Py4j GatewayServer started, listening at port {}",
new String(Files.readAllBytes(portFile), StandardCharsets.UTF_8));
.getMethod("startPy4jGateway", Path.class).invoke(null, Paths.get(System.getProperty("user.dir")));

return new Runnable() {
@Override
Expand Down

This file was deleted.

Expand Up @@ -16,22 +16,51 @@

package io.cdap.cdap.app.runtime.spark.python;

import io.cdap.cdap.app.runtime.spark.SparkRuntimeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import py4j.CallbackClient;
import py4j.GatewayServer;
import py4j.Py4JNetworkException;

import java.io.IOException;
import java.lang.reflect.Field;
import java.net.InetAddress;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;

/**
* Utility class to provide methods for PySpark integration.
*/
@SuppressWarnings("unused")
public final class SparkPythonUtil extends AbstractSparkPythonUtil {
public final class SparkPythonUtil {

private static final Logger LOG = LoggerFactory.getLogger(SparkPythonUtil.class);

/**
* Starts a Py4j gateway server.
*
* @param dir the local directory for writing information for the gateway server, such as port and auth token.
* @return the gateway server
* @throws IOException if failed to start the server or failed to write out the port.
*/
public static GatewayServer startPy4jGateway(Path dir) throws IOException {
GatewayServer server = new GatewayServer(null, 0);
try {
server.start();
} catch (Py4JNetworkException e) {
throw new IOException(e);
}

// Write the port number in string form to the port file
Files.write(dir.resolve(SparkRuntimeUtils.PYSPARK_PORT_FILE_NAME),
Integer.toString(server.getListeningPort()).getBytes(StandardCharsets.UTF_8));

LOG.debug("Py4j Gateway server started at port {}", server.getListeningPort());
return server;
}

/**
* Updates the python callback port in the {@link GatewayServer}.
*/
Expand Down
Expand Up @@ -16,13 +16,73 @@

package io.cdap.cdap.app.runtime.spark.python;

import io.cdap.cdap.app.runtime.spark.SparkRuntimeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import py4j.GatewayServer;
import py4j.Py4JNetworkException;

import java.io.IOException;
import java.nio.channels.ByteChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.EnumSet;

/**
* Utility class to provide methods for PySpark integration.
*/
@SuppressWarnings("unused")
public final class SparkPythonUtil extends AbstractSparkPythonUtil {
public final class SparkPythonUtil {

private static final Logger LOG = LoggerFactory.getLogger(SparkPythonUtil.class);

/**
* Starts a Py4j gateway server.
*
* @param dir the local directory for writing information for the gateway server, such as port and auth token.
* @return the gateway server
* @throws IOException if failed to start the server or failed to write out the port.
*/
public static GatewayServer startPy4jGateway(Path dir) throws IOException {
// 256 bits secret
byte[] secret = new byte[256 / 8];
new SecureRandom().nextBytes(secret);
String authToken = Base64.getEncoder().encodeToString(secret);

GatewayServer server = new GatewayServer.GatewayServerBuilder()
.javaPort(0)
.authToken(authToken)
.build();
try {
server.start();
} catch (Py4JNetworkException e) {
throw new IOException(e);
}

// Write the port number in string form to the port file
Files.write(dir.resolve(SparkRuntimeUtils.PYSPARK_PORT_FILE_NAME),
Integer.toString(server.getListeningPort()).getBytes(StandardCharsets.UTF_8));
// Write the auth token
Path secretFile = dir.resolve(SparkRuntimeUtils.PYSPARK_SECRET_FILE_NAME);
try (ByteChannel channel = Files.newByteChannel(dir.resolve(SparkRuntimeUtils.PYSPARK_SECRET_FILE_NAME),
EnumSet.of(StandardOpenOption.CREATE, StandardOpenOption.WRITE),
PosixFilePermissions.asFileAttribute(
EnumSet.of(PosixFilePermission.OWNER_READ,
PosixFilePermission.OWNER_WRITE)))) {
channel.write(StandardCharsets.UTF_8.encode(authToken));
}

LOG.debug("Py4j Gateway server started at port {} with auth token of {} bits",
server.getListeningPort(), secret.length);

return server;
}

/**
* Updates the python callback port in the {@link GatewayServer}.
Expand Down
Binary file modified cdap-spark-core2_2.11/src/main/resources/pyspark/py4j-src.zip
Binary file not shown.
Binary file modified cdap-spark-core2_2.11/src/main/resources/pyspark/pyspark.zip
Binary file not shown.
26 changes: 18 additions & 8 deletions cdap-spark-python/src/main/resources/cdap/pyspark/context.py
Expand Up @@ -15,9 +15,8 @@
# the License.

import os
from threading import RLock

from py4j.java_gateway import java_import, JavaGateway
from threading import RLock

try:
# The JavaObject is only needed for the Spark 1 hack. Failure to import in future Spark/py4j version is ok.
Expand Down Expand Up @@ -157,7 +156,7 @@ class SparkRuntimeContext(object):
_runtimeContext = None
_onDemandCallback = False

def __init__(self, gatewayPort = None, driver = True):
def __init__(self, gatewayPort = None, gatewaySecret = None, driver = True):
# If the gateway port file is there, always use it. This is for distributed mode.
if os.path.isfile("cdap.py4j.gateway.port.txt"):
fd = open("cdap.py4j.gateway.port.txt", "r")
Expand All @@ -169,21 +168,31 @@ def __init__(self, gatewayPort = None, driver = True):
else:
raise Exception("Cannot determine Py4j GatewayServer port")

self.__class__.__ensureGatewayInit(gatewayPort, driver)
# Load the gateway secret if available
if os.path.isfile("cdap.py4j.gateway.secret.txt"):
fd = open("cdap.py4j.gateway.secret.txt", "r")
gatewaySecret = fd.read()
fd.close()
elif gatewaySecret is None:
if "PYSPARK_GATEWAY_SECRET" in os.environ:
gatewaySecret = os.environ["PYSPARK_GATEWAY_SECRET"]

self.__class__.__ensureGatewayInit(gatewayPort, gatewaySecret, driver)
self._allowCallback = driver
self._gatewayPort = gatewayPort
self._gatewaySecret = gatewaySecret

def __getstate__(self):
return { "gatewayPort" : self._gatewayPort }
return { "gatewayPort" : self._gatewayPort , "gatewaySecret" : self._gatewaySecret}

def __setstate__(self, state):
self.__init__(state["gatewayPort"], False)
self.__init__(state["gatewayPort"], state["gatewaySecret"], False)

def getSparkRuntimeContext(self):
return self.__class__._runtimeContext

@classmethod
def __ensureGatewayInit(cls, gatewayPort, driver):
def __ensureGatewayInit(cls, gatewayPort, gatewaySecret, driver):
with cls._lock:
if not cls._gateway:
# Spark 1.6 and Spark 2 are using later verions of py4j (0.9 and 0.10+ respectively),
Expand All @@ -194,7 +203,8 @@ def __ensureGatewayInit(cls, gatewayPort, driver):
from py4j.java_gateway import GatewayParameters, CallbackServerParameters
callbackServerParams = CallbackServerParameters(port = 0, daemonize = True,
daemonize_connections = True) if driver else None
gateway = JavaGateway(gateway_parameters = GatewayParameters(port = gatewayPort, auto_convert = True),
gateway = JavaGateway(gateway_parameters = GatewayParameters(port = gatewayPort, auto_convert = True,
auth_token = gatewaySecret),
callback_server_parameters = callbackServerParams)
except:
from py4j.java_gateway import CallbackServer, GatewayClient
Expand Down
Expand Up @@ -35,6 +35,7 @@
import io.cdap.cdap.api.metrics.MetricTimeSeries;
import io.cdap.cdap.common.conf.Constants;
import io.cdap.cdap.common.test.AppJarHelper;
import io.cdap.cdap.common.utils.DirUtils;
import io.cdap.cdap.common.utils.Tasks;
import io.cdap.cdap.internal.DefaultId;
import io.cdap.cdap.proto.NamespaceMeta;
Expand All @@ -44,6 +45,7 @@
import io.cdap.cdap.proto.id.DatasetId;
import io.cdap.cdap.proto.id.NamespaceId;
import io.cdap.cdap.spark.app.CharCountProgram;
import io.cdap.cdap.spark.app.PythonSpark2;
import io.cdap.cdap.spark.app.ScalaCharCountProgram;
import io.cdap.cdap.spark.app.ScalaCrossNSProgram;
import io.cdap.cdap.spark.app.ScalaSparkServiceProgram;
Expand All @@ -62,7 +64,9 @@
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
Expand All @@ -73,6 +77,7 @@
import java.net.URI;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -91,6 +96,9 @@ public class Spark2Test extends TestBaseWithSpark2 {
@ClassRule
public static final TestConfiguration CONFIG = new TestConfiguration(Constants.Explore.EXPLORE_ENABLED, false);

@ClassRule
public static final TemporaryFolder TEMP_FOLDER = new TemporaryFolder();

private static final String TEST_STRING_1 = "persisted data";
private static final String TEST_STRING_2 = "distributed systems";

Expand Down Expand Up @@ -254,6 +262,57 @@ public void testSparkWithLocalFiles() throws Exception {
SparkAppUsingLocalFiles.ScalaSparkUsingLocalFiles.class.getSimpleName(), "scala");
}

@Test
public void testPySpark() throws Exception {
ApplicationManager appManager = deploy(NamespaceId.DEFAULT, Spark2TestApp.class);

// Write some data to a local file
File inputFile = TEMP_FOLDER.newFile();
try (BufferedWriter writer = Files.newBufferedWriter(inputFile.toPath(), StandardCharsets.UTF_8)) {
for (int i = 0; i < 100; i++) {
writer.write("Event " + i);
writer.newLine();
}
}

File outputDir = new File(TMP_FOLDER.newFolder(), "output");
appManager.getSparkManager(PythonSpark2.class.getSimpleName())
.startAndWaitForRun(ImmutableMap.of("input.file", inputFile.getAbsolutePath(),
"output.path", outputDir.getAbsolutePath()),
ProgramRunStatus.COMPLETED, 2, TimeUnit.MINUTES);

// Verify the result
File resultFile = DirUtils.listFiles(outputDir).stream()
.filter(f -> !f.getName().endsWith(".crc"))
.filter(f -> !f.getName().startsWith("_SUCCESS"))
.findFirst()
.orElse(null);
Assert.assertNotNull(resultFile);

List<String> lines = Files.readAllLines(resultFile.toPath(), StandardCharsets.UTF_8);
Assert.assertFalse(lines.isEmpty());

// Expected only even number
int count = 0;
for (String line : lines) {
line = line.trim();
if (!line.isEmpty()) {
Assert.assertEquals("Event " + count, line);
count += 2;
}
}

Assert.assertEquals(100, count);

final Map<String, String> tags = ImmutableMap.of(
Constants.Metrics.Tag.NAMESPACE, NamespaceId.DEFAULT.getNamespace(),
Constants.Metrics.Tag.APP, Spark2TestApp.class.getSimpleName(),
Constants.Metrics.Tag.SPARK, PythonSpark2.class.getSimpleName());

Tasks.waitFor(100L, () -> getMetricsManager().getTotalMetric(tags, "user.body"),
5, TimeUnit.SECONDS, 100, TimeUnit.MILLISECONDS);
}

private void prepareInputData(DataSetManager<ObjectStore<String>> manager) {
ObjectStore<String> keys = manager.get();
keys.write(Bytes.toBytes(TEST_STRING_1), TEST_STRING_1);
Expand Down