Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating TRT-LLM template #866

Merged
merged 14 commits into from
Mar 29, 2024
16 changes: 14 additions & 2 deletions truss/config/trt_llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from enum import Enum
from typing import Optional
Expand All @@ -8,13 +9,13 @@
logger = logging.getLogger(__name__)


class TRTLLMModelArchitecture(Enum):
class TRTLLMModelArchitecture(str, Enum):
LLAMA: str = "llama"
MISTRAL: str = "mistral"
DEEPSEEK: str = "deepseek"


class TRTLLMQuantizationType(Enum):
class TRTLLMQuantizationType(str, Enum):
NO_QUANT: str = "no_quant"
WEIGHTS_ONLY_INT8: str = "weights_int8"
WEIGHTS_KV_INT8: str = "weights_kv_int8"
Expand Down Expand Up @@ -48,6 +49,12 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
TrussTRTLLMPluginConfiguration()
)

class Config:
json_encoders = {
TRTLLMModelArchitecture: lambda x: x.value,
TRTLLMQuantizationType: lambda x: x.value,
}


class TrussTRTLLMServingConfiguration(BaseModel):
engine_repository: str
Expand Down Expand Up @@ -85,3 +92,8 @@ def requires_build(self):
if self.build is not None:
return True
return False

# TODO(Abu): Replace this with model_dump(json=True)
# when pydantic v2 is used here
def to_json_dict(self):
return json.loads(self.json())
2 changes: 1 addition & 1 deletion truss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@

REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_"

