Skip to content

Commit

Permalink
[PIP-193] Support Transform Function with LocalRunner (apache#17445)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Sep 8, 2022
1 parent 7508fe5 commit 327648c
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,11 @@ public void testPulsarSourceLocalRunMultipleInstances() throws Throwable {
}

private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, String className) throws Exception {
testPulsarSinkLocalRun(jarFilePathUrl, parallelism, className, null, null);
}

private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, String className,
String transformFunction, String transformFunctionClassName) throws Exception {
final String namespacePortion = "io";
final String replNamespace = tenant + "/" + namespacePortion;
final String sourceTopic = "persistent://" + replNamespace + "/input";
Expand All @@ -921,6 +926,9 @@ private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, Stri

sinkConfig.setArchive(jarFilePathUrl);
sinkConfig.setParallelism(parallelism);
sinkConfig.setTransformFunction(transformFunction);
sinkConfig.setTransformFunctionClassName(transformFunctionClassName);

int metricsPort = FunctionCommon.findAvailablePort();
@Cleanup
LocalRunner localRunner = LocalRunner.builder()
Expand All @@ -933,6 +941,7 @@ private void testPulsarSinkLocalRun(String jarFilePathUrl, int parallelism, Stri
.tlsHostNameVerificationEnabled(false)
.brokerServiceUrl(pulsar.getBrokerServiceUrlTls())
.connectorsDirectory(workerConfig.getConnectorsDirectory())
.functionsDirectory(workerConfig.getFunctionsDirectory())
.metricsPortStart(metricsPort)
.build();

Expand Down Expand Up @@ -1083,6 +1092,12 @@ public void close() throws Exception {
public void testPulsarSinkStatsByteBufferType() throws Throwable {
runWithNarClassLoader(() -> testPulsarSinkLocalRun(null, 1, StatsNullSink.class.getName()));
}

//@Test(timeOut = 20000, groups = "builtin")
@Test(groups = "builtin")
public void testPulsarSinkWithFunction() throws Throwable {
testPulsarSinkLocalRun(null, 1, StatsNullSink.class.getName(), "builtin://exclamation", "org.apache.pulsar.functions.api.examples.RecordFunction");
}

public static class TestErrorSink implements Sink<byte[]> {
private Map config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.pulsar.functions;

import static org.apache.commons.lang3.StringUtils.isNotEmpty;
import static org.apache.pulsar.common.functions.Utils.inferMissingArguments;
import com.beust.jcommander.IStringConverter;
import com.beust.jcommander.JCommander;
Expand Down Expand Up @@ -48,6 +49,7 @@
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.Builder;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.apache.pulsar.common.functions.FunctionConfig;
import org.apache.pulsar.common.functions.Utils;
Expand Down Expand Up @@ -92,8 +94,8 @@ public class LocalRunner implements AutoCloseable {
private final String functionsDir;
private final Thread shutdownHook;
private final int instanceLivenessCheck;
private ClassLoader userCodeClassLoader;
private boolean userCodeClassLoaderCreated;
private UserCodeClassLoader userCodeClassLoader;
private UserCodeClassLoader transformFunctionCodeClassLoader;
private RuntimeFactory runtimeFactory;
private HTTPServer metricsServer;

Expand All @@ -102,6 +104,12 @@ public enum RuntimeEnv {
PROCESS
}

@Value
private static class UserCodeClassLoader {
ClassLoader classLoader;
boolean classLoaderCreated;
}

public static class FunctionConfigConverter implements IStringConverter<FunctionConfig> {
@Override
public FunctionConfig convert(String value) {
Expand Down Expand Up @@ -310,16 +318,21 @@ public synchronized void stop() {
runtimeFactory = null;
}

if (userCodeClassLoaderCreated) {
if (userCodeClassLoader instanceof Closeable) {
try {
((Closeable) userCodeClassLoader).close();
} catch (IOException e) {
log.warn("Error closing classloader", e);
}
closeClassLoaderIfneeded(userCodeClassLoader);
userCodeClassLoader = null;
closeClassLoaderIfneeded(transformFunctionCodeClassLoader);
transformFunctionCodeClassLoader = null;
}
}

private static void closeClassLoaderIfneeded(UserCodeClassLoader userCodeClassLoader) {
if (userCodeClassLoader != null && userCodeClassLoader.isClassLoaderCreated()) {
if (userCodeClassLoader.getClassLoader() instanceof Closeable) {
try {
((Closeable) userCodeClassLoader.getClassLoader()).close();
} catch (IOException e) {
log.warn("Error closing classloader", e);
}
userCodeClassLoaderCreated = false;
userCodeClassLoader = null;
}
}
}
Expand All @@ -333,16 +346,18 @@ public void start(boolean blocking) throws Exception {
Runtime.getRuntime().addShutdownHook(shutdownHook);
Function.FunctionDetails functionDetails = null;
String userCodeFile;
String transformFunctionFile = null;
int parallelism;
if (functionConfig != null) {
FunctionConfigUtils.inferMissingArguments(functionConfig, true);
parallelism = functionConfig.getParallelism();
if (functionConfig.getRuntime() == FunctionConfig.Runtime.JAVA) {
userCodeFile = functionConfig.getJar();
ClassLoader functionClassLoader = extractClassLoader(
userCodeClassLoader = extractClassLoader(
userCodeFile, ComponentType.FUNCTION, functionConfig.getClassName());
functionDetails = FunctionConfigUtils.convert(
functionConfig, FunctionConfigUtils.validateJavaFunction(functionConfig, functionClassLoader));
functionConfig,
FunctionConfigUtils.validateJavaFunction(functionConfig, getCurrentOrUserCodeClassLoader()));
} else if (functionConfig.getRuntime() == FunctionConfig.Runtime.GO) {
userCodeFile = functionConfig.getGo();
} else if (functionConfig.getRuntime() == FunctionConfig.Runtime.PYTHON) {
Expand All @@ -352,26 +367,42 @@ public void start(boolean blocking) throws Exception {
}

if (functionDetails == null) {
functionDetails = FunctionConfigUtils.convert(functionConfig,
userCodeClassLoader != null ? userCodeClassLoader :
Thread.currentThread().getContextClassLoader());
functionDetails = FunctionConfigUtils.convert(functionConfig, getCurrentOrUserCodeClassLoader());
}
} else if (sourceConfig != null) {
inferMissingArguments(sourceConfig);
userCodeFile = sourceConfig.getArchive();
parallelism = sourceConfig.getParallelism();
ClassLoader sourceClassLoader = extractClassLoader(
userCodeClassLoader = extractClassLoader(
userCodeFile, ComponentType.SOURCE, sourceConfig.getClassName());
functionDetails = SourceConfigUtils.convert(
sourceConfig, SourceConfigUtils.validateAndExtractDetails(sourceConfig, sourceClassLoader, true));
sourceConfig,
SourceConfigUtils.validateAndExtractDetails(sourceConfig, getCurrentOrUserCodeClassLoader(), true));
} else if (sinkConfig != null) {
inferMissingArguments(sinkConfig);
userCodeFile = sinkConfig.getArchive();
transformFunctionFile = sinkConfig.getTransformFunction();
parallelism = sinkConfig.getParallelism();
ClassLoader sinkClassLoader = extractClassLoader(
userCodeClassLoader = extractClassLoader(
userCodeFile, ComponentType.SINK, sinkConfig.getClassName());
if (isNotEmpty(sinkConfig.getTransformFunction())) {
transformFunctionCodeClassLoader = extractClassLoader(
sinkConfig.getTransformFunction(),
ComponentType.FUNCTION,
sinkConfig.getTransformFunctionClassName());
}

ClassLoader functionClassLoader = null;
if (transformFunctionCodeClassLoader != null) {
functionClassLoader = transformFunctionCodeClassLoader.getClassLoader() == null
? Thread.currentThread().getContextClassLoader()
: transformFunctionCodeClassLoader.getClassLoader();
}

functionDetails = SinkConfigUtils.convert(
sinkConfig, SinkConfigUtils.validateAndExtractDetails(sinkConfig, sinkClassLoader, null, true));
sinkConfig,
SinkConfigUtils.validateAndExtractDetails(sinkConfig, getCurrentOrUserCodeClassLoader(),
functionClassLoader, true));
} else {
throw new IllegalArgumentException("Must specify Function, Source or Sink config");
}
Expand Down Expand Up @@ -401,10 +432,10 @@ public void start(boolean blocking) throws Exception {
&& (runtimeEnv == null || runtimeEnv == RuntimeEnv.THREAD)) {
// By default run java functions as threads
startThreadedMode(functionDetails, parallelism, instanceIdOffset, serviceUrl,
stateStorageServiceUrl, authConfig, userCodeFile);
stateStorageServiceUrl, authConfig, userCodeFile, transformFunctionFile);
} else {
startProcessMode(functionDetails, parallelism, instanceIdOffset, serviceUrl,
stateStorageServiceUrl, authConfig, userCodeFile);
stateStorageServiceUrl, authConfig, userCodeFile, transformFunctionFile);
}
local.addAll(spawners);
}
Expand All @@ -426,15 +457,22 @@ public void start(boolean blocking) throws Exception {
}
}

private ClassLoader extractClassLoader(String userCodeFile, ComponentType componentType, String className)
private ClassLoader getCurrentOrUserCodeClassLoader() {
return userCodeClassLoader == null || userCodeClassLoader.getClassLoader() == null
? Thread.currentThread().getContextClassLoader()
: userCodeClassLoader.getClassLoader();
}

private UserCodeClassLoader extractClassLoader(String userCodeFile, ComponentType componentType, String className)
throws IOException, URISyntaxException {
userCodeClassLoader = userCodeFile != null ? isBuiltIn(userCodeFile, componentType) : null;
if (userCodeClassLoader == null) {
ClassLoader classLoader = userCodeFile != null ? isBuiltIn(userCodeFile, componentType) : null;
boolean classLoaderCreated = false;
if (classLoader == null) {
if (userCodeFile != null && Utils.isFunctionPackageUrlSupported(userCodeFile)) {
File file = FunctionCommon.extractFileFromPkgURL(userCodeFile);
userCodeClassLoader = FunctionCommon.getClassLoaderFromPackage(
classLoader = FunctionCommon.getClassLoaderFromPackage(
componentType, className, file, narExtractionDirectory);
userCodeClassLoaderCreated = true;
classLoaderCreated = true;
} else if (userCodeFile != null) {
File file = new File(userCodeFile);
if (!file.exists()) {
Expand All @@ -454,9 +492,9 @@ private ClassLoader extractClassLoader(String userCodeFile, ComponentType compon
}
throw new RuntimeException(errorMsg + " (" + userCodeFile + ") does not exist");
}
userCodeClassLoader = FunctionCommon.getClassLoaderFromPackage(
classLoader = FunctionCommon.getClassLoaderFromPackage(
componentType, className, file, narExtractionDirectory);
userCodeClassLoaderCreated = true;
classLoaderCreated = true;
} else {
if (!(runtimeEnv == null || runtimeEnv == RuntimeEnv.THREAD)) {
String errorMsg;
Expand All @@ -477,15 +515,13 @@ private ClassLoader extractClassLoader(String userCodeFile, ComponentType compon
}
}
}
return userCodeClassLoader == null
? Thread.currentThread().getContextClassLoader()
: userCodeClassLoader;
return new UserCodeClassLoader(classLoader, classLoaderCreated);
}

private void startProcessMode(org.apache.pulsar.functions.proto.Function.FunctionDetails functionDetails,
int parallelism, int instanceIdOffset, String serviceUrl,
String stateStorageServiceUrl, AuthenticationConfig authConfig,
String userCodeFile) throws Exception {
String userCodeFile, String transformFunctionFile) throws Exception {
SecretsProviderConfigurator secretsProviderConfigurator = getSecretsProviderConfigurator();
runtimeFactory = new ProcessRuntimeFactory(
serviceUrl,
Expand Down Expand Up @@ -532,7 +568,7 @@ private void startProcessMode(org.apache.pulsar.functions.proto.Function.Functio
instanceConfig,
userCodeFile,
null,
null,
transformFunctionFile,
null,
runtimeFactory,
instanceLivenessCheck);
Expand Down Expand Up @@ -568,7 +604,7 @@ public void run() {
private void startThreadedMode(org.apache.pulsar.functions.proto.Function.FunctionDetails functionDetails,
int parallelism, int instanceIdOffset, String serviceUrl,
String stateStorageServiceUrl, AuthenticationConfig authConfig,
String userCodeFile) throws Exception {
String userCodeFile, String transformFunctionFile) throws Exception {

if (metricsPortStart != null) {
if (metricsPortStart < 0 || metricsPortStart > 65535) {
Expand Down Expand Up @@ -599,8 +635,8 @@ private void startThreadedMode(org.apache.pulsar.functions.proto.Function.Functi

ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader();
try {
if (userCodeClassLoader != null) {
Thread.currentThread().setContextClassLoader(userCodeClassLoader);
if (userCodeClassLoader != null && userCodeClassLoader.getClassLoader() != null) {
Thread.currentThread().setContextClassLoader(userCodeClassLoader.getClassLoader());
}
runtimeFactory = new ThreadRuntimeFactory("LocalRunnerThreadGroup",
serviceUrl,
Expand All @@ -620,6 +656,7 @@ private void startThreadedMode(org.apache.pulsar.functions.proto.Function.Functi
// TODO: correctly implement function version and id
instanceConfig.setFunctionVersion(UUID.randomUUID().toString());
instanceConfig.setFunctionId(UUID.randomUUID().toString());
instanceConfig.setTransformFunctionId(UUID.randomUUID().toString());
instanceConfig.setInstanceId(i + instanceIdOffset);
instanceConfig.setMaxBufferedTuples(1024);
if (metricsPortStart != null) {
Expand All @@ -638,7 +675,7 @@ private void startThreadedMode(org.apache.pulsar.functions.proto.Function.Functi
instanceConfig,
userCodeFile,
null,
null,
transformFunctionFile,
null,
runtimeFactory,
instanceLivenessCheck);
Expand Down

0 comments on commit 327648c

Please sign in to comment.