diff --git a/Dockerfile-neuron b/Dockerfile-neuron new file mode 100644 index 000000000..9f4b23740 --- /dev/null +++ b/Dockerfile-neuron @@ -0,0 +1,187 @@ +FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef +WORKDIR /usr/src + +ENV SCCACHE=0.10.0 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache + +# Donwload, configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +FROM chef AS planner + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG SCCACHE_GHA_ENABLED + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +FROM builder AS http-builder + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +FROM builder AS grpc-builder + +COPY proto proto + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s + +FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + build-essential \ + git \ + curl \ + cmake \ + pkg-config \ + protobuf-compiler \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/bin/python3 /usr/local/bin/python || true +RUN ln -s /usr/bin/pip3 /usr/local/bin/pip || true + +WORKDIR /usr/src +COPY backends backends +COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py +COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml +RUN cd backends/python/server && \ + make install + +ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58 +ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008 +ARG NEURONX_TOOLS_VERSION=2.26.14.0 + +ARG NEURONX_CC_VERSION=2.21.33363.0+82129205 +ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.16998+e9bf8a50 +ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf + +RUN apt-get update \ + && apt-get upgrade -y \ + && apt-get install -y --no-install-recommends \ + apt-transport-https \ + build-essential \ + ca-certificates \ + cmake \ + curl \ + emacs \ + git \ + gnupg2 \ + gpg-agent \ + jq \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libcap-dev \ + libhwloc-dev \ + openjdk-11-jdk \ + unzip \ + vim \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list +RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add - + +RUN apt-get update \ + && apt-get install -y \ + aws-neuronx-tools=$NEURONX_TOOLS_VERSION \ + aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \ + aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +ENV PATH="/opt/aws/neuron/bin:${PATH}" + +RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \ + --extra-index-url https://pypi.org/simple \ + --trusted-host pip.repos.neuron.amazonaws.com \ + neuronx-cc==$NEURONX_CC_VERSION \ + torch-neuronx==$NEURONX_FRAMEWORK_VERSION \ + torchvision \ + neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \ + && rm -rf ~/.cache/pip/* + +# HF ARGS +ARG TRANSFORMERS_VERSION=4.57.1 +ARG DIFFUSERS_VERSION=0.35.2 +ARG HUGGINGFACE_HUB_VERSION=0.36.0 +ARG OPTIMUM_NEURON_VERSION=0.4.1 +ARG SENTENCE_TRANSFORMERS=5.1.2 +ARG PEFT_VERSION=0.17.0 +ARG DATASETS_VERSION=4.1.1 + +# install Hugging Face libraries and its dependencies +# optimum-neuron==${OPTIMUM_NEURON_VERSION} \ +RUN pip install --no-cache-dir -U \ + networkx==2.8.8 \ + transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \ + diffusers==${DIFFUSERS_VERSION} \ + compel \ + controlnet-aux \ + huggingface_hub==${HUGGINGFACE_HUB_VERSION} \ + hf_transfer \ + datasets==${DATASETS_VERSION} \ + "optimum-neuron @ git+https://github.com/huggingface/optimum-neuron@main" \ + sentence_transformers==${SENTENCE_TRANSFORMERS} \ + peft==${PEFT_VERSION} \ + && rm -rf ~/.cache/pip/* + + +FROM neuron AS grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM neuron + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 1e919f233..0ca8b584c 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -14,7 +14,9 @@ from text_embeddings_server.models.jinaBert_model import FlashJinaBert from text_embeddings_server.models.flash_mistral import FlashMistral from text_embeddings_server.models.flash_qwen3 import FlashQwen3 -from text_embeddings_server.utils.device import get_device, use_ipex +from text_embeddings_server.models.neuron_models import NeuronSentenceTransformersModel + +from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron __all__ = ["Model"] @@ -74,6 +76,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): logger.info(f"backend device: {device}") config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) + + # Neuron cases + if is_neuron(): + if config.model_type == "bert": + return create_model(NeuronSentenceTransformersModel, model_path, device, datatype) if ( hasattr(config, "auto_map") diff --git a/backends/python/server/text_embeddings_server/models/neuron_models.py b/backends/python/server/text_embeddings_server/models/neuron_models.py new file mode 100644 index 000000000..e3b850c3e --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/neuron_models.py @@ -0,0 +1,71 @@ +import inspect +import torch + +from pathlib import Path +from typing import Type, List +from optimum.neuron import NeuronSentenceTransformers +from opentelemetry import trace + +from text_embeddings_server.models import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + + +class NeuronSentenceTransformersModel(Model): + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + ): + model = NeuronSentenceTransformers.from_pretrained(model_path) + + self.hidden_size = model.config.hidden_size + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.has_token_type_ids = ( + inspect.signature(model.forward).parameters.get("token_type_ids", None) + is not None + ) + + super(NeuronSentenceTransformersModel, self).__init__( + model=model, dtype=dtype, device=device + ) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + if self.has_token_type_ids: + kwargs["token_type_ids"] = batch.token_type_ids + output = self.model(**kwargs) + + sentence_embedding = output["sentence_embedding"] + + return [ + Embedding( + values=sentence_embedding[i * self.hidden_size : (i + 1) * self.hidden_size] + ) + for i in range(len(batch)) + ] + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + pass diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py index 3f3b04dd7..46b81370f 100644 --- a/backends/python/server/text_embeddings_server/utils/device.py +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -1,4 +1,6 @@ import os +import re +import functools from loguru import logger import importlib.metadata import importlib.util @@ -49,6 +51,21 @@ def is_hpu() -> bool: is_hpu_available = False return is_hpu_available +@functools.cache +def get_neuron_major() -> int: + MAJORS_FILE = "/proc/devices" + NEURON_MAJOR_LINE = re.compile(r"^\s*(\d+)\s+neuron\s*$") + if not os.path.exists(MAJORS_FILE): + return -1 + with open(MAJORS_FILE, "r") as f: + for l in f.readlines(): + m = NEURON_MAJOR_LINE.match(l) + if m: + return int(m.group(1)) + return -1 + +def is_neuron() -> bool: + return get_neuron_major > -1 def use_ipex() -> bool: value = os.environ.get("USE_IPEX", "True").lower() @@ -72,5 +89,7 @@ def get_device(): if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") + elif is_neuron(): + device = torch.device("xla") return device diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 245715b38..b53067de1 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -67,6 +67,15 @@ fn is_hpu() -> bool { } } +fn is_neuron() -> bool { + match Command::new("neuron-ls") + .output() + { + Ok(output) => output.status.success(), + Err(_) => false, + } +} + #[derive(Debug, Clone)] pub struct Backend { /// Channel to communicate with the background thread @@ -409,16 +418,39 @@ async fn init_backend( if let Some(api_repo) = api_repo.as_ref() { if cfg!(feature = "python") || cfg!(feature = "candle") { let start = std::time::Instant::now(); - if download_safetensors(api_repo).await.is_err() { - tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); - tracing::info!("Downloading `pytorch_model.bin`"); - api_repo - .get("pytorch_model.bin") + if is_neuron() { + tracing::info!("Downloading `model.neuron`"); + let model_files = download_neuron(api_repo) .await .map_err(|err| BackendError::WeightsNotFound(err.to_string()))?; - } - tracing::info!("Model weights downloaded in {:?}", start.elapsed()); + if model_files.is_empty() { + tracing::error!( + "Neuron model files not found in the repository. \ + You can easily compile your model to neuron format following the guide: \ + https://huggingface.co/docs/optimum-neuron/en/model_doc/sentence_transformers/overview " + ); + return Err(BackendError::WeightsNotFound( + "No Neuron model files found".into(), + )); + } + + tracing::info!("Neuron model downloaded in {:?}", start.elapsed()); + } else { + if download_safetensors(api_repo).await.is_err() { + tracing::warn!( + "safetensors weights not found. Using `pytorch_model.bin` instead. \ + Model loading will be significantly slower." + ); + tracing::info!("Downloading `pytorch_model.bin`"); + api_repo + .get("pytorch_model.bin") + .await + .map_err(|err| BackendError::WeightsNotFound(err.to_string()))?; + } + + tracing::info!("Model weights downloaded in {:?}", start.elapsed()); + } } } @@ -655,6 +687,20 @@ async fn download_onnx(api: &ApiRepo) -> Result, ApiError> { Ok(model_files) } +async fn download_neuron(api: &ApiRepo) -> Result, ApiError> { + let mut model_files: Vec = Vec::new(); + + tracing::info!("Downloading `model.neuron`"); + match api.get("model.neuron").await { + Ok(p) => model_files.push(p), + Err(err) => { + tracing::warn!("Could not download `model.neuron`: {err}"); + } + }; + + Ok(model_files) +} + #[cfg(feature = "candle")] #[derive(Debug, Clone, Deserialize, PartialEq)] enum ModuleType { diff --git a/docs/source/en/ aws_neuron.md b/docs/source/en/ aws_neuron.md new file mode 100644 index 000000000..d383fdba8 --- /dev/null +++ b/docs/source/en/ aws_neuron.md @@ -0,0 +1,37 @@ + +# Using TEI Container with AWS Trainium and Inferentia Instances + +## Build Docker Image + +To build a container optimized for AWS Neuron devices, run the following command: + +```shell +platform="neuron" + +docker build . -f Dockerfile-neuron -t tei-neuron:main +``` + +### Deploy Docker Container + +To deploy your model on an AWS Trainium or Inferentia instance, use the following command: + +```shell +model='optimum/bge-base-en-v1.5-neuronx' +volume=$PWD/data + +docker run -p 8080:80 -v $volume:/data tei-neuron:main --model-id $model --dtype float32 +``` \ No newline at end of file diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fa6f21e63..b9eebac2c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -19,6 +19,8 @@ title: Build custom container for TEI - local: intel_container title: Using TEI container with Intel Hardware + - local: local_neuron + title: Using TEI container with AWS Neuron - local: examples title: Example uses title: Tutorials diff --git a/integration_tests/neuron/conftest.py b/integration_tests/neuron/conftest.py new file mode 100644 index 000000000..e69de29bb diff --git a/integration_tests/neuron/test_embed.py b/integration_tests/neuron/test_embed.py new file mode 100644 index 000000000..e69de29bb