Skip to content

Commit

Permalink
Merge pull request #12 from demml/feature/auth-jwt
Browse files Browse the repository at this point in the history
Feature/auth jwt
  • Loading branch information
thorrester committed May 5, 2024
2 parents 45c0a62 + 72b9221 commit dd52d36
Show file tree
Hide file tree
Showing 34 changed files with 882 additions and 152 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/lint-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ jobs:


Unit-Test-310-Coverage:
needs: Lints
strategy:
matrix:
python-version: ["3.10"]
Expand All @@ -101,6 +100,9 @@ jobs:
env:
OPSML_TESTING: 1
LOG_LEVEL: DEBUG
OPSML_AUTH: True
OPSML_USERNAME: admin
OPSML_PASSWORD: admin
steps:
- uses: actions/checkout@v4
- name: Install poetry
Expand All @@ -119,7 +121,7 @@ jobs:
rm -rf /opt/hostedtoolcache/node
sudo apt clean
make setup.project
make test.unit
make test.coverage
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
Expand Down
4 changes: 2 additions & 2 deletions .tool-versions
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
python 3.10.14
python 3.11.7
poetry 1.8.2
gitleaks 8.18.0
gitleaks 8.18.0
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ setup.sysdeps:
&& echo " feel free to ignore when on drone")

test.unit:
poetry run pytest \
-m "not large and not compat and not appsec" \
--ignore tests/integration \
--cov \
--cov-fail-under=0 \
--cov-report xml:./coverage.xml \
--cov-report term \
--junitxml=./results.xml

test.coverage:
poetry run pytest \
-m "not large and not compat" \
--ignore tests/integration \
Expand Down
9 changes: 8 additions & 1 deletion opsml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from opsml.cards import DataCard, ModelCard, PipelineCard, ProjectCard, RunCard
from opsml.cards import (
AuditCard,
DataCard,
ModelCard,
PipelineCard,
ProjectCard,
RunCard,
)
from opsml.data import (
ArrowData,
DataInterface,
Expand Down
5 changes: 5 additions & 0 deletions opsml/app/core/event_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

from opsml.helpers.logging import ArtifactLogger
from opsml.model.registrar import ModelRegistrar
from opsml.registry.backend import _set_registry
from opsml.registry.registry import CardRegistries
from opsml.settings.config import config
from opsml.storage import client
from opsml.types import RegistryType

logger = ArtifactLogger.get_logger()

Expand All @@ -39,6 +41,9 @@ def _init_registries(app: FastAPI) -> None:
app.state.model_registrar = ModelRegistrar(client.storage_client)
app.state.storage_root = config.storage_root

if config.opsml_auth:
app.state.auth_db = _set_registry(RegistryType.AUTH)


def _shutdown_registries(app: FastAPI) -> None:
app.state.registries = None
Expand Down
37 changes: 22 additions & 15 deletions opsml/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import Any, List, Optional
from typing import Optional

import uvicorn
from fastapi import Depends, FastAPI
from fastapi.staticfiles import StaticFiles
from prometheus_fastapi_instrumentator import Instrumentator

from opsml.app.core.event_handlers import lifespan
from opsml.app.core.login import get_current_username
from opsml.app.core.middleware import rollbar_middleware
from opsml.app.routes.router import api_router
from opsml.app.routes import auth
from opsml.app.routes.router import build_router
from opsml.helpers.logging import ArtifactLogger
from opsml.settings.config import config
from opsml.settings.config import OpsmlConfig, config

logger = ArtifactLogger.get_logger()

Expand All @@ -25,19 +25,26 @@


class OpsmlApp:
def __init__(self, port: int = 8888, login: bool = False):
def __init__(self, port: int = 8888, app_config: Optional[OpsmlConfig] = None):
self.port = port
self.login = login
self.app = FastAPI(title=config.app_name, dependencies=self.get_login(), lifespan=lifespan)
if app_config is None:
self.app_config = config
else:
self.app_config = app_config

def get_login(self) -> Optional[List[Any]]:
"""Sets the login dependency for an app if specified"""

if self.login:
return [Depends(get_current_username)]
return None
self.app = FastAPI(title=self.app_config.app_name, lifespan=lifespan)

def build_app(self) -> None:
# build routes for the app and include auth deps

if self.app_config.opsml_auth:
deps = [Depends(auth.get_current_active_user)]
else:
deps = None

api_router = build_router(dependencies=deps)
api_router.include_router(auth.router, tags=["auth"], prefix="/opsml")

self.app.include_router(api_router)
self.app.mount("/static", StaticFiles(directory=STATIC_PATH), name="static")

Expand All @@ -55,8 +62,8 @@ def get_app(self) -> FastAPI:
return self.app


def run_app(login: bool = False) -> FastAPI:
return OpsmlApp(login=login).get_app()
def run_app(port: int = 8888, app_config: Optional[OpsmlConfig] = None) -> FastAPI:
return OpsmlApp(port, app_config).get_app()


if __name__ == "__main__":
Expand Down
2 changes: 0 additions & 2 deletions opsml/app/routes/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@

AUDIT_FILE = "audit_file.csv"

templates = Jinja2Templates(directory=TEMPLATE_PATH)

audit_route_helper = AuditRouteHelper()
router = APIRouter()

Expand Down
211 changes: 211 additions & 0 deletions opsml/app/routes/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from typing import Annotated

import jwt
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel

from opsml.helpers.logging import ArtifactLogger
from opsml.registry.sql.base.server import ServerAuthRegistry
from opsml.settings.config import config
from opsml.types.extra import User

logger = ArtifactLogger.get_logger()

router = APIRouter()


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/opsml/auth/token")


class Token(BaseModel):
access_token: str
token_type: str


class TokenData(BaseModel):
username: str


class UserCreated(BaseModel):
created: bool = False


class UserUpdated(BaseModel):
updated: bool = False


class UserDeleted(BaseModel):
deleted: bool = False


router = APIRouter()


async def get_current_user(
request: Request,
token: Annotated[str, Depends(oauth2_scheme)],
) -> User:
auth_db: ServerAuthRegistry = request.app.state.auth_db

credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)

try:
payload = jwt.decode(
token,
config.opsml_jwt_secret,
algorithms=[config.opsml_jwt_algorithm],
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)

except jwt.exceptions.ExpiredSignatureError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="token_expired",
headers={"WWW-Authenticate": "Bearer"},
) from exc

except jwt.exceptions.DecodeError as exc:
raise credentials_exception from exc

user = auth_db.get_user(token_data.username)
if user is None:
raise credentials_exception
return user


async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user


@router.post("/auth/token")
async def login_for_access_token(
request: Request,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Token:
logger.info("Logging in user: {}", form_data.username)

# quick exit if auth is disabled
if not config.opsml_auth:
return Token(access_token="", token_type="bearer")

auth_db: ServerAuthRegistry = request.app.state.auth_db
user = auth_db.get_user(form_data.username)

if user is None:
logger.info("User does not exist: {}", form_data.username)

raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)

assert user is not None

# check if password is correct
authenicated = auth_db.authenticate_user(user, form_data.password)

if not authenicated:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)

logger.info("User authenticated: {}", form_data.username)
return Token(access_token=auth_db.create_access_token(user), token_type="bearer")


@router.get("/auth/user", response_model=User)
def get_user(
request: Request,
username: str,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> User:
"""Retrieves user by username"""
if not current_user.scopes.admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")

auth_db: ServerAuthRegistry = request.app.state.auth_db
user = auth_db.get_user(username)

if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")

return user


@router.post("/auth/user", response_model=UserCreated)
def create_user(
request: Request,
user: User,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> UserCreated:
"""Create new user"""
if not current_user.scopes.admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")

auth_db: ServerAuthRegistry = request.app.state.auth_db

# check user not exists
if auth_db.get_user(user.username) is not None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="User already exists")

# add user
auth_db.add_user(user)

# test getting user
user = auth_db.get_user(user.username)

if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Failed to create user")

return UserCreated(created=True)


@router.put("/auth/user", response_model=UserUpdated)
def update_user(
request: Request,
user: User,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> UserUpdated:
"""Update user"""
if not current_user.scopes.admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")

auth_db: ServerAuthRegistry = request.app.state.auth_db
updated = auth_db.update_user(user)

return UserUpdated(updated=updated)


@router.delete("/auth/user", response_model=UserDeleted)
def delete_user(
request: Request,
username: str,
current_user: Annotated[User, Depends(get_current_active_user)],
) -> UserDeleted:
"""Delete user"""
if not current_user.scopes.admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")

auth_db: ServerAuthRegistry = request.app.state.auth_db
user = auth_db.get_user(username)

if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")

deleted = auth_db.delete_user(user)
return UserDeleted(deleted=deleted)
1 change: 1 addition & 0 deletions opsml/app/routes/cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

logger = ArtifactLogger.get_logger()


router = APIRouter()


Expand Down
1 change: 1 addition & 0 deletions opsml/app/routes/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

MAX_FILE_SIZE = 1024 * 1024 * 1024 * 50 # = 50GB
MAX_REQUEST_BODY_SIZE = MAX_FILE_SIZE + 1024

router = APIRouter()


Expand Down
Loading

0 comments on commit dd52d36

Please sign in to comment.