# Prompt Embeddings with vLLM on AWS Inferentia2

This notebook demonstrates how to send **pre-computed embedding tensors** (instead of token IDs) to a vLLM server running on AWS Inferentia2 (Neuron). This enables use cases like:

- Embedding manipulation (e.g., interpolation, arithmetic)
- Custom embedding pipelines (e.g., from a different encoder)
- Avoiding redundant tokenization when embeddings are already available

**Requirements:** This notebook must be run on an **inf2.8xlarge** instance with the **Deep Learning AMI Neuron (Ubuntu 24.04) 20260126** and a Hugging Face token with access to `meta-llama/Llama-3.1-8B-Instruct`.  It is very version dependent and is meant as an example of how the capability can be added.  To deploy earlier version of the DLAMI see https://builder.aws.com/content/30beZEzf3XencHUcq7XDVLClkQ6/deploying-previous-versions-of-neuron-sdk-on-trainium-and-inferentia-ec2-instances

If you are running in a remote vscode, consider running ```ln -s /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13 ~/.venv``` from a terminal before selecting your kernel.

## Step 1: Set Up Environment Variables

In [1]:
import os

# Set your Hugging Face token (required for gated models like Llama)
os.environ["HF_TOKEN"] = os.environ.get("HF_TOKEN", "your-hf-token-here")

# Model configuration
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
TP_SIZE = 2
MAX_MODEL_LEN = 512
SERVER_PORT = 8000
SERVER_URL = f"http://localhost:{SERVER_PORT}"

print(f"Model: {MODEL_NAME}")
print(f"HF_TOKEN: {'set' if os.environ['HF_TOKEN'] != 'your-hf-token-here' else 'NOT SET -- update above'}")

Model: meta-llama/Llama-3.1-8B-Instruct
HF_TOKEN: set


## Step 2: Clone and Install Modified Branches

The `prompt_embeds` feature requires modifications to two repositories:

1. **vllm-neuron** -- the Neuron platform plugin for vLLM (passes embeddings from the API to the model)
2. **neuronx-distributed-inference (NxDI)** -- the Neuron model compilation/execution library (handles embeddings in the traced model graph)

We install these over the pre-installed DLAMI venv using editable installs.

In [2]:
%%bash
source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate

cd /home/ubuntu
if [ ! -d "neuronx-distributed-inference" ]; then
    git clone -b embeddings https://github.com/jimburtoft/neuronx-distributed-inference.git
fi
cd neuronx-distributed-inference
git checkout embeddings
pip install -e . --quiet 2>&1 | tail -3
echo "NxDI installed successfully"

Cloning into 'neuronx-distributed-inference'...


Already on 'embeddings'


Your branch is up to date with 'origin/embeddings'.


NxDI installed successfully


In [3]:
%%bash
source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate

cd /home/ubuntu
if [ ! -d "vllm-neuron" ]; then
    git clone -b embeddings https://github.com/jimburtoft/vllm-neuron.git
fi
cd vllm-neuron
git checkout embeddings
pip install -e . --quiet 2>&1 | tail -3
echo "vllm-neuron installed successfully"

Cloning into 'vllm-neuron'...


Already on 'embeddings'


Your branch is up to date with 'origin/embeddings'.


vllm-neuron installed successfully


## Step 3: Start the vLLM Server

We launch the vLLM OpenAI-compatible server with `--enable-prompt-embeds`. The server will download model weights, compile the model for Neuron (first run takes ~15-25 minutes; cached NEFFs are used on subsequent runs), load it onto Neuron cores, and start serving.

**Important flags:**
- `--enable-prompt-embeds` -- enables the prompt embedding API
- `--no-enable-prefix-caching` -- **required** (prefix caching is on by default in vLLM v0.13 and is not yet compatible with prompt_embeds)

In [4]:
import subprocess

log_file = "/home/ubuntu/vllm_server.log"

