Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions api/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .auth import Authentication
from .db import Database
from .models import User, UserGroup
from .models import User, UserGroup, UserProfile


async def setup_admin_group(db):
Expand All @@ -29,18 +29,22 @@ async def setup_admin_group(db):


async def setup_admin_user(db, username, admin_group):
user_obj = await db.find_one(User, username=username)
user_obj = await db.find_one_by_attributes(User,
{'profile.username': username})
if user_obj:
print(f"User {username} already exists, aborting.")
print(user_obj.json())
return None
password = getpass.getpass(f"Password for user '{args.username}': ")
hashed_password = Authentication.get_password_hash(password)
print(f"Creating {username} user...")
return await db.create(User(
profile = UserProfile(
username=username,
hashed_password=hashed_password,
groups=[admin_group]
)
return await db.create(User(
profile=profile
))


Expand Down
16 changes: 9 additions & 7 deletions api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,17 @@ def verify_password(cls, password_hash, user):
async def authenticate_user(self, username: str, password: str):
"""Authenticate a username / password pair

Look up a `User` in the database with the provided `username` and check
whether the provided clear text `password` matches the hash associated
with it.
Look up a `User` in the database with the provided `username`
and check whether the provided clear text `password` matches the hash
associated with it.
"""
user = await self._db.find_one(User, username=username)
user = await self._db.find_one_by_attributes(
User, {'profile.username': username})
if not user:
return False
if not self.verify_password(password, user):
if not self.verify_password(password, user.profile):
return False
return user
return user.profile

def create_access_token(self, data: dict):
"""Create a JWT access token using the provided arbitrary `data`"""
Expand Down Expand Up @@ -114,7 +115,8 @@ async def get_current_user(self, token, security_scopes):
except JWTError as error:
return None, str(error)

user = await self._db.find_one(User, username=username)
user = await self._db.find_one_by_attributes(
User, {'profile.username': username})
return user, None

async def validate_scopes(self, requested_scopes):
Expand Down
10 changes: 10 additions & 0 deletions api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ async def find_one(self, model, **kwargs):
obj = await col.find_one(kwargs)
return model(**obj) if obj else None

async def find_one_by_attributes(self, model, attributes):
"""Find one object with matching attributes without pagination

The attributes dictionary provides key/value pairs used to find an
object with matching attributes.
"""
col = self._get_collection(model)
obj = await col.find_one(attributes)
return model(**obj) if obj else None

async def find_by_id(self, model, obj_id):
"""Find one object with a given id"""
col = self._get_collection(model)
Expand Down
29 changes: 20 additions & 9 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Regression,
User,
UserGroup,
UserProfile,
Password,
get_model_from_kind
)
Expand Down Expand Up @@ -127,9 +128,9 @@ async def authorize_user(node_id: str, user: User = Depends(get_current_user)):
# user groups will be allowed to update the node
node_from_id = await db.find_by_id(Node, node_id)
if node_from_id.owner:
if not user.username == node_from_id.owner:
if not user.profile.username == node_from_id.owner:
if not any(group.name in node_from_id.user_groups
for group in user.groups):
for group in user.profile.groups):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized to complete the operation"
Expand Down Expand Up @@ -157,11 +158,11 @@ async def post_user(
detail=f"User group does not exist with name: \
{group_name}")
group_obj.append(group)
obj = await db.create(User(
username=username,
hashed_password=hashed_password,
groups=group_obj
))
profile = UserProfile(
username=username,
hashed_password=hashed_password,
groups=group_obj)
obj = await db.create(User(profile=profile))
except DuplicateKeyError as error:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down Expand Up @@ -387,7 +388,7 @@ async def post_node(node: Node, current_user: str = Depends(get_user)):
)

await _verify_user_group_existence(node.user_groups)
node.owner = current_user.username
node.owner = current_user.profile.username
obj = await db.create(node)
data = _get_node_event_data('created', obj)
await pubsub.publish_cloudevent('node', data)
Expand Down Expand Up @@ -423,12 +424,22 @@ async def put_node(node_id: str, node: Node,
return obj


async def _set_node_ownership_recursively(user: User, hierarchy: Hierarchy):
"""Set node ownership information for a hierarchy of nodes"""
if not hierarchy.node.owner:
hierarchy.node.owner = user.profile.username
for node in hierarchy.child_nodes:
await _set_node_ownership_recursively(user, node)


@app.put('/nodes/{node_id}', response_model=List[Node],
response_model_by_alias=False)
async def put_nodes(
node_id: str, nodes: Hierarchy, token: str = Depends(get_user)):
node_id: str, nodes: Hierarchy,
user: str = Depends(authorize_user)):
"""Add a hierarchy of nodes to an existing root node"""
nodes.node.id = ObjectId(node_id)
await _set_node_ownership_recursively(user, nodes)
obj_list = await db.create_hierarchy(nodes, Node)
data = _get_node_event_data('updated', obj_list[0])
await pubsub.publish_cloudevent('node', data)
Expand Down
20 changes: 14 additions & 6 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,31 @@ def create_indexes(cls, collection):
collection.create_index("name", unique=True)


