diff --git a/containerized_job/Dockerfile b/containerized_job/Dockerfile new file mode 100644 index 000000000..b2e754d41 --- /dev/null +++ b/containerized_job/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.10-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + git build-essential \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY download_model.py . + +ARG LLM_MODEL +ENV MODEL_PATH=./downloaded_model + +RUN python download_model.py --model-name "$LLM_MODEL" --model-path "$MODEL_PATH" + +COPY . . +COPY --from=external_context /vllm_llm.py . + +EXPOSE 8000 + +CMD ["python", "app.py"] diff --git a/containerized_job/app.py b/containerized_job/app.py new file mode 100644 index 000000000..e93c75e46 --- /dev/null +++ b/containerized_job/app.py @@ -0,0 +1,50 @@ +import os + +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from schema import ChatRequest, LogitsRequest +from vllm_llm import ReproducibleVLLM + +MODEL_PATH = os.getenv("MODEL_PATH") + + +class ReproducibleVllmApp: + def __init__(self): + self.llm = ReproducibleVLLM(model_id=MODEL_PATH) + self.app = FastAPI() + self.app.post("/generate")(self.generate) + self.app.post("/generate_logits")(self.generate_logits) + + async def generate(self, request: ChatRequest): + try: + result = await self.llm.generate( + messages=[m.dict() for m in request.messages], + sampling_params=request.sampling_parameters.dict(), + seed=request.seed, + continue_last_message=request.continue_last_message, + ) + return {"result": result} + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + async def generate_logits(self, request: LogitsRequest): + try: + logits, prompt = await self.llm.generate_logits( + messages=[m.dict() for m in request.messages], + top_logprobs=request.top_logprobs, + sampling_params=request.sampling_parameters.dict(), + seed=request.seed, + continue_last_message=request.continue_last_message, + ) + return {"logits": logits, "prompt": prompt} + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) + + def run(self): + uvicorn.run(self.app, host="0.0.0.0", port=8000) + + +if __name__ == "__main__": + server = ReproducibleVllmApp() + server.run() diff --git a/containerized_job/build.sh b/containerized_job/build.sh new file mode 100755 index 000000000..f85949397 --- /dev/null +++ b/containerized_job/build.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +IMAGE_NAME="sn1-validator-api" +MODEL_NAME="mrfakename/mistral-small-3.1-24b-instruct-2503-hf" + +DOCKER_BUILDKIT=1 docker build \ + --build-arg LLM_MODEL="$MODEL_NAME" \ + -t "$IMAGE_NAME" \ + --build-context external_context=../prompting/llms \ + . diff --git a/containerized_job/download_model.py b/containerized_job/download_model.py new file mode 100644 index 000000000..31a436fd6 --- /dev/null +++ b/containerized_job/download_model.py @@ -0,0 +1,24 @@ +import argparse + +from huggingface_hub import snapshot_download + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download model files") + parser.add_argument( + "--model-name", + type=str, + help="Model name to use", + ) + parser.add_argument( + "--model-path", + type=str, + help="Path to save the model files", + ) + + args = parser.parse_args() + + print(f"Downloading Model {args.model_name}, files downloaded to {args.model_path}") + + snapshot_download(repo_id=args.model_name, local_dir=args.model_path) + + print(f"Model files downloaded to {args.model_path}") diff --git a/containerized_job/requirements.txt b/containerized_job/requirements.txt new file mode 100644 index 000000000..56862af5c --- /dev/null +++ b/containerized_job/requirements.txt @@ -0,0 +1,8 @@ +fastapi==0.115.0 +uvicorn==0.23.2 +pydantic==2.9.0 +vllm==0.8.3 +torch==2.6.0 +numpy==1.26.4 +loguru==0.7.2 +huggingface-hub==0.30.0 diff --git a/containerized_job/schema.py b/containerized_job/schema.py new file mode 100644 index 000000000..96354aa34 --- /dev/null +++ b/containerized_job/schema.py @@ -0,0 +1,29 @@ +from typing import List, Literal, Optional + +from pydantic import BaseModel + + +class ChatMessage(BaseModel): + content: str + role: Literal["user", "assistant", "system"] + + +class SamplingParameters(BaseModel): + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + max_tokens: Optional[int] = 512 + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + + +class ChatRequest(BaseModel): + messages: List[ChatMessage] + seed: Optional[int] + sampling_parameters: Optional[SamplingParameters] = SamplingParameters() + continue_last_message: Optional[bool] = False + + +class LogitsRequest(ChatRequest): + top_logprobs: Optional[int] = 10