Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions Dockerfile-neuron
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef
WORKDIR /usr/src

ENV SCCACHE=0.10.0
ENV RUSTC_WRAPPER=/usr/local/bin/sccache

# Donwload, configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG SCCACHE_GHA_ENABLED

COPY --from=planner /usr/src/recipe.json recipe.json

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

FROM builder AS http-builder

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

FROM builder AS grpc-builder

COPY proto proto

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s

FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-dev \
build-essential \
git \
curl \
cmake \
pkg-config \
protobuf-compiler \
ninja-build \
&& rm -rf /var/lib/apt/lists/*

RUN ln -s /usr/bin/python3 /usr/local/bin/python || true
RUN ln -s /usr/bin/pip3 /usr/local/bin/pip || true

WORKDIR /usr/src
COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
RUN cd backends/python/server && \
make install

ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58
ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008
ARG NEURONX_TOOLS_VERSION=2.26.14.0

ARG NEURONX_CC_VERSION=2.21.33363.0+82129205
ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.16998+e9bf8a50
ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf

RUN apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y --no-install-recommends \
apt-transport-https \
build-essential \
ca-certificates \
cmake \
curl \
emacs \
git \
gnupg2 \
gpg-agent \
jq \
libgl1-mesa-glx \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
libcap-dev \
libhwloc-dev \
openjdk-11-jdk \
unzip \
vim \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /tmp/tmp* \
&& apt-get clean

RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -

RUN apt-get update \
&& apt-get install -y \
aws-neuronx-tools=$NEURONX_TOOLS_VERSION \
aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \
aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /tmp/tmp* \
&& apt-get clean

ENV PATH="/opt/aws/neuron/bin:${PATH}"

RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \
--extra-index-url https://pypi.org/simple \
--trusted-host pip.repos.neuron.amazonaws.com \
neuronx-cc==$NEURONX_CC_VERSION \
torch-neuronx==$NEURONX_FRAMEWORK_VERSION \
torchvision \
neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \
&& rm -rf ~/.cache/pip/*

# HF ARGS
ARG TRANSFORMERS_VERSION=4.57.1
ARG DIFFUSERS_VERSION=0.35.2
ARG HUGGINGFACE_HUB_VERSION=0.36.0
ARG OPTIMUM_NEURON_VERSION=0.4.1
ARG SENTENCE_TRANSFORMERS=5.1.2
ARG PEFT_VERSION=0.17.0
ARG DATASETS_VERSION=4.1.1

# install Hugging Face libraries and its dependencies
# optimum-neuron==${OPTIMUM_NEURON_VERSION} \
RUN pip install --no-cache-dir -U \
networkx==2.8.8 \
transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \
diffusers==${DIFFUSERS_VERSION} \
compel \
controlnet-aux \
huggingface_hub==${HUGGINGFACE_HUB_VERSION} \
hf_transfer \
datasets==${DATASETS_VERSION} \
"optimum-neuron @ git+https://github.com/huggingface/optimum-neuron@main" \
sentence_transformers==${SENTENCE_TRANSFORMERS} \
peft==${PEFT_VERSION} \
&& rm -rf ~/.cache/pip/*


FROM neuron AS grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]

FROM neuron

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
from text_embeddings_server.utils.device import get_device, use_ipex
from text_embeddings_server.models.neuron_models import NeuronSentenceTransformersModel

from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron

__all__ = ["Model"]

Expand Down Expand Up @@ -74,6 +76,11 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
logger.info(f"backend device: {device}")

config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

# Neuron cases
if is_neuron():
if config.model_type == "bert":
return create_model(NeuronSentenceTransformersModel, model_path, device, datatype)

if (
hasattr(config, "auto_map")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import inspect
import torch

from pathlib import Path
from typing import Type, List
from optimum.neuron import NeuronSentenceTransformers
from opentelemetry import trace

from text_embeddings_server.models import Model
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score

tracer = trace.get_tracer(__name__)


class NeuronSentenceTransformersModel(Model):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
):
model = NeuronSentenceTransformers.from_pretrained(model_path)

self.hidden_size = model.config.hidden_size
position_offset = 0
model_type = model.config.model_type
if model_type in ["xlm-roberta", "camembert", "roberta"]:
position_offset = model.config.pad_token_id + 1
if hasattr(model.config, "max_seq_length"):
self.max_input_length = model.config.max_seq_length
else:
self.max_input_length = (
model.config.max_position_embeddings - position_offset
)

self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.has_token_type_ids = (
inspect.signature(model.forward).parameters.get("token_type_ids", None)
is not None
)

super(NeuronSentenceTransformersModel, self).__init__(
model=model, dtype=dtype, device=device
)

@property
def batch_type(self) -> Type[PaddedBatch]:
return PaddedBatch

@tracer.start_as_current_span("embed")
def embed(self, batch: PaddedBatch) -> List[Embedding]:
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
if self.has_token_type_ids:
kwargs["token_type_ids"] = batch.token_type_ids
output = self.model(**kwargs)

sentence_embedding = output["sentence_embedding"]

return [
Embedding(
values=sentence_embedding[i * self.hidden_size : (i + 1) * self.hidden_size]
)
for i in range(len(batch))
]

@tracer.start_as_current_span("predict")
def predict(self, batch: PaddedBatch) -> List[Score]:
pass
19 changes: 19 additions & 0 deletions backends/python/server/text_embeddings_server/utils/device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import re
import functools
from loguru import logger
import importlib.metadata
import importlib.util
Expand Down Expand Up @@ -49,6 +51,21 @@ def is_hpu() -> bool:
is_hpu_available = False
return is_hpu_available

@functools.cache
def get_neuron_major() -> int:
MAJORS_FILE = "/proc/devices"
NEURON_MAJOR_LINE = re.compile(r"^\s*(\d+)\s+neuron\s*$")
if not os.path.exists(MAJORS_FILE):
return -1
with open(MAJORS_FILE, "r") as f:
for l in f.readlines():
m = NEURON_MAJOR_LINE.match(l)
if m:
return int(m.group(1))
return -1

def is_neuron() -> bool:
return get_neuron_major > -1

def use_ipex() -> bool:
value = os.environ.get("USE_IPEX", "True").lower()
Expand All @@ -72,5 +89,7 @@ def get_device():

if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
elif is_neuron():
device = torch.device("xla")

return device
Loading
Loading