class User(DatabaseModel):
"""API user model"""
class UserProfile(BaseModel):
"""API user profile model"""
username: str
hashed_password: str = Field(description="Hash of the plaintext password")
groups: conlist(UserGroup, unique_items=True) = Field(
default=[],
description="A list of groups that user belongs to"
)


class User(DatabaseModel):
"""API user model
The model will be accessible by admin users only"""
active: bool = Field(
default=True,
description="To check if user is active or not"
)
groups: conlist(UserGroup, unique_items=True) = Field(
default=[],
description="A list of groups that user belongs to"
profile: UserProfile = Field(
description="User profile details accessible by all users"
)

@classmethod
def create_indexes(cls, collection):
"""Create an index to bind unique constraint to username"""
collection.create_index("username", unique=True)
collection.create_index("profile.username", unique=True)


class KernelVersion(BaseModel):
Expand Down
30 changes: 18 additions & 12 deletions tests/e2e_tests/test_user_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import pytest

from api.models import User, UserGroup
from api.models import User, UserGroup, UserProfile
from api.db import Database
from e2e_tests.conftest import db_create

Expand All @@ -35,13 +35,15 @@ async def test_create_admin_user(test_async_client):
hashed_password = response.json()
assert response.status_code == 200

profile = UserProfile(
username=username,
hashed_password=hashed_password,
groups=[UserGroup(name="admin")]
)
obj = await db_create(
Database.COLLECTIONS[User],
User(
username=username,
hashed_password=hashed_password,
groups=[UserGroup(name="admin")]
))
Database.COLLECTIONS[User],
User(profile=profile)
)
assert obj is not None

response = await test_async_client.post(
Expand Down Expand Up @@ -82,8 +84,10 @@ async def test_create_regular_user(test_async_client):
data=json.dumps({'password': password})
)
assert response.status_code == 200
assert ('id', 'username', 'hashed_password', 'active',
'groups') == tuple(response.json().keys())
assert ('id', 'active',
'profile') == tuple(response.json().keys())
assert ('username', 'hashed_password',
'groups') == tuple(response.json()['profile'].keys())

response = await test_async_client.post(
"token",
Expand Down Expand Up @@ -118,9 +122,11 @@ def test_whoami(test_client):
},
)
assert response.status_code == 200
assert ('id', 'username', 'hashed_password', 'active',
'groups') == tuple(response.json().keys())
assert response.json()['username'] == 'test_user'
assert ('id', 'active',
'profile') == tuple(response.json().keys())
assert ('username', 'hashed_password',
'groups') == tuple(response.json()['profile'].keys())
assert response.json()['profile']['username'] == 'test_user'


@pytest.mark.dependency(depends=["test_create_regular_user"])
Expand Down
36 changes: 27 additions & 9 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

from api.main import app
from api.models import User, UserGroup
from api.models import User, UserGroup, UserProfile
from api.pubsub import PubSub

BEARER_TOKEN = "Bearer \
Expand Down Expand Up @@ -105,17 +105,31 @@ def mock_db_find_one(mocker):
return async_mock


@pytest.fixture
def mock_db_find_one_by_attributes(mocker):
"""
Mocks async call to Database class method
used to find an object with matching attributes
"""
async_mock = AsyncMock()
mocker.patch('api.db.Database.find_one_by_attributes',
side_effect=async_mock)
return async_mock


@pytest.fixture
def mock_get_current_user(mocker):
"""
Mocks async call to Authentication class method
used to get current user
"""
async_mock = AsyncMock()
user = User(username='bob',
hashed_password='$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.'
'xCZGmM8jWXUXJZ4K',
active=True)
profile = UserProfile(
username='bob',
hashed_password='$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.'
'xCZGmM8jWXUXJZ4K',
)
user = User(profile=profile, active=True)
mocker.patch('api.auth.Authentication.get_current_user',
side_effect=async_mock)
async_mock.return_value = user, None
Expand All @@ -129,10 +143,14 @@ def mock_get_current_admin_user(mocker):
used to get current user
"""
async_mock = AsyncMock()
user = User(username='admin',
hashed_password='$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.'
'xCZGmM8jWXUXJZ4K',
active=True, groups=[UserGroup(name='admin')])
profile = UserProfile(
username='admin',
hashed_password='$2b$12$CpJZx5ooxM11bCFXT76/z.o6HWs2sPJy4iP8.'
'xCZGmM8jWXUXJZ4K',
groups=[UserGroup(name='admin')])
user = User(
profile=profile,
active=True)
mocker.patch('api.auth.Authentication.get_current_user',
side_effect=async_mock)
async_mock.return_value = user, None
Expand Down
Loading