Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend RBAC to support project id as input #673

Merged
merged 4 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion registry/access_control/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Admin roles can add or delete roles in management UI page or through management
| RBAC_API_AUDIENCE | Used as audience to decode jwt tokens |

## Notes

Please notice that User Role records are **NOT** case sensitive. All records will be converted to lower case before saving to database.
Supported scenarios status are tracked below:

- General Foundations:
Expand Down
42 changes: 21 additions & 21 deletions registry/access_control/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ async def get_projects() -> list[str]:


@router.get('/projects/{project}', name="Get My Project [Read Access Required]")
async def get_project(project: str, requestor: User = Depends(project_read_access)):
async def get_project(project: str, access: UserAccess = Depends(project_read_access)):
response = requests.get(url=f"{registry_url}/projects/{project}",
headers=get_api_header(requestor)).content.decode('utf-8')
headers=get_api_header(access.user_name)).content.decode('utf-8')
return json.loads(response)


@router.get("/projects/{project}/datasources", name="Get data sources of my project [Read Access Required]")
def get_project_datasources(project: str, requestor: User = Depends(project_read_access)) -> list:
def get_project_datasources(project: str, access: UserAccess = Depends(project_read_access)) -> list:
response = requests.get(url=f"{registry_url}/projects/{project}/datasources",
headers=get_api_header(requestor)).content.decode('utf-8')
headers=get_api_header(access.user_name)).content.decode('utf-8')
return json.loads(response)


Expand All @@ -46,16 +46,16 @@ def get_project_datasource(project: str, datasource: str, requestor: User = Depe


@router.get("/projects/{project}/features", name="Get features under my project [Read Access Required]")
def get_project_features(project: str, keyword: Optional[str] = None, requestor: User = Depends(project_read_access)) -> list:
def get_project_features(project: str, keyword: Optional[str] = None, access: UserAccess = Depends(project_read_access)) -> list:
response = requests.get(url=f"{registry_url}/projects/{project}/features",
headers=get_api_header(requestor)).content.decode('utf-8')
headers=get_api_header(access.user_name)).content.decode('utf-8')
return json.loads(response)


@router.get("/features/{feature}", name="Get a single feature by feature Id [Read Access Required]")
def get_feature(feature: str, requestor: User = Depends(get_user)) -> dict:
response = requests.get(url=f"{registry_url}/features/{feature}",
headers=get_api_header(requestor)).content.decode('utf-8')
headers=get_api_header(requestor.username)).content.decode('utf-8')
ret = json.loads(response)

feature_qualifiedName = ret['attributes']['qualifiedName']
Expand All @@ -67,7 +67,7 @@ def get_feature(feature: str, requestor: User = Depends(get_user)) -> dict:
@router.get("/features/{feature}/lineage", name="Get Feature Lineage [Read Access Required]")
def get_feature_lineage(feature: str, requestor: User = Depends(get_user)) -> dict:
response = requests.get(url=f"{registry_url}/features/{feature}/lineage",
headers=get_api_header(requestor)).content.decode('utf-8')
headers=get_api_header(requestor.username)).content.decode('utf-8')
ret = json.loads(response)

feature_qualifiedName = ret['guidEntityMap'][feature]['attributes']['qualifiedName']
Expand All @@ -80,35 +80,35 @@ def get_feature_lineage(feature: str, requestor: User = Depends(get_user)) -> di
def new_project(definition: dict, requestor: User = Depends(get_user)) -> dict:
rbac.init_userrole(requestor.username, definition["name"])
response = requests.post(url=f"{registry_url}/projects", json=definition,
headers=get_api_header(requestor)).content.decode('utf-8')
headers=get_api_header(requestor.username)).content.decode('utf-8')
return json.loads(response)


@router.post("/projects/{project}/datasources", name="Create new data source of my project [Write Access Required]")
def new_project_datasource(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict:
def new_project_datasource(project: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict:
response = requests.post(url=f"{registry_url}/projects/{project}/datasources", json=definition, headers=get_api_header(
requestor)).content.decode('utf-8')
access.user_name)).content.decode('utf-8')
return json.loads(response)