server_cmd = (
    f"source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate && "
    f"HF_TOKEN={os.environ['HF_TOKEN']} "
    f"python -m vllm.entrypoints.openai.api_server "
    f"--model {MODEL_NAME} "
    f"--tensor-parallel-size {TP_SIZE} "
    f"--max-model-len {MAX_MODEL_LEN} "
    f"--max-num-seqs 1 "
    f"--enable-prompt-embeds "
    f"--no-enable-prefix-caching "
    f"--port {SERVER_PORT}"
)

subprocess.Popen(
    f"bash -c '{server_cmd}' > {log_file} 2>&1 &",
    shell=True,
)

print(f"Server starting in background. Log: {log_file}")

Server starting in background. Log: /home/ubuntu/vllm_server.log


### Wait for the Server to Be Ready

This polls the server until it responds. First run (compilation) takes ~15-25 minutes.

In [5]:
import requests
import time

start_time = time.time()
print("Waiting for server to be ready...")

while True:
    try:
        resp = requests.get(f"{SERVER_URL}/v1/models", timeout=2)
        if resp.status_code == 200:
            elapsed = time.time() - start_time
            model_id = resp.json()["data"][0]["id"]
            print(f"Server ready! Model: {model_id} (took {elapsed:.0f}s)")
            break
    except (requests.ConnectionError, requests.Timeout):
        pass

    elapsed = time.time() - start_time
    if elapsed > 2400:  # 40 min timeout
        raise TimeoutError(f"Server did not start within {elapsed:.0f}s")
    time.sleep(15)

Waiting for server to be ready...


Server ready! Model: meta-llama/Llama-3.1-8B-Instruct (took 1215s)


## Step 4: Test Regular Text Completion

Verify the server works with a standard text prompt. Both regular text and prompt embeddings work on the same compiled model -- no recompilation needed.

In [6]:
TEST_PROMPT = "The capital of France is"

response = requests.post(
    f"{SERVER_URL}/v1/completions",
    json={
        "model": MODEL_NAME,
        "prompt": TEST_PROMPT,
        "max_tokens": 30,
        "temperature": 0,
    },
    timeout=60,
)

assert response.status_code == 200, f"Server error: {response.text}"
baseline_text = response.json()["choices"][0]["text"]
print(f"Prompt: \"{TEST_PROMPT}\"")
print(f"Output: \"{baseline_text}\"")

Prompt: "The capital of France is"
Output: " a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture."


## Step 5: Compute Embeddings from the Model's Embedding Layer

To test `prompt_embeds`, we produce embedding vectors that match what the model's internal embedding layer would generate:

1. Tokenize the prompt
2. Load **only** the embedding weight matrix from safetensors (not the full model)
3. Look up the token embeddings

Since these are the same embeddings the model would compute internally, the output should match the baseline exactly.

In [7]:
import torch
from transformers import AutoTokenizer, AutoConfig
from safetensors import safe_open
import glob
import gc

# Tokenize
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokens = tokenizer(TEST_PROMPT, return_tensors="pt")
input_ids = tokens["input_ids"]

print(f"Prompt:     \"{TEST_PROMPT}\"")
print(f"Token IDs:  {input_ids[0].tolist()}")
print(f"Tokens:     {[tokenizer.decode(t) for t in input_ids[0]]}")
print(f"Seq length: {input_ids.shape[1]}")

Prompt:     "The capital of France is"
Token IDs:  [128000, 791, 6864, 315, 9822, 374]
Tokens:     ['<|begin_of_text|>', 'The', ' capital', ' of', ' France', ' is']
Seq length: 6


In [8]:
# Load only the embedding weight matrix (avoids loading the full 16GB model)
config = AutoConfig.from_pretrained(MODEL_NAME)

cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
model_dirs = glob.glob(os.path.join(
    cache_dir,
    f"models--{MODEL_NAME.replace('/', '--')}",
    "snapshots", "*"
))

if model_dirs:
    model_dir = model_dirs[0]
    st_files = sorted(glob.glob(os.path.join(model_dir, "model-*.safetensors")))
else:
    from huggingface_hub import hf_hub_download
    st_files = [hf_hub_download(MODEL_NAME, "model-00001-of-00004.safetensors")]

