Skip to content

Commit

Permalink
[serving] Avoid using special config.properties for DeepSpeed (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Dec 6, 2022
1 parent f5d71d3 commit f883a83
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 17 deletions.
2 changes: 1 addition & 1 deletion serving/docker/deepspeed.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ CMD ["serve"]
COPY scripts scripts/
RUN mkdir -p /opt/djl/conf && \
mkdir -p /opt/djl/deps
COPY deepspeed.config.properties /opt/djl/conf/config.properties
COPY config.properties /opt/djl/conf/config.properties

RUN apt-get update && \
scripts/install_djl_serving.sh $djl_version && \
Expand Down
6 changes: 0 additions & 6 deletions serving/docker/deepspeed.config.properties

This file was deleted.

19 changes: 9 additions & 10 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ private void initModelStore() throws IOException {
String modelUrl = matcher.group(3);
String version = null;
String engineName = null;
String deviceMapping = null;
String deviceMapping = "*";
String modelName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
Expand All @@ -389,12 +389,9 @@ private void initModelStore() throws IOException {
continue;
}
}
String[] devices = {null};
if (deviceMapping != null) {
DependencyManager.getInstance().installEngine(engineName);
Engine engine = Engine.getEngine(engineName);
devices = parseDevices(deviceMapping, engine, pair.getValue());
}
DependencyManager.getInstance().installEngine(engineName);
Engine engine = Engine.getEngine(engineName);
String[] devices = parseDevices(deviceMapping, engine, pair.getValue());

WlmConfigManager wlmc = WlmConfigManager.getInstance();
ModelInfo<Input, Output> modelInfo =
Expand All @@ -410,13 +407,12 @@ private void initModelStore() throws IOException {
wlmc.getMaxBatchDelay(),
wlmc.getBatchSize());
Workflow workflow = new Workflow(modelInfo);
String[] finalDevices = devices;
CompletableFuture<Void> f =
modelManager
.registerWorkflow(workflow)
.thenAccept(
v -> {
for (String deviceName : finalDevices) {
for (String deviceName : devices) {
modelManager.initWorkers(workflow, deviceName, -1, -1);
}
})
Expand Down Expand Up @@ -627,7 +623,8 @@ private String[] parseDevices(String devices, Engine engine, Path modelDir) {
if ("*".equals(devices)) {
int gpuCount = engine.getGpuCount();
if (gpuCount > 0) {
if ("Python".equals(engine.getEngineName())) {
String engineName = engine.getEngineName();
if ("Python".equals(engineName)) {
Properties prop = getServingProperties(modelDir);
String v = Utils.getenv("TENSOR_PARALLEL_DEGREE", "-1");
v = prop.getProperty("option.tensor_parallel_degree", v);
Expand All @@ -642,6 +639,8 @@ private String[] parseDevices(String devices, Engine engine, Path modelDir) {
}
gpuCount = procs;
}
} else if ("DeepSpeed".equals(engineName)) {
return new String[] {"0"};
}

return IntStream.range(0, gpuCount)
Expand Down

0 comments on commit f883a83

Please sign in to comment.