Skip to content

Commit

Permalink
[serving] Make yaml support optional
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Dec 12, 2022
1 parent 6aec9be commit 2e0d9ad
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
2 changes: 1 addition & 1 deletion serving/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies {
api "io.netty:netty-transport-native-kqueue:${netty_version}:osx-aarch_64"
api "io.netty:netty-transport-native-kqueue:${netty_version}:osx-x86_64"

implementation "org.yaml:snakeyaml:${snakeyaml_version}"
//noinspection GradlePackageUpdate
implementation "commons-cli:commons-cli:${commons_cli_version}"
implementation "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}"
Expand All @@ -30,6 +29,7 @@ dependencies {
}
runtimeOnly project(":engines:python")

testRuntimeOnly "org.yaml:snakeyaml:${snakeyaml_version}"
testImplementation("org.testng:testng:${testng_version}") {
exclude group: "junit", module: "junit"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import ai.djl.serving.wlm.util.WlmConfigManager;
import ai.djl.serving.workflow.WorkflowExpression.Item;
import ai.djl.serving.workflow.function.WorkflowFunction;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
Expand All @@ -28,15 +30,12 @@
import com.google.gson.JsonParseException;
import com.google.gson.annotations.SerializedName;

import org.yaml.snakeyaml.Yaml;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
Expand Down Expand Up @@ -69,7 +68,6 @@ public class WorkflowDefinition {
int maxBatchDelay;
int batchSize;

private static final Yaml YAML = new Yaml();
public static final Gson GSON =
JsonUtils.builder()
.registerTypeAdapter(ModelInfo.class, new ModelDefinitionDeserializer())
Expand All @@ -85,7 +83,7 @@ public class WorkflowDefinition {
* @throws IOException if it fails to load the file for parsing
*/
public static WorkflowDefinition parse(Path path) throws IOException {
return parse(path.toUri(), Files.newBufferedReader(path));
return parse(path.toUri(), Files.newInputStream(path));
}

/**
Expand All @@ -94,9 +92,10 @@ public static WorkflowDefinition parse(Path path) throws IOException {
* @param uri the uri of the file
* @param input the input
* @return the parsed {@link WorkflowDefinition}
* @throws IOException if failed to read from input
*/
public static WorkflowDefinition parse(URI uri, InputStream input) {
return parse(uri, new InputStreamReader(input, StandardCharsets.UTF_8));
public static WorkflowDefinition parse(URI uri, InputStream input) throws IOException {
return parse(uri, Utils.toString(input));
}

/**
Expand All @@ -106,12 +105,21 @@ public static WorkflowDefinition parse(URI uri, InputStream input) {
* @param input the input
* @return the parsed {@link WorkflowDefinition}
*/
public static WorkflowDefinition parse(URI uri, Reader input) {
public static WorkflowDefinition parse(URI uri, String input) {
String fileName = Objects.requireNonNull(uri.toString());
if (fileName.endsWith(".yml") || fileName.endsWith(".yaml")) {
Object yaml = YAML.load(input);
String asJson = GSON.toJson(yaml);
return GSON.fromJson(asJson, WorkflowDefinition.class);
try {
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
Class<?> clazz = Class.forName("org.yaml.snakeyaml.Yaml", true, cl);
Constructor<?> constructor = clazz.getConstructor();
Method method = clazz.getMethod("load", String.class);
Object obj = constructor.newInstance();
Object yaml = method.invoke(obj, input);
String asJson = GSON.toJson(yaml);
return GSON.fromJson(asJson, WorkflowDefinition.class);
} catch (ReflectiveOperationException e) {
throw new IllegalArgumentException("Yaml parsing is not supported.", e);
}
} else if (fileName.endsWith(".json")) {
return GSON.fromJson(input, WorkflowDefinition.class);
} else {
Expand Down
13 changes: 4 additions & 9 deletions serving/src/test/java/ai/djl/serving/WorkflowTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.ExecutionException;

public class WorkflowTest {

Expand All @@ -52,29 +51,25 @@ public void beforeAll() throws IOException {
}

@Test
public void testJson()
throws IOException, BadWorkflowException, ExecutionException, InterruptedException {
public void testJson() throws IOException, BadWorkflowException {
Path workflowFile = Paths.get("src/test/resources/workflows/basic.json");
runWorkflow(workflowFile, zeroInput);
}

@Test
public void testYaml()
throws IOException, BadWorkflowException, ExecutionException, InterruptedException {
public void testYaml() throws IOException, BadWorkflowException {
Path workflowFile = Paths.get("src/test/resources/workflows/basic.yaml");
runWorkflow(workflowFile, zeroInput);
}

@Test
public void testCriteria()
throws IOException, BadWorkflowException, ExecutionException, InterruptedException {
public void testCriteria() throws IOException, BadWorkflowException {
Path workflowFile = Paths.get("src/test/resources/workflows/criteria.json");
runWorkflow(workflowFile, zeroInput);
}

@Test
public void testFunctions()
throws IOException, BadWorkflowException, ExecutionException, InterruptedException {
public void testFunctions() throws IOException, BadWorkflowException {
Path workflowFile = Paths.get("src/test/resources/workflows/functions.json");
runWorkflow(workflowFile, zeroInput);
}
Expand Down

0 comments on commit 2e0d9ad

Please sign in to comment.