Skip to content

Commit

Permalink
Added Authorization header API KEY checker
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Aug 9, 2023
1 parent 7b554e2 commit 71a468c
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 333 deletions.
2 changes: 2 additions & 0 deletions llama_api/server/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def run(
skip_pytorch_install: bool = False,
skip_tensorflow_install: bool = False,
skip_compile: bool = False,
api_key: Optional[str] = None,
) -> None:
initialize_before_launch(
git_and_disk_paths=Config.git_and_disk_paths,
Expand All @@ -169,6 +170,7 @@ def run(
from uvicorn import Server as UvicornServer

environ["MAX_WORKERS"] = str(max_workers)
environ["API_KEY"] = api_key or ""

UvicornServer(
config=UvicornConfig(
Expand Down
105 changes: 86 additions & 19 deletions llama_api/utils/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import cached_property
from os import environ
from pathlib import Path
from re import Match, Pattern, compile
from typing import Callable, Coroutine, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -124,6 +126,17 @@ class RouteErrorHandler(APIRoute):
): ErrorResponseFormatters.model_not_found,
}

api_key: Optional[str] = environ.get("API_KEY", None) or None

@cached_property
def authorization(self) -> Optional[str]:
"""API key for authentication"""
if self.api_key is None:
return None
if not self.api_key.startswith("sk-"):
self.api_key = f"sk-{self.api_key}"
return f"Bearer {self.api_key}"

def error_message_wrapper(
self,
error: Exception,
Expand Down Expand Up @@ -154,7 +167,7 @@ def error_message_wrapper(
return 500, ErrorResponse(
message=str(error),
type="internal_server_error",
param=f"traceback:: {self.parse_trackback(error)}",
param=f"traceback:: {parse_trackback(error)}",
code=type(error).__name__,
)

Expand All @@ -167,6 +180,47 @@ async def custom_route_handler(self, request: Request) -> Response:
"""Defines custom route handler that catches exceptions and formats
in OpenAI style error response"""
try:
if self.authorization is not None:
# Check API key
authorization = request.headers.get(
"Authorization",
request.query_params.get("authorization", None),
) # type: Optional[str]
if not authorization or not authorization.startswith(
"Bearer "
):
error_response = ErrorResponse(
message=(
(
"You didn't provide an API key. "
"You need to provide your API key in "
"an Authorization header using Bearer auth "
"(i.e. Authorization: Bearer YOUR_KEY)."
)
),
type="invalid_request_error",
param=None,
code=None,
)
return JSONResponse(
{"error": error_response},
status_code=401,
)
if authorization != self.authorization:
api_key = authorization[len("Bearer ") :] # noqa: E203
error_response = ErrorResponse(
message=(
"Incorrect API key provided: "
+ mask_secret(api_key, 8, 4)
),
type="invalid_request_error",
param=None,
code="invalid_api_key",
)
return JSONResponse(
{"error": error_response},
status_code=401,
)
return await super().get_route_handler()(request)
except Exception as error:
json_body = await request.json()
Expand Down Expand Up @@ -200,23 +254,36 @@ async def custom_route_handler(self, request: Request) -> Response:
status_code=status_code,
)

def parse_trackback(self, exception: Exception) -> str:
"""Parses traceback information from the exception"""
if (
exception.__traceback__ is not None
and exception.__traceback__.tb_next is not None
):
# Get previous traceback from the exception
traceback = exception.__traceback__.tb_next

# Get filename, function name, and line number
try:
co_filename = Path(traceback.tb_frame.f_code.co_filename).name
except Exception:
co_filename = "UNKNOWN"
co_name = traceback.tb_frame.f_code.co_name
lineno = traceback.tb_lineno
return f"Error in {co_filename} at line {lineno} in {co_name}"
def parse_trackback(exception: Exception) -> str:
"""Parses traceback information from the exception"""
if (
exception.__traceback__ is not None
and exception.__traceback__.tb_next is not None
):
# Get previous traceback from the exception
traceback = exception.__traceback__.tb_next

# If traceback is not available, return UNKNOWN
return "UNKNOWN"
# Get filename, function name, and line number
try:
co_filename = Path(traceback.tb_frame.f_code.co_filename).name
except Exception:
co_filename = "UNKNOWN"
co_name = traceback.tb_frame.f_code.co_name
lineno = traceback.tb_lineno
return f"Error in {co_filename} at line {lineno} in {co_name}"

# If traceback is not available, return UNKNOWN
return "UNKNOWN"


def mask_secret(api_key: str, n_start: int, n_end: int) -> str:
length = len(api_key)
if length <= n_start + n_end:
return api_key
else:
return (
api_key[:n_start]
+ "*" * (length - n_start - n_end)
+ api_key[-n_end:]
)
11 changes: 6 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
help="Maximum number of process workers to run; default is 1",
)
parser.add_argument(
"-i",
"--install-pkgs",
action="store_true",
help="Install all required packages before running the server",
Expand All @@ -43,9 +42,11 @@
help="Skip installing tensorflow, if `install-pkgs` is set",
)
parser.add_argument(
"--skip-compile",
action="store_true",
help="Skip compiling the shared library of LLaMA C++ code",
"-k",
"--api-key",
type=str,
default=None,
help="API key to use for the server",
)

args = parser.parse_args()
Expand All @@ -56,5 +57,5 @@
force_cuda=args.force_cuda,
skip_pytorch_install=args.skip_torch_install,
skip_tensorflow_install=args.skip_tf_install,
skip_compile=args.skip_compile,
api_key=args.api_key,
)
Loading

0 comments on commit 71a468c

Please sign in to comment.