Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions containerized_job/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
FROM python:3.10-slim

WORKDIR /app

RUN apt-get update && apt-get install -y \
git build-essential \
&& rm -rf /var/lib/apt/lists/*

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY download_model.py .

ARG LLM_MODEL
ENV MODEL_PATH=./downloaded_model

RUN python download_model.py --model-name "$LLM_MODEL" --model-path "$MODEL_PATH"

COPY . .
COPY --from=external_context /vllm_llm.py .

EXPOSE 8000

CMD ["python", "app.py"]
50 changes: 50 additions & 0 deletions containerized_job/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from schema import ChatRequest, LogitsRequest
from vllm_llm import ReproducibleVLLM

MODEL_PATH = os.getenv("MODEL_PATH")


class ReproducibleVllmApp:
def __init__(self):
self.llm = ReproducibleVLLM(model_id=MODEL_PATH)
self.app = FastAPI()
self.app.post("/generate")(self.generate)
self.app.post("/generate_logits")(self.generate_logits)

async def generate(self, request: ChatRequest):
try:
result = await self.llm.generate(
messages=[m.dict() for m in request.messages],
sampling_params=request.sampling_parameters.dict(),
seed=request.seed,
continue_last_message=request.continue_last_message,
)
return {"result": result}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})

Check warning

Code scanning / CodeQL

Information exposure through an exception

[Stack trace information](1) flows to this location and may be exposed to an external user.

async def generate_logits(self, request: LogitsRequest):
try:
logits, prompt = await self.llm.generate_logits(
messages=[m.dict() for m in request.messages],
top_logprobs=request.top_logprobs,
sampling_params=request.sampling_parameters.dict(),
seed=request.seed,
continue_last_message=request.continue_last_message,
)
return {"logits": logits, "prompt": prompt}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})

Check warning

Code scanning / CodeQL

Information exposure through an exception

[Stack trace information](1) flows to this location and may be exposed to an external user.

def run(self):
uvicorn.run(self.app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
server = ReproducibleVllmApp()
server.run()
10 changes: 10 additions & 0 deletions containerized_job/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

IMAGE_NAME="sn1-validator-api"
MODEL_NAME="mrfakename/mistral-small-3.1-24b-instruct-2503-hf"

DOCKER_BUILDKIT=1 docker build \
--build-arg LLM_MODEL="$MODEL_NAME" \
-t "$IMAGE_NAME" \
--build-context external_context=../prompting/llms \
.
24 changes: 24 additions & 0 deletions containerized_job/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import argparse

from huggingface_hub import snapshot_download

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download model files")
parser.add_argument(
"--model-name",
type=str,
help="Model name to use",
)
parser.add_argument(
"--model-path",
type=str,
help="Path to save the model files",
)

args = parser.parse_args()

print(f"Downloading Model {args.model_name}, files downloaded to {args.model_path}")

snapshot_download(repo_id=args.model_name, local_dir=args.model_path)

print(f"Model files downloaded to {args.model_path}")
8 changes: 8 additions & 0 deletions containerized_job/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
fastapi==0.115.0
uvicorn==0.23.2
pydantic==2.9.0
vllm==0.8.3
torch==2.6.0
numpy==1.26.4
loguru==0.7.2
huggingface-hub==0.30.0
29 changes: 29 additions & 0 deletions containerized_job/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List, Literal, Optional

from pydantic import BaseModel


class ChatMessage(BaseModel):
content: str
role: Literal["user", "assistant", "system"]


class SamplingParameters(BaseModel):
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
max_tokens: Optional[int] = 512
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
top_k: Optional[int] = -1
logprobs: Optional[int] = None


class ChatRequest(BaseModel):
messages: List[ChatMessage]
seed: Optional[int]
sampling_parameters: Optional[SamplingParameters] = SamplingParameters()
continue_last_message: Optional[bool] = False


class LogitsRequest(ChatRequest):
top_logprobs: Optional[int] = 10