Skip to content

Commit

Permalink
[python] Fixes build error
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Nov 2, 2023
1 parent a58a735 commit c5cc7ef
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 90 deletions.
15 changes: 5 additions & 10 deletions engines/python/src/main/java/ai/djl/python/engine/PyModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDManager;
Expand Down Expand Up @@ -155,15 +154,11 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
// Handle TRT-LLM
boolean isTrtLlmBackend = "TRT-LLM".equals(Utils.getenv("LMI_BACKEND"));
if (isTrtLlmBackend) {
try {
Optional<Path> trtLlmRepoDir = TrtLLMUtils.initTrtLlmModel(this);
if (trtLlmRepoDir.isPresent()) {
String modelId = trtLlmRepoDir.get().toAbsolutePath().toString();
setProperty("model_id", modelId);
pyEnv.addParameter("model_id", modelId);
}
} catch (ModelException e) {
throw new RuntimeException(e);
Optional<Path> trtLlmRepoDir = TrtLlmUtils.initTrtLlmModel(this);
if (trtLlmRepoDir.isPresent()) {
String modelId = trtLlmRepoDir.get().toAbsolutePath().toString();
setProperty("model_id", modelId);
pyEnv.addParameter("model_id", modelId);
}
}

Expand Down
167 changes: 87 additions & 80 deletions engines/python/src/main/java/ai/djl/python/engine/TrtLLMUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
*/
package ai.djl.python.engine;

import ai.djl.ModelException;
import ai.djl.engine.EngineException;
import ai.djl.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
Expand All @@ -24,93 +27,97 @@
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrtLLMUtils {
final class TrtLlmUtils {

private static final Logger logger = LoggerFactory.getLogger(TrtLLMUtils.class);
private static final Logger logger = LoggerFactory.getLogger(TrtLlmUtils.class);

public static Optional<Path> initTrtLlmModel(PyModel model) throws ModelException, IOException {
// check if downloadS3Dir or local model path is a trt-llm repo
boolean isTrtLlmRepo = isValidTrtLlmModelRepo(model);
if (!isTrtLlmRepo) {
return Optional.of(buildTrtLlmArtifacts(model));
}
return Optional.empty();
}
private TrtLlmUtils() {}

public static Path buildTrtLlmArtifacts(PyModel model) throws ModelException, IOException {
logger.info("Converting model to TensorRT-LLM artifacts");
Path trtLlmRepoDir = Paths.get("/tmp/tensorrtllm");
String modelId = model.getProperty("model_id");
// invoke trt-llm build script
List<String> commandList = new ArrayList<>();
commandList.add("python");
commandList.add("/opt/djl/partition/trt_llm_partition.py");
commandList.add("--properties_dir");
commandList.add(model.getModelPath().toAbsolutePath().toString());
commandList.add("--trt_llm_model_repo");
commandList.add(trtLlmRepoDir.toAbsolutePath().toString());
if (modelId != null) {
commandList.add("--model_path");
commandList.add(modelId);
}
try {
Process exec = new ProcessBuilder(commandList).redirectErrorStream(true).start();
String logOutput;
try (InputStream is = exec.getInputStream()) {
logOutput = Utils.toString(is);
}
int exitCode = exec.waitFor();
if (0 != exitCode || logOutput.startsWith("ERROR ")) {
logger.error(logOutput);
throw new EngineException("Download model failed: " + logOutput);
} else {
logger.info(logOutput);
}
} catch (IOException | InterruptedException e) {
throw new ModelException("Failed to build TensorRT-LLM artifacts", e);
static Optional<Path> initTrtLlmModel(PyModel model) throws IOException {
// check if downloadS3Dir or local model path is a trt-llm repo
boolean isTrtLlmRepo = isValidTrtLlmModelRepo(model);
if (!isTrtLlmRepo) {
return Optional.of(buildTrtLlmArtifacts(model));
}
return Optional.empty();
}
logger.info("TensorRT-LLM artifacts built successfully");
return trtLlmRepoDir;
}

public static boolean isValidTrtLlmModelRepo(PyModel model) throws IOException {
Optional<Path> dirToCheckOptional = Optional.empty();
Path modelPath = Paths.get(model.getProperty("model_id"));
if (Files.exists(modelPath)) {
dirToCheckOptional = Optional.of(modelPath);
}
if (!dirToCheckOptional.isPresent()) {
return false;
static Path buildTrtLlmArtifacts(PyModel model) throws IOException {
logger.info("Converting model to TensorRT-LLM artifacts");
Path trtLlmRepoDir = Paths.get("/tmp/tensorrtllm");
String modelId = model.getProperty("model_id");
// invoke trt-llm build script
List<String> commandList = getStrings(model, trtLlmRepoDir, modelId);
try {
Process exec = new ProcessBuilder(commandList).redirectErrorStream(true).start();
String logOutput;
try (InputStream is = exec.getInputStream()) {
logOutput = Utils.toString(is);
}
int exitCode = exec.waitFor();
if (0 != exitCode || logOutput.startsWith("ERROR ")) {
logger.error(logOutput);
throw new EngineException("Download model failed: " + logOutput);
} else {
logger.info(logOutput);
}
logger.info("TensorRT-LLM artifacts built successfully");
return trtLlmRepoDir;
} catch (InterruptedException e) {
throw new IOException("Failed to build TensorRT-LLM artifacts", e);
}
}

Path dirToCheck = dirToCheckOptional.get();
List<Path> configFiles = new ArrayList<>();
List<Path> tokenizerFiles = new ArrayList<>();
try (Stream<Path> walk = Files.walk(dirToCheck)) {
walk.filter(Files::isRegularFile)
.forEach(
path -> {
if ("config.pbtxt".equals(path.getFileName().toString())) {
// check depth of config.pbtxt
Path relativePath = dirToCheck.relativize(path);
if (relativePath.getNameCount() == 2) {
configFiles.add(path);
}
}
// TODO: research required tokenizer files and add a tighter check
if ("tokenizer_config.json".equals(path.getFileName().toString())) {
tokenizerFiles.add(path);
}
});
private static List<String> getStrings(PyModel model, Path trtLlmRepoDir, String modelId) {
List<String> commandList = new ArrayList<>();
commandList.add("python");
commandList.add("/opt/djl/partition/trt_llm_partition.py");
commandList.add("--properties_dir");
commandList.add(model.getModelPath().toAbsolutePath().toString());
commandList.add("--trt_llm_model_repo");
commandList.add(trtLlmRepoDir.toAbsolutePath().toString());
if (modelId != null) {
commandList.add("--model_path");
commandList.add(modelId);
}
return commandList;
}
boolean isValidRepo = !configFiles.isEmpty() && tokenizerFiles.size() == 1;
if (isValidRepo) {
logger.info("Valid TRT-LLM model repo found");
}
return isValidRepo;
}

static boolean isValidTrtLlmModelRepo(PyModel model) throws IOException {
Optional<Path> dirToCheckOptional = Optional.empty();
Path modelPath = Paths.get(model.getProperty("model_id"));
if (Files.exists(modelPath)) {
dirToCheckOptional = Optional.of(modelPath);
}
if (!dirToCheckOptional.isPresent()) {
return false;
}

Path dirToCheck = dirToCheckOptional.get();
List<Path> configFiles = new ArrayList<>();
List<Path> tokenizerFiles = new ArrayList<>();
try (Stream<Path> walk = Files.walk(dirToCheck)) {
walk.filter(Files::isRegularFile)
.forEach(
path -> {
if ("config.pbtxt".equals(path.getFileName().toString())) {
// check depth of config.pbtxt
Path relativePath = dirToCheck.relativize(path);
if (relativePath.getNameCount() == 2) {
configFiles.add(path);
}
}
// TODO: research required tokenizer files and add a tighter check
if ("tokenizer_config.json".equals(path.getFileName().toString())) {
tokenizerFiles.add(path);
}
});
}
boolean isValidRepo = !configFiles.isEmpty() && tokenizerFiles.size() == 1;
if (isValidRepo) {
logger.info("Valid TRT-LLM model repo found");
}
return isValidRepo;
}
}

0 comments on commit c5cc7ef

Please sign in to comment.