embed_weight = None
for st_file in st_files:
    with safe_open(st_file, framework="pt") as f:
        if "model.embed_tokens.weight" in f.keys():
            embed_weight = f.get_tensor("model.embed_tokens.weight")
            print(f"Loaded embed_tokens.weight from {os.path.basename(st_file)}")
            break

assert embed_weight is not None, "Could not find embedding weights"

embed_layer = torch.nn.Embedding(config.vocab_size, config.hidden_size)
embed_layer.weight = torch.nn.Parameter(embed_weight)
print(f"Embedding matrix: {embed_weight.shape} ({embed_weight.dtype})")

Loaded embed_tokens.weight from model-00001-of-00004.safetensors


Embedding matrix: torch.Size([128256, 4096]) (torch.bfloat16)


In [9]:
# Compute embeddings
with torch.no_grad():
    # Cast to bfloat16 to match the Neuron model's computation dtype
    embeddings = embed_layer(input_ids).to(torch.bfloat16)

# vLLM API expects shape [seq_len, hidden_dim] (no batch dim)
prompt_embeds = embeddings.squeeze(0)

print(f"Embeddings shape: {prompt_embeds.shape}")
print(f"Embeddings dtype: {prompt_embeds.dtype}")
print(f"First 5 values:   {prompt_embeds[0, :5].tolist()}")

del embed_layer, embed_weight
gc.collect()

Embeddings shape: torch.Size([6, 4096])
Embeddings dtype: torch.bfloat16
First 5 values:   [0.0002651214599609375, -0.000499725341796875, -0.000583648681640625, 0.00133514404296875, 9.393692016601562e-05]


33

## Step 6: Serialize Embeddings for the API

The vLLM `prompt_embeds` API expects the tensor serialized as **base64-encoded `torch.save()` bytes**.

In [10]:
import base64
import io

buf = io.BytesIO()
torch.save(prompt_embeds, buf)
buf.seek(0)
b64_data = base64.b64encode(buf.read()).decode("utf-8")

print(f"Serialized size: {len(b64_data):,} characters (base64)")
print(f"Raw tensor size: {prompt_embeds.nelement() * prompt_embeds.element_size():,} bytes")

Serialized size: 67,640 characters (base64)
Raw tensor size: 49,152 bytes


## Step 7: Send Prompt Embeddings to the Server

Send the pre-computed embeddings via the `prompt_embeds` field. The server skips the internal embedding lookup and uses our embeddings directly.

In [11]:
response = requests.post(
    f"{SERVER_URL}/v1/completions",
    json={
        "model": MODEL_NAME,
        "prompt_embeds": b64_data,
        "max_tokens": 30,
        "temperature": 0,
    },
    timeout=60,
)

assert response.status_code == 200, f"Server error: {response.text}"
embeds_text = response.json()["choices"][0]["text"]

print(f"=== Results ===")
print(f"Baseline (text):      \"{baseline_text}\"")
print(f"Prompt embeds:        \"{embeds_text}\"")
print()
if baseline_text == embeds_text:
    print("PERFECT MATCH -- prompt_embeds produces identical output to text tokens!")
else:
    print("OUTPUTS DIFFER -- see notes below for possible causes.")

=== Results ===
Baseline (text):      " a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture."
Prompt embeds:        " a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture."

PERFECT MATCH -- prompt_embeds produces identical output to text tokens!


## Step 8: Embedding Manipulation Example

The real power of `prompt_embeds` is the ability to manipulate embeddings before sending them. Here we blend two prompts by averaging their embedding vectors.

In [12]:
# Reload embedding layer briefly
embed_weight_2 = None
for st_file in st_files:
    with safe_open(st_file, framework="pt") as f:
        if "model.embed_tokens.weight" in f.keys():
            embed_weight_2 = f.get_tensor("model.embed_tokens.weight")
            break

embed_layer_2 = torch.nn.Embedding(config.vocab_size, config.hidden_size)
embed_layer_2.weight = torch.nn.Parameter(embed_weight_2)

prompt_a = "The weather in Paris is"
prompt_b = "The weather in Tokyo is"

tokens_a = tokenizer(prompt_a, return_tensors="pt")["input_ids"]
tokens_b = tokenizer(prompt_b, return_tensors="pt")["input_ids"]

with torch.no_grad():
    embeds_a = embed_layer_2(tokens_a).to(torch.bfloat16).squeeze(0)
    embeds_b = embed_layer_2(tokens_b).to(torch.bfloat16).squeeze(0)

del embed_layer_2, embed_weight_2
gc.collect()

print(f"Prompt A: \"{prompt_a}\" -> shape {embeds_a.shape}")
print(f"Prompt B: \"{prompt_b}\" -> shape {embeds_b.shape}")

if embeds_a.shape == embeds_b.shape:
    blended = 0.5 * embeds_a + 0.5 * embeds_b
    print(f"Blended:  shape {blended.shape}")

    buf = io.BytesIO()
    torch.save(blended, buf)
    buf.seek(0)
    b64_blended = base64.b64encode(buf.read()).decode("utf-8")

    resp = requests.post(
        f"{SERVER_URL}/v1/completions",
        json={
            "model": MODEL_NAME,
            "prompt_embeds": b64_blended,
            "max_tokens": 30,
            "temperature": 0,
        },
        timeout=60,
    )
    assert resp.status_code == 200, f"Server error: {resp.text}"
    print(f"\nBlended output: \"{resp.json()['choices'][0]['text']}\"")
else:
    print(f"\nPrompts have different token lengths ({embeds_a.shape[0]} vs {embeds_b.shape[0]}).")
    print("Pad the shorter embedding to blend them.")

Prompt A: "The weather in Paris is" -> shape torch.Size([6, 4096])
Prompt B: "The weather in Tokyo is" -> shape torch.Size([6, 4096])
Blended:  shape torch.Size([6, 4096])



Blended output: " mild and pleasant in the spring, making it an ideal time to visit the city. The average temperature in April is around 12°C (54°F"


## Cleanup

In [13]:
%%bash
pkill -f "vllm.entrypoints.openai.api_server" 2>/dev/null && echo "Server stopped." || echo "No server running."

Server stopped.


## How It Works

The Neuron compiler traces a static computation graph, so it cannot use Python-level `if/else` branching to choose between token embeddings and pre-computed embeddings at runtime. Instead, the model **always computes token embeddings** from `input_ids`, then uses `torch.where` to select between the computed embeddings and the provided `inputs_embeds` based on whether the input is non-zero.

This means both regular text prompts and prompt embeddings work on the same compiled model with no recompilation.

### Performance Note on `torch.where`

The `torch.where` approach always runs the token embedding lookup even when pre-computed embeddings are provided -- the embedding table lookup still executes on every request, and its result is simply discarded when `inputs_embeds` is non-zero. For most use cases this overhead is negligible (the embedding lookup is a small fraction of total inference time). However, if your production workload **exclusively** sends pre-computed embeddings and never uses token IDs, a more efficient approach would be to compile a separate model that skips the embedding lookup entirely. That would require a dedicated traced model variant and is not covered in this notebook.

### Limitations

- **Prefix caching not supported**: Must use `--no-enable-prefix-caching`. Prefix caching is on by default in vLLM v0.13.
- **Sequence length**: `prompt_embeds` length must fit within `--max-model-len` and the compiled bucket sizes.
- **dtype**: Embeddings should be `torch.bfloat16` to match the model's computation dtype.
- **Serialization format**: Must use `torch.save()` + base64 encoding (not JSON lists).
- **Tested with**: Llama-3.1-8B-Instruct on inf2.8xlarge. Other model architectures supported by NxD Inference should also work but are untested.

### Modified Repositories

- **vllm-neuron** `embeddings` branch: [github.com/jimburtoft/vllm-neuron](https://github.com/jimburtoft/vllm-neuron/tree/embeddings)
- **neuronx-distributed-inference** `embeddings` branch: [github.com/jimburtoft/neuronx-distributed-inference](https://github.com/jimburtoft/neuronx-distributed-inference/tree/embeddings)