Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support authorization on the gateway #851

Merged
merged 14 commits into from
Jan 29, 2024
6 changes: 5 additions & 1 deletion gateway/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ dependencies = [
"httpx",
"jinja2",
"uvicorn",
"aiocache",
]

[tool.setuptools.package-data]
"dstack.gateway" = ["systemd/resources/*"]
"dstack.gateway" = [
"resources/nginx/*",
"resources/systemd/*",
]

[tool.setuptools.dynamic]
version = {attr = "dstack.gateway.version.__version__"}
Empty file.
17 changes: 17 additions & 0 deletions gateway/src/dstack/gateway/auth/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from fastapi import APIRouter, Depends, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from dstack.gateway.services.auth import AuthProvider, get_auth

router = APIRouter()


@router.get("/{project}")
async def get_auth(
project: str,
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
auth: AuthProvider = Depends(get_auth),
):
if await auth.has_access(project, token.credentials):
return {"status": "ok"}
raise HTTPException(status_code=403)
10 changes: 10 additions & 0 deletions gateway/src/dstack/gateway/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
import functools
from typing import Callable, ParamSpec, TypeVar

import httpx

R = TypeVar("R")
P = ParamSpec("P")


async def run_async(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
func_with_args = functools.partial(func, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(None, func_with_args)


class AsyncClientWrapper(httpx.AsyncClient):
def __del__(self):
try:
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass
2 changes: 2 additions & 0 deletions gateway/src/dstack/gateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dstack.gateway.openai.store as openai_store
import dstack.gateway.version
from dstack.gateway.auth.routes import router as auth_router
from dstack.gateway.logging import configure_logging
from dstack.gateway.openai.routes import router as openai_router
from dstack.gateway.registry.routes import router as registry_router
Expand Down Expand Up @@ -36,6 +37,7 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
app.include_router(registry_router, prefix="/api/registry")
app.include_router(openai_router, prefix="/api/openai")
app.include_router(auth_router, prefix="/auth")


@app.get("/")
Expand Down
2 changes: 1 addition & 1 deletion gateway/src/dstack/gateway/openai/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResp

@abstractmethod
async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
pass
yield
11 changes: 1 addition & 10 deletions gateway/src/dstack/gateway/openai/clients/tgi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import datetime
import json
import uuid
from typing import AsyncIterator, Dict, List, Optional

import httpx
import jinja2
import jinja2.sandbox

from dstack.gateway.common import AsyncClientWrapper
from dstack.gateway.errors import GatewayError
from dstack.gateway.openai.clients import ChatCompletionsClient
from dstack.gateway.openai.schemas import (
Expand Down Expand Up @@ -185,13 +184,5 @@ def trim_stop_tokens(text: str, stop_tokens: List[str]) -> str:
return text


class AsyncClientWrapper(httpx.AsyncClient):
def __del__(self):
try:
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass


def raise_exception(message: str):
raise jinja2.TemplateError(message)
16 changes: 14 additions & 2 deletions gateway/src/dstack/gateway/openai/routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Annotated, AsyncIterator

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, Security
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from dstack.gateway.errors import GatewayError
from dstack.gateway.openai.schemas import (
Expand All @@ -11,8 +12,19 @@
ModelsResponse,
)
from dstack.gateway.openai.store import OpenAIStore, get_store
from dstack.gateway.services.auth import AuthProvider, get_auth

router = APIRouter()

async def auth_required(
project: str,
auth: AuthProvider = Depends(get_auth),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
):
if not await auth.has_access(project, token.credentials):
raise HTTPException(status_code=403)


router = APIRouter(dependencies=[Depends(auth_required)])


@router.get("/{project}/models")
Expand Down
24 changes: 24 additions & 0 deletions gateway/src/dstack/gateway/resources/nginx/entrypoint.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
server {
server_name {{ domain }};
location / {
proxy_pass http://localhost:{{ port }}/{{ prefix.strip('/') }}/;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
}
listen 80;
listen 443 ssl;
ssl_certificate /etc/letsencrypt/live/{{ domain }}/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/{{ domain }}/privkey.pem;
include /etc/letsencrypt/options-ssl-nginx.conf;
ssl_dhparam /etc/letsencrypt/ssl-dhparams.pem;
set $force_https 1;
if ($scheme = "https") {
set $force_https 0;
}
if ($remote_addr = 127.0.0.1) {
set $force_https 0;
}
if ($force_https) {
return 301 https://$host$request_uri;
}
}
54 changes: 54 additions & 0 deletions gateway/src/dstack/gateway/resources/nginx/service.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
upstream {{ upstream }} {
server {{ server }};
}
server {
server_name {{ domain }};
location / {
{% if auth %}
auth_request /auth;
{% endif %}
try_files /nonexistent @$http_upgrade;
}
location @websocket {
proxy_pass http://{{ upstream }};
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "Upgrade";
}
location @ {
proxy_pass http://{{ upstream }};
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
}
{% if auth %}
location = /auth {
internal;
if ($remote_addr = 127.0.0.1) {
return 200;
}
proxy_pass http://localhost:{{ port }}/auth/{{ project }};
proxy_pass_request_body off;
proxy_set_header Content-Length "";
proxy_set_header X-Original-URI $request_uri;
proxy_set_header Authorization $http_authorization;
}
{% endif %}
listen 80;
listen 443 ssl;
ssl_certificate /etc/letsencrypt/live/{{ domain }}/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/{{ domain }}/privkey.pem;
include /etc/letsencrypt/options-ssl-nginx.conf;
ssl_dhparam /etc/letsencrypt/ssl-dhparams.pem;
set $force_https 1;
if ($scheme = "https") {
set $force_https 0;
}
if ($remote_addr = 127.0.0.1) {
set $force_https 0;
}
if ($force_https) {
return 301 https://$host$request_uri;
}
}
1 change: 1 addition & 0 deletions gateway/src/dstack/gateway/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ class Service(BaseModel):
docker_ssh_host: Optional[str] = None
docker_ssh_port: Optional[int] = None

auth: bool = True
options: dict = {}
33 changes: 33 additions & 0 deletions gateway/src/dstack/gateway/services/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging
from functools import lru_cache

import httpx
from aiocache import cached

from dstack.gateway.common import AsyncClientWrapper

DSTACK_SERVER_TUNNEL_PORT = 8001
logger = logging.getLogger(__name__)


class AuthProvider:
def __init__(self):
self.client = AsyncClientWrapper(base_url=f"http://localhost:{DSTACK_SERVER_TUNNEL_PORT}")

@cached(ttl=60, noself=True)
async def has_access(self, project: str, token: str) -> bool:
try:
resp = await self.client.post(
f"/api/projects/{project}/get",
headers={"Authorization": f"Bearer {token}"},
)
if resp.status_code == 200:
return True
except httpx.RequestError as e:
logger.debug("Failed to check access: %s", e)
return False


@lru_cache()
def get_auth() -> AuthProvider:
return AuthProvider()
Loading
Loading