Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions gpu_container/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
FROM nvidia/cuda:12.1.1-devel-ubuntu20.04

# Set the working directory
WORKDIR /app

# Install Python 3.9
RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y python3.9 python3.9-dev python3.9-distutils curl && \
# Install pip for python3.9
curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9 && \
# Make python3 point to python3.9
update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 && \
# Clean up
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Copy the requirements file into the container
COPY requirements.txt .

# Install the required packages
RUN pip install --no-cache-dir -r requirements.txt

# Copy the application code
COPY . ./gpu_container/

# Command to run the application
CMD ["uvicorn", "gpu_container.app:app", "--host", "0.0.0.0", "--port", "8000"]
Empty file added gpu_container/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions gpu_container/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from contextlib import asynccontextmanager

from fastapi import FastAPI

from gpu_container.embeddings.lifespan import lifespan as embeddings_lifespan
from gpu_container.embeddings.router import router as embeddings_router
from gpu_container.vllm.lifespan import lifespan as vllm_lifespan
from gpu_container.vllm.router import router as vllm_router


@asynccontextmanager
async def lifespan(app: FastAPI):
"""
A top-level lifespan handler that calls the lifespan handlers
for different parts of the application.
"""
async with embeddings_lifespan(app):
async with vllm_lifespan(app):
yield


app = FastAPI(lifespan=lifespan)

app.include_router(embeddings_router)
app.include_router(vllm_router)
16 changes: 16 additions & 0 deletions gpu_container/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
services:
gpu-app:
build: .
ports:
- "8000:8000"
environment:
- MODEL_ID=WhereIsAI/UAE-Large-V1
- VLLM_MODEL_ID=mrfakename/mistral-small-3.1-24b-instruct-2503-hf
- VLLM_GPU_UTILIZATION=0.8
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
Empty file.
42 changes: 42 additions & 0 deletions gpu_container/embeddings/lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from contextlib import asynccontextmanager

import torch
from angle_emb import AnglE
from fastapi import FastAPI


def load_config_from_env():
"""Loads configuration from environment variables."""
model_id = os.getenv("MODEL_ID", "WhereIsAI/UAE-Large-V1")
device = os.getenv("DEVICE", "cpu")

return {"model_id": model_id, "device": device}


@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handle embedding model startup and shutdown."""
print("Loading embeddings model...")
config = load_config_from_env()
print(f"Loading model: {config['model_id']} on device: {config['device']}")

model = AnglE.from_pretrained(config["model_id"], pooling_strategy="cls")

if config["device"] == "cuda" and torch.cuda.is_available():
model.to(torch.device("cuda"))
print("Embeddings model moved to CUDA.")
else:
model.to(torch.device("cpu"))
print("Embeddings model moved to CPU.")

app.state.embeddings_model = model
app.state.embeddings_model_id = config["model_id"]
print("Embeddings model loaded.")

yield

print("Shutting down embeddings model...")
app.state.model = None
app.state.model_id = None
print("Embeddings model shut down.")
44 changes: 44 additions & 0 deletions gpu_container/embeddings/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import List

import numpy as np
from fastapi import APIRouter, Request
from pydantic import BaseModel

router = APIRouter()


class EmbeddingRequest(BaseModel):
input: List[str]


class Embedding(BaseModel):
object: str = "embedding"
index: int
embedding: List[float]


class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[Embedding]
model: str


@router.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: Request, body: EmbeddingRequest):
"""Generate embeddings for a list of texts."""
model = request.app.state.embeddings_model
model_id = request.app.state.embeddings_model_id

if model is None:
return {"error": "Model not loaded"}, 503

# Generate embeddings
embeddings = model.encode(body.input, to_numpy=True)

# Ensure embeddings are a list of lists of floats
if isinstance(embeddings, np.ndarray):
embeddings = embeddings.tolist()

response_data = [Embedding(index=i, embedding=embedding) for i, embedding in enumerate(embeddings)]

return EmbeddingResponse(data=response_data, model=model_id)
6 changes: 6 additions & 0 deletions gpu_container/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
angle-emb
torch
fastapi
uvicorn
pydantic
vllm==0.8.5
Empty file added gpu_container/vllm/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions gpu_container/vllm/lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from contextlib import asynccontextmanager

from fastapi import FastAPI

from gpu_container.vllm.reproducible_vllm import ReproducibleVLLM


def load_config_from_env():
"""Loads vLLM configuration from environment variables."""
vllm_model_id = os.getenv("VLLM_MODEL_ID", "default_model_id")
device = os.getenv("DEVICE", "cuda")
# Add any other vLLM-specific environment variables here
return {"vllm_model_id": vllm_model_id, "device": device}


@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handle vLLM engine startup and shutdown."""
print("Loading vLLM engine...")
config = load_config_from_env()

engine = ReproducibleVLLM(model_id=config["vllm_model_id"], device=config["device"])

app.state.vllm_engine = engine
app.state.vllm_model_id = config["vllm_model_id"]
print("vLLM engine loaded.")

yield

print("Shutting down vLLM engine...")
app.state.vllm_engine = None
app.state.vllm_model_id = None
print("vLLM engine shut down.")
169 changes: 169 additions & 0 deletions gpu_container/vllm/reproducible_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import random
from typing import Dict, List, Optional, Union

