Skip to content

Commit

Permalink
Added bearer token header to worker http client (for HF API) (#3240)
Browse files Browse the repository at this point in the history
  • Loading branch information
yk committed May 27, 2023
1 parent f26406a commit 274e65d
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions inference/worker/__main__.py
Expand Up @@ -40,6 +40,7 @@ def main():
base_url=settings.inference_server_url,
basic_auth_username=settings.basic_auth_username,
basic_auth_password=settings.basic_auth_password,
bearer_token=settings.bearer_token,
)

while True:
Expand Down
3 changes: 3 additions & 0 deletions inference/worker/settings.py
Expand Up @@ -5,6 +5,7 @@ class Settings(pydantic.BaseSettings):
backend_url: str = "ws://localhost:8000"
model_config_name: str = "distilgpt2"
inference_server_url: str = "http://localhost:8001"
inference_server_route: str = "/generate_stream"
safety_server_url: str = "http://localhost:8002"
api_key: str = "0000"

Expand All @@ -21,6 +22,8 @@ class Settings(pydantic.BaseSettings):
# for hf basic server
quantize: bool = False

bearer_token: str | None = None

basic_auth_username: str | None = None
basic_auth_password: str | None = None

Expand Down
13 changes: 12 additions & 1 deletion inference/worker/utils.py
Expand Up @@ -172,6 +172,7 @@ class HttpClient(pydantic.BaseModel):
base_url: str
basic_auth_username: str | None = None
basic_auth_password: str | None = None
bearer_token: str | None = None

@property
def auth(self):
Expand All @@ -180,10 +181,19 @@ def auth(self):
else:
return None

def _maybe_add_bearer_token(self, headers: dict[str, str] | None):
if self.bearer_token:
if headers is None:
headers = {}
headers["Authorization"] = f"Bearer {self.bearer_token}"
return headers

def get(self, path: str, **kwargs):
kwargs["headers"] = self._maybe_add_bearer_token(kwargs.get("headers"))
return requests.get(self.base_url + path, auth=self.auth, **kwargs)

def post(self, path: str, **kwargs):
kwargs["headers"] = self._maybe_add_bearer_token(kwargs.get("headers"))
return requests.post(self.base_url + path, auth=self.auth, **kwargs)


Expand All @@ -192,9 +202,10 @@ def get_inference_server_stream_events(request: interface.GenerateStreamRequest)
base_url=settings.inference_server_url,
basic_auth_username=settings.basic_auth_username,
basic_auth_password=settings.basic_auth_password,
bearer_token=settings.bearer_token,
)
response = http.post(
"/generate_stream",
settings.inference_server_route,
json=request.dict(),
stream=True,
headers={"Accept": "text/event-stream"},
Expand Down

0 comments on commit 274e65d

Please sign in to comment.