@router.post("/projects/{project}/anchors", name="Create new anchors of my project [Write Access Required]")
def new_project_anchor(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict:
def new_project_anchor(project: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict:
response = requests.post(url=f"{registry_url}/projects/{project}/anchors", json=definition, headers=get_api_header(
requestor)).content.decode('utf-8')
access.user_name)).content.decode('utf-8')
return json.loads(response)


@router.post("/projects/{project}/anchors/{anchor}/features", name="Create new anchor features of my project [Write Access Required]")
def new_project_anchor_feature(project: str, anchor: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict:
def new_project_anchor_feature(project: str, anchor: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict:
response = requests.post(url=f"{registry_url}/projects/{project}/anchors/{anchor}/features", json=definition, headers=get_api_header(
requestor)).content.decode('utf-8')
access.user_name)).content.decode('utf-8')
return json.loads(response)


@router.post("/projects/{project}/derivedfeatures", name="Create new derived features of my project [Write Access Required]")
def new_project_derived_feature(project: str, definition: dict, requestor: User = Depends(project_write_access)) -> dict:
def new_project_derived_feature(project: str, definition: dict, access: UserAccess = Depends(project_write_access)) -> dict:
response = requests.post(url=f"{registry_url}/projects/{project}/derivedfeatures",
json=definition, headers=get_api_header(requestor)).content.decode('utf-8')
json=definition, headers=get_api_header(access.user_name)).content.decode('utf-8')
return json.loads(response)

# Below are access control management APIs
Expand All @@ -118,10 +118,10 @@ def get_userroles(requestor: User = Depends(get_user)) -> list:


@router.post("/users/{user}/userroles/add", name="Add a new user role [Project Manage Access Required]")
def add_userrole(project: str, user: str, role: str, reason: str, requestor: User = Depends(project_manage_access)):
return rbac.add_userrole(project, user, role, reason, requestor.username)
def add_userrole(project: str, user: str, role: str, reason: str, access: UserAccess = Depends(project_manage_access)):
return rbac.add_userrole(access.project_name, user, role, reason, access.user_name)


@router.delete("/users/{user}/userroles/delete", name="Delete a user role [Project Manage Access Required]")
def delete_userrole(project: str, user: str, role: str, reason: str, requestor: User = Depends(project_manage_access)):
return rbac.delete_userrole(project, user, role, reason, requestor.username)
def delete_userrole(user: str, role: str, reason: str, access: UserAccess= Depends(project_manage_access)):
return rbac.delete_userrole(access.project_name, user, role, reason, access.user_name)
41 changes: 28 additions & 13 deletions registry/access_control/rbac/access.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any
from typing import Any, Union
from uuid import UUID
from fastapi import Depends, HTTPException, status
from rbac.db_rbac import DbRBAC

from rbac.models import AccessType, User
from rbac.models import AccessType, User, UserAccess,_to_uuid
from rbac.auth import authorize