TRTLLM_BASE_IMAGE = "baseten/trtllm-build-server:r23.12_baseten_v0.7.1_20240111"
TRTLLM_BASE_IMAGE = "baseten/trtllm-build-server:r23.12_baseten_v0.9.0_20240325_dev3"
BASE_TRTLLM_REQUIREMENTS = [
"tritonclient[all]==2.42.0",
"transformers==4.33.1",
Expand Down
26 changes: 13 additions & 13 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from truss.contexts.truss_context import TrussContext
from truss.patch.hash import directory_content_hash
from truss.truss_config import BaseImage, ModelServer, TrussConfig
from truss.truss_config import BaseImage, TrussConfig
from truss.truss_spec import TrussSpec
from truss.util.jinja import read_template_from_fs
from truss.util.path import (
Expand Down Expand Up @@ -331,20 +331,20 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
# Copy over truss
copy_tree_path(truss_dir, build_dir, ignore_patterns=truss_ignore_patterns)
# Copy over template truss for TRT-LLM (we overwrite the model and packages dir)
if config.build.model_server is ModelServer.TRT_LLM:

if config.trt_llm is not None:
copy_tree_path(TRTLLM_TRUSS_DIR, build_dir, ignore_patterns=[])

# Check to see if TP and GPU count are the same
# TODO(Abu): Consolidate these config parameters so that we don't have to
# keep truss + template in sync if we change th einterface
if "tensor_parallel_count" in config.build.arguments:
if (
config.build.arguments["tensor_parallel_count"]
!= config.resources.accelerator.count
):
raise ValueError(
"Tensor parallelism and GPU count must be the same for TRT-LLM"
)
tensor_parallel_count = (
config.trt_llm.build.tensor_parallel_count # type: ignore[union-attr]
if config.trt_llm.build is not None
else config.trt_llm.serve.tensor_parallel_count # type: ignore[union-attr]
)

if tensor_parallel_count != config.resources.accelerator.count:
raise ValueError(
"Tensor parallelism and GPU count must be the same for TRT-LLM"
)

config.base_image = BaseImage(
image=TRTLLM_BASE_IMAGE, python_executable_path="/usr/bin/python3"
Expand Down
66 changes: 48 additions & 18 deletions truss/templates/trtllm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@
from itertools import count

import build_engine_utils
from builder.types import TrussTRTLLMConfiguration
from constants import (
GRPC_SERVICE_PORT,
HF_AUTH_KEY_CONSTANT,
HTTP_SERVICE_PORT,
TOKENIZER_KEY_CONSTANT,
)
from schema import ModelInput, TrussBuildConfig
from schema import ModelInput
from transformers import AutoTokenizer
from triton_client import TritonClient, TritonServer
from utils import execute_command

DEFAULT_MAX_TOKENS = 500
DEFAULT_MAX_NEW_TOKENS = 500


class Model:
Expand All @@ -25,65 +30,90 @@ def __init__(self, data_dir, config, secrets):
self.uses_openai_api = None

def load(self):
build_config = TrussBuildConfig(**self._config["build"]["arguments"])
execute_command(["ldconfig"])
trtllm_config = TrussTRTLLMConfiguration(**self._config.get("trt_llm", {}))
self.uses_openai_api = "openai-compatible" in self._config.get(
"model_metadata", {}
).get("tags", [])
hf_access_token = None
if "hf_access_token" in self._secrets._base_secrets.keys():
hf_access_token = self._secrets["hf_access_token"]

# TODO(Abu): Move to pre-runtime
if build_config.requires_build:
# The underlying engine build call is idempotent, so we can call it in the load
# even if the engine is already built. The engine build call checks to see if
# there are .engine files in the destination directory and skips the build if
# they are present.
if trtllm_config.requires_build:
build_engine_utils.build_engine_from_config_args(
engine_build_args=build_config.engine_build_args,
truss_trtllm_configuration=trtllm_config,
checkpoint_dir_path=None,
dst=self._data_dir,
)

self.triton_server = TritonServer(
grpc_port=GRPC_SERVICE_PORT,
http_port=HTTP_SERVICE_PORT,
)
self.triton_client = TritonClient(
grpc_service_port=GRPC_SERVICE_PORT,
)

if not trtllm_config.requires_build:
engine_repository_path = trtllm_config.serve.engine_repository
tokenizer_repository = trtllm_config.serve.tokenizer_repository
tensor_parallel_count = trtllm_config.serve.tensor_parallel_count
pipeline_parallel_count = trtllm_config.serve.pipeline_parallel_count
else:
# If this model required a build, the engine live inside the data_dir
engine_repository_path = self._data_dir
tokenizer_repository = trtllm_config.build.huggingface_ckpt_repository
tensor_parallel_count = trtllm_config.build.tensor_parallel_count
pipeline_parallel_count = trtllm_config.build.pipeline_parallel_count

world_size = tensor_parallel_count * pipeline_parallel_count

self.triton_server.create_model_repository(
truss_data_dir=self._data_dir,
engine_repository_path=build_config.engine_repository
if not build_config.requires_build
else None,
engine_repository_path=engine_repository_path,
huggingface_auth_token=hf_access_token,
)

env = {}
if hf_access_token:
env[HF_AUTH_KEY_CONSTANT] = hf_access_token
env[TOKENIZER_KEY_CONSTANT] = build_config.tokenizer_repository
env[TOKENIZER_KEY_CONSTANT] = tokenizer_repository

world_size = (
build_config.tensor_parallel_count * build_config.pipeline_parallel_count
)
self.triton_server.start(
world_size=world_size,
env=env,
)

self.triton_client = TritonClient(
grpc_service_port=GRPC_SERVICE_PORT,
)

self.tokenizer = AutoTokenizer.from_pretrained(
build_config.tokenizer_repository, token=hf_access_token
tokenizer_repository, token=hf_access_token
)
self.eos_token_id = self.tokenizer.eos_token_id

async def predict(self, model_input):
if "messages" not in model_input and "prompt" not in model_input:
raise ValueError("Prompt or messages must be provided")

model_input.setdefault("max_tokens", DEFAULT_MAX_TOKENS)
model_input.setdefault("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
model_input["request_id"] = str(os.getpid()) + str(
next(self._request_id_counter)
)
model_input["eos_token_id"] = self.eos_token_id

self.triton_client.start_grpc_stream()
if "messages" in model_input:
messages = model_input.pop("messages")
if self.uses_openai_api and "prompt" not in model_input:
model_input["prompt"] = self.tokenizer.apply_chat_template(
messages, tokenize=False
)

self.triton_client.start_grpc_stream()
model_input = ModelInput(**model_input)

result_iterator = self.triton_client.infer(model_input)

async def generate():
Expand Down
40 changes: 15 additions & 25 deletions truss/templates/trtllm/packages/build_engine_utils.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,24 @@
from pathlib import Path
from typing import Optional

from schema import EngineBuildArgs
from builder.types import TrussTRTLLMConfiguration


def build_engine_from_config_args(
engine_build_args: EngineBuildArgs,
truss_trtllm_configuration: TrussTRTLLMConfiguration,
dst: Path,
checkpoint_dir_path: Optional[Path] = None,
):
import os
import shutil
import sys

# NOTE: These are provided by the underlying base image
# TODO(Abu): Remove this when we have a better way of handling this
sys.path.append("/app/baseten")
from build_engine import Engine, build_engine
from trtllm_utils import docker_tag_aware_file_cache

engine = Engine(**engine_build_args.model_dump())

with docker_tag_aware_file_cache("/root/.cache/trtllm"):
built_engine = build_engine(engine, download_remote=True)

if not os.path.exists(dst):
os.makedirs(dst)

for filename in os.listdir(str(built_engine)):
source_file = os.path.join(str(built_engine), filename)
destination_file = os.path.join(dst, filename)
if not os.path.exists(destination_file):
shutil.copy(source_file, destination_file)

return dst
from builder.main import build_engine

build_engine(
engine_configuration=truss_trtllm_configuration,
engine_serialization_path=dst,
# If checkpoint_dir_path is provided, we'll look there for the
# weight files. If not, we will attempt to use the `huggingface_ckpt_repository`
# key in the `truss_trtllm_configuration` to download the weights.
checkpoint_dir_path=checkpoint_dir_path,
)
return dst