import numpy as np
import torch
from vllm import LLM, SamplingParams


class ReproducibleVLLM:
def __init__(
self,
model_id: str = "mrfakename/mistral-small-3.1-24b-instruct-2503-hf",
device: str = "cuda:0",
sampling_params: Optional[Dict[str, Union[str, float, int, bool]]] = None,
):
"""Deterministic VLLM model."""
self._device = device
self.model_id = model_id
self.sampling_params = {} if sampling_params is None else sampling_params

self.model = LLM(
model=model_id,
trust_remote_code=True,
gpu_memory_utilization=0.9,
)

# Store tokenizer from VLLM for consistency
self.tokenizer = self.model.get_tokenizer()

@classmethod
async def get_max_tokens(
cls,
sampling_params: Dict[str, Union[str, float, int, bool]],
default_value: int = 512,
) -> int:
# Process max tokens with backward compatibility.
max_tokens = sampling_params.get("max_tokens")
if max_tokens is None:
max_tokens = sampling_params.get("max_new_tokens")
if max_tokens is None:
max_tokens = sampling_params.get("max_completion_tokens", default_value)
return max_tokens

@classmethod
async def prepare_sampling_params(
cls, sampling_params: Optional[Dict[str, Union[str, float, int, bool]]] = None
) -> SamplingParams:
sampling_params = sampling_params or {}
max_tokens = await cls.get_max_tokens(sampling_params)

params = SamplingParams(
temperature=float(sampling_params.get("temperature", 1.0)),
top_p=float(sampling_params.get("top_p", 1.0)),
max_tokens=int(max_tokens),
presence_penalty=float(sampling_params.get("presence_penalty", 0.0)),
frequency_penalty=float(sampling_params.get("frequency_penalty", 0.0)),
top_k=int(sampling_params.get("top_k", -1)),
logprobs=sampling_params.get("logprobs", None),
)
return params

async def generate(
self,
messages: Union[List[str], List[Dict[str, str]]],
sampling_params: Optional[Dict[str, Union[str, float, int, bool]]] = None,
seed: Optional[int] = None,
continue_last_message: bool = False,
) -> str:
"""Generate text with optimized performance using VLLM."""
self.set_random_seeds(seed)

# Convert chat messages to prompt string using tokenizer's chat template
if isinstance(messages, list) and isinstance(messages[0], dict):
try:
# Extract any trailing whitespace before applying template
trailing_space = ""
if continue_last_message and messages[-1]["content"]:
content = messages[-1]["content"]
stripped = content.rstrip()
if len(content) > len(stripped):
trailing_space = content[len(stripped) :]

# Try using the tokenizer's chat template
prompt = self.tokenizer.apply_chat_template(
conversation=messages,
tokenize=False,
add_generation_prompt=not continue_last_message,
continue_final_message=continue_last_message,
)

# Append back just the trailing whitespace if it was stripped
if trailing_space:
prompt += trailing_space
except (AttributeError, NotImplementedError):
raise ValueError(f"Chat template not supported for model {self.model_id}")
else:
prompt = messages[0] if isinstance(messages, list) else messages

# Convert sampling parameters to vLLM format.
params = sampling_params if sampling_params is not None else self.sampling_params
vllm_params = await self.prepare_sampling_params(params)
outputs = self.model.generate(prompt, vllm_params)

if not outputs:
return ""

result = outputs[0].outputs[0].text
return {"choices": [{"message": {"content": result}}]}

async def generate_logits(
self,
messages: Union[List[str], List[Dict[str, str]]],
top_logprobs: int = 10,
sampling_params: Optional[Dict[str, Union[str, float, int, bool]]] = None,
seed: Optional[int] = None,
continue_last_message: bool = False,
) -> dict[str, float]:
"""Generate logits for the next token prediction.

Args:
messages: Input messages or text.
top_logprobs: Number of top logits to return (default: 10).
sampling_params: Generation parameters.
seed: Random seed for reproducibility.
continue_last_message: Whether to continue the last message in chat format.

Returns:
Dictionary mapping tokens to their log probabilities.
"""
self.set_random_seeds(seed)
params = sampling_params if sampling_params is not None else self.sampling_params
params = params.copy()
params["max_tokens"] = 1
params["logprobs"] = top_logprobs
vllm_params = await self.prepare_sampling_params(params)

prompt = self.tokenizer.apply_chat_template(
conversation=messages,
tokenize=False,
add_generation_prompt=not continue_last_message,
continue_final_message=continue_last_message,
)

outputs = self.model.generate(prompt, vllm_params)

if not outputs or not outputs[0].outputs[0].logprobs:
return {}

logprobs = outputs[0].outputs[0].logprobs[0]
token_logprobs = {self.tokenizer.decode([token]): logprob.logprob for token, logprob in logprobs.items()}
sorted_token_logprobs = dict(sorted(token_logprobs.items(), key=lambda item: item[1], reverse=True))
return sorted_token_logprobs, prompt

def set_random_seeds(self, seed: Optional[int] = 42):
"""Set random seeds for reproducibility across all relevant libraries."""
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

@staticmethod
def format_messages(
messages: Union[List[str], List[Dict[str, str]]],
) -> List[Dict[str, Union[str, List[Dict[str, str]]]]]:
return messages
Loading