"""
Expand All @@ -22,24 +23,25 @@ def get_user(user: User = Depends(authorize)) -> User:
return user


def project_read_access(project: str, user: User = Depends(authorize)) -> User:
def project_read_access(project: str, user: User = Depends(authorize)) -> UserAccess:
return _project_access(project, user, AccessType.READ)


def project_write_access(project: str, user: User = Depends(authorize)) -> User:
def project_write_access(project: str, user: User = Depends(authorize)) -> UserAccess:
return _project_access(project, user, AccessType.WRITE)


def project_manage_access(project: str, user: User = Depends(authorize)) -> User:
def project_manage_access(project: str, user: User = Depends(authorize)) -> UserAccess:
return _project_access(project, user, AccessType.MANAGE)


def _project_access(project: str, user: User, access: str):
def _project_access(project: str, user: User, access: str) -> UserAccess:
project = _get_project_name(project)
if rbac.validate_project_access_users(project, user.username, access):
return user
return UserAccess(user.username, project)
else:
raise ForbiddenAccess(
f"{access} privileges for project {project} required for user {user.username}")
f"{access} access for project {project} is required for user {user.username}")


def global_admin_access(user: User = Depends(authorize)):
Expand All @@ -48,16 +50,29 @@ def global_admin_access(user: User = Depends(authorize)):
else:
raise ForbiddenAccess('Admin privileges required')

def validate_project_access_for_feature(feature:str, user:str, access:str):
def validate_project_access_for_feature(feature:str, user:User, access:str):
project = _get_project_from_feature(feature)
_project_access(project, user, access)


def _get_project_from_feature(feature: str):
feature_delimiter = "__"
return feature.split(feature_delimiter)[0]

def get_api_header(requestor: User):
def get_api_header(username: str):
return {
"x-registry-requestor": requestor.username
}
"x-registry-requestor": username
}

def _get_project_name(id_or_name: Union[str, UUID]):
try:
_to_uuid(id_or_name)
if id_or_name not in rbac.projects_ids:
# refresh project id map if id not found
rbac.get_projects_ids()
return rbac.projects_ids[id_or_name]
except KeyError:
raise ForbiddenAccess(f"Project Id {id_or_name} not found in Registry")
except ValueError:
pass
# It is a name
return id_or_name
32 changes: 20 additions & 12 deletions registry/access_control/rbac/db_rbac.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import requests
from fastapi import HTTPException, status
from typing import Any
from rbac import config
from rbac.database import connect
from rbac.models import AccessType, UserRole, RoleType, SUPER_ADMIN_SCOPE
from rbac.models import AccessType, UserRole, RoleType, SUPER_ADMIN_SCOPE, _to_uuid
from rbac.interface import RBAC
import os
import logging
Expand All @@ -19,6 +21,7 @@ def __init__(self):
os.environ["RBAC_CONNECTION_STR"] = config.RBAC_CONNECTION_STR
self.conn = connect()
self.get_userroles()
self.get_projects_ids()

def get_userroles(self):
# Cache is not supported in cluster, make sure every operation read from database.
Expand Down Expand Up @@ -56,9 +59,9 @@ def get_userroles_by_user(self, user_name: str, role_name: str = None) -> list[U
where delete_reason is null and user_name ='%s'"""
if role_name:
query += fr"and role_name = '%s'"
rows = self.conn.query(query % (user_name, role_name))
rows = self.conn.query(query % (user_name.lower(), role_name.lower()))
else:
rows = self.conn.query(query % (user_name))
rows = self.conn.query(query % (user_name.lower()))
ret = []
for row in rows:
ret.append(UserRole(**row))
Expand All @@ -72,9 +75,9 @@ def get_userroles_by_project(self, project_name: str, role_name: str = None) ->
where delete_reason is null and project_name ='%s'"""
if role_name:
query += fr"and role_name = '%s'"
rows = self.conn.query(query % (project_name, role_name))
rows = self.conn.query(query % (project_name.lower(), role_name.lower()))
else:
rows = self.conn.query(query % (project_name))
rows = self.conn.query(query % (project_name.lower()))
ret = []
for row in rows:
ret.append(UserRole(**row))
Expand Down Expand Up @@ -106,8 +109,8 @@ def add_userrole(self, project_name: str, user_name: str, role_name: str, create
# insert new record
query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time)
values ('%s','%s','%s','%s' ,'%s', getutcdate())"""
self.conn.update(query % (project_name, user_name,
role_name, by, create_reason))
self.conn.update(query % (project_name.lower(), user_name.lower(),
role_name.lower(), by, create_reason.replace("'", "''")))
logging.info(
f"Userrole added with query: {query%(project_name, user_name, role_name, by, create_reason)}")
self.get_userroles()
Expand All @@ -122,8 +125,8 @@ def delete_userrole(self, project_name: str, user_name: str, role_name: str, del
[delete_time] = getutcdate()
WHERE [user_name] = '%s' and [project_name] = '%s' and [role_name] = '%s'
and [delete_time] is null"""
self.conn.update(query % (by, delete_reason,
user_name, project_name, role_name))
self.conn.update(query % (by, delete_reason.replace("'", "''"),
user_name.lower(), project_name.lower(), role_name.lower()))
logging.info(
f"Userrole removed with query: {query%(by, delete_reason, user_name, project_name, role_name)}")
self.get_userroles()
Expand All @@ -141,22 +144,27 @@ def init_userrole(self, creator_name: str, project_name:str):
query = fr"""select project_name, user_name, role_name, create_by, create_reason, create_time, delete_reason, delete_time
from userroles
where delete_reason is null and project_name ='%s'"""
rows = self.conn.query(query%(project_name))
rows = self.conn.query(query%(project_name.lower()))
if len(rows) > 0:
logging.warning(f"{project_name} already exist, please pick another name.")
return
else:
# initialize project admin if project not exist:
self.init_project_admin(creator_name, project_name)



def init_project_admin(self, creator_name: str, project_name: str):
"""initialize the creator as project admin when a new project is created
"""
create_by = "system"
create_reason = "creator of project, get admin by default."
query = fr"""insert into userroles (project_name, user_name, role_name, create_by, create_reason, create_time)
values ('%s','%s','%s','%s','%s', getutcdate())"""
self.conn.update(query % (project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason))
self.conn.update(query % (project_name.lower(), creator_name.lower(), RoleType.ADMIN.value, create_by, create_reason))
logging.info(f"Userrole initialized with query: {query%(project_name, creator_name, RoleType.ADMIN.value, create_by, create_reason)}")
return self.get_userroles()

def get_projects_ids(self):
"""cache all project ids from registry api"""
response = requests.get(url=f"{config.RBAC_REGISTRY_URL}/projects-ids").content.decode('utf-8')
self.projects_ids = json.loads(response)
66 changes: 66 additions & 0 deletions registry/access_control/rbac/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
from typing import List, Optional
from pydantic import BaseModel
from datetime import datetime
from enum import Enum
from uuid import UUID

class User(BaseModel):
id: str
Expand Down Expand Up @@ -97,3 +99,67 @@ def to_dict(self) -> dict:
"project_name": self.project_name,
"access": self.access_name,
}

class UserAccess():
def __init__(self,
user_name: str,
project_name: str):
self.user_name = user_name
self.project_name = project_name

def to_snake(d, level: int = 0):
"""
Convert `string`, `list[string]`, or all keys in a `dict` into snake case
The maximum length of input string or list is 100, or it will be truncated before being processed, for dict, the exception will be thrown if it has more than 100 keys.
the maximum nested level is 10, otherwise the exception will be thrown
"""
if level >= 10:
raise ValueError("Too many nested levels")
if isinstance(d, str):
d = d[:100]
return re.sub(r'(?<!^)(?=[A-Z])', '_', d).lower()
if isinstance(d, list):
d = d[:100]
return [to_snake(i, level + 1) if isinstance(i, (dict, list)) else i for i in d]
if len(d) > 100:
raise ValueError("Dict has too many keys")
return {to_snake(a, level + 1): to_snake(b, level + 1) if isinstance(b, (dict, list)) else b for a, b in d.items()}



def _to_type(value, type):
"""
Convert `value` into `type`,
or `list[type]` if `value` is a list
NOTE: This is **not** a generic implementation, only for objects in this module
"""
if isinstance(value, type):
return value
if isinstance(value, list):
return list([_to_type(v, type) for v in value])
if isinstance(value, dict):
if hasattr(type, "new"):
try:
# The convention is to use `new` method to create the object from a dict
return type.new(**to_snake(value))
except TypeError:
pass
return type(**to_snake(value))
if issubclass(type, Enum):
try:
n = int(value)
return type(n)
except ValueError:
pass
if hasattr(type, "new"):
try:
# As well as Enum types, some of them have alias that cannot be handled by default Enum constructor
return type.new(value)
except KeyError:
pass
return type[value]
return type(value)


def _to_uuid(value):
return _to_type(value, UUID)
5 changes: 5 additions & 0 deletions registry/purview-registry/api-spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ List **names** of all projects.

Response Type: `array<string>`

### `GET /projects-ids`
Dictionary of **id** to **names** mapping of all projects.

Response Type: `dict`

### `GET /projects/{project}`
Get everything defined in the project

Expand Down
3 changes: 3 additions & 0 deletions registry/purview-registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def to_camel(s):
def get_projects() -> list[str]:
return registry.get_projects()

@router.get("/projects-ids")
def get_projects_ids() -> dict:
return registry.get_projects_ids()

@router.get("/projects/{project}",tags=["Project"])
def get_projects(project: str) -> dict:
Expand Down
Loading