diff --git a/app/service/derivation.py b/app/service/derivation.py index db9d4c50..acccebc7 100644 --- a/app/service/derivation.py +++ b/app/service/derivation.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import aliased, joinedload, raiseload from app.db.model import Derivation, DerivationType, Entity -from app.db.utils import load_db_model_from_pydantic +from app.db.utils import ENTITY_TYPE_TO_CLASS, load_db_model_from_pydantic from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep from app.dependencies.common import PaginationQuery from app.dependencies.db import SessionDep @@ -18,7 +18,7 @@ ) from app.filters.entity import BasicEntityFilterDep from app.queries.common import router_read_many -from app.queries.entity import get_writable_entity +from app.queries.entity import get_readable_entity, get_writable_entity from app.schemas.base import BasicEntityRead from app.schemas.derivation import DerivationCreate, DerivationRead from app.schemas.types import ListResponse @@ -35,14 +35,26 @@ def read_many( pagination_request: PaginationQuery, entity_filter: BasicEntityFilterDep, ) -> ListResponse[BasicEntityRead]: - """Return a list of basic entities used to generate the specified entity.""" - db_model_class = Entity + """Return a list of basic entities used to generate the specified entity. + + Only the used entities that are accessible by the user are returned. + """ + used_db_model_class = Entity generated_alias = aliased(Entity, flat=True, name="generated_alias") entity_type = entity_route_to_type(entity_route) - # always needed regardless of the filter, so they cannot go to filter_keys + generated_db_model_class = ENTITY_TYPE_TO_CLASS[entity_type] + + # ensure that the requested entity is readable + _ = get_readable_entity( + db, + db_model_class=generated_db_model_class, + entity_id=entity_id, + project_id=user_context.project_id, + ) + # always needed regardless of the filter, so it cannot go to filter_keys apply_filter_query_operations = ( - lambda q: q.join(Derivation, db_model_class.id == Derivation.used_id) + lambda q: q.join(Derivation, used_db_model_class.id == Derivation.used_id) .join(generated_alias, Derivation.generated_id == generated_alias.id) .where( generated_alias.id == entity_id, @@ -54,7 +66,7 @@ def read_many( name_to_facet_query_params = filter_joins = None return router_read_many( db=db, - db_model_class=db_model_class, + db_model_class=used_db_model_class, authorized_project_id=user_context.project_id, with_search=None, with_in_brain_region=None, @@ -75,7 +87,16 @@ def create_one( json_model: DerivationCreate, user_context: UserContextWithProjectIdDep, ) -> DerivationRead: - used_entity = get_writable_entity( + """Create a new derivation from a readable entity (used) to a writable entity (generated). + + Used entity: a readable entity (public in any project, or private in the same project). + Generated entity: a writable entity (public or private, in the same project). + + Even when the parent (used) is private, the child (generated) can be either public or private. + + See also https://github.com/openbraininstitute/entitycore/issues/427 + """ + used_entity = get_readable_entity( db, Entity, json_model.used_id, diff --git a/app/service/hierarchy.py b/app/service/hierarchy.py index 6d40a3b2..fe3dd90e 100644 --- a/app/service/hierarchy.py +++ b/app/service/hierarchy.py @@ -21,12 +21,22 @@ def _load_nodes( derivation_type: DerivationType, ) -> dict[uuid.UUID, HierarchyNode]: root = aliased(entity_class, flat=True, name="root") + root_parent = aliased(entity_class, flat=True, name="root_parent") parent = aliased(entity_class, flat=True, name="parent") child = aliased(entity_class, flat=True, name="child") order_by = ["name", "id"] - matching_derivation_for_root = sa.select(sa.literal(1)).where( - Derivation.generated_id == root.id, - Derivation.derivation_type == derivation_type, + matching_derivation_for_root = ( + sa.select(sa.literal(1)) + .select_from(Derivation) + .join(root_parent, root_parent.id == Derivation.used_id) + .where( + Derivation.generated_id == root.id, + Derivation.derivation_type == derivation_type, + ) + ) + # needed to consider as root also the children with private parents in a different project + matching_derivation_for_root = constrain_to_accessible_entities( + matching_derivation_for_root, project_id=project_id, db_model_class=root_parent ) query_roots = ( sa.select( diff --git a/tests/test_derivation.py b/tests/test_derivation.py index e5d96a3d..4bf65ef4 100644 --- a/tests/test_derivation.py +++ b/tests/test_derivation.py @@ -5,6 +5,8 @@ from app.schemas.api import ErrorResponse from tests.utils import ( + PROJECT_ID, + UNRELATED_PROJECT_ID, add_all_db, assert_request, assert_response, @@ -12,9 +14,10 @@ ) -def test_get_derived_from(db, client, create_emodel_ids, electrical_cell_recording_json_data): +def test_get_derived_from( + db, client, client_user_2, emodel_id, public_emodel_id, electrical_cell_recording_json_data +): # create two emodels, one with derivations and one without - generated_emodel_id, other_emodel_id = create_emodel_ids(2) trace_ids = [ create_electrical_cell_recording_id( client, json_data=electrical_cell_recording_json_data | {"name": f"name-{i}"} @@ -25,21 +28,23 @@ def test_get_derived_from(db, client, create_emodel_ids, electrical_cell_recordi [ Derivation( used_id=ecr_id, - generated_id=generated_emodel_id, + generated_id=public_emodel_id, derivation_type="circuit_extraction", ) for ecr_id in trace_ids[:3] ] + [ Derivation( - used_id=ecr_id, generated_id=generated_emodel_id, derivation_type="circuit_rewiring" + used_id=ecr_id, + generated_id=public_emodel_id, + derivation_type="circuit_rewiring", ) for ecr_id in trace_ids[3:5] ] + [ Derivation( used_id=trace_ids[5], - generated_id=generated_emodel_id, + generated_id=emodel_id, # private derivation_type="unspecified", ) ] @@ -47,7 +52,7 @@ def test_get_derived_from(db, client, create_emodel_ids, electrical_cell_recordi add_all_db(db, derivations) response = client.get( - url=f"/emodel/{generated_emodel_id}/derived-from", + url=f"/emodel/{public_emodel_id}/derived-from", params={"derivation_type": "circuit_extraction"}, ) assert_response(response, 200) @@ -57,7 +62,7 @@ def test_get_derived_from(db, client, create_emodel_ids, electrical_cell_recordi assert all(d["type"] == "electrical_cell_recording" for d in data) response = client.get( - url=f"/emodel/{generated_emodel_id}/derived-from", + url=f"/emodel/{public_emodel_id}/derived-from", params={"derivation_type": "circuit_rewiring"}, ) assert_response(response, 200) @@ -67,7 +72,8 @@ def test_get_derived_from(db, client, create_emodel_ids, electrical_cell_recordi assert all(d["type"] == "electrical_cell_recording" for d in data) response = client.get( - url=f"/emodel/{generated_emodel_id}/derived-from", params={"derivation_type": "unspecified"} + url=f"/emodel/{emodel_id}/derived-from", + params={"derivation_type": "unspecified"}, ) assert_response(response, 200) data = response.json()["data"] @@ -76,27 +82,45 @@ def test_get_derived_from(db, client, create_emodel_ids, electrical_cell_recordi assert data[0]["type"] == "electrical_cell_recording" # Test error not derivation_type param - response = client.get(url=f"/emodel/{generated_emodel_id}/derived-from") + response = client.get(url=f"/emodel/{public_emodel_id}/derived-from") assert_response(response, 422) error = ErrorResponse.model_validate(response.json()) assert error.error_code == ApiErrorCode.INVALID_REQUEST # Test error invalid derivation_type param response = client.get( - url=f"/emodel/{generated_emodel_id}/derived-from", + url=f"/emodel/{public_emodel_id}/derived-from", params={"derivation_type": "invalid_type"}, ) assert_response(response, 422) error = ErrorResponse.model_validate(response.json()) assert error.error_code == ApiErrorCode.INVALID_REQUEST + # Test empty result response = client.get( - url=f"/emodel/{other_emodel_id}/derived-from", params={"derivation_type": "unspecified"} + url=f"/emodel/{public_emodel_id}/derived-from", + params={"derivation_type": "unspecified"}, ) assert_response(response, 200) data = response.json()["data"] assert len(data) == 0 + # Test private unreadable entity + response = client_user_2.get( + url=f"/emodel/{emodel_id}/derived-from", + params={"derivation_type": "unspecified"}, + ) + assert_response(response, 404) + assert response.json()["error_code"] == "ENTITY_NOT_FOUND" + + # Test non existing entity + response = client_user_2.get( + url="/emodel/00000000-0000-0000-0000-000000000000/derived-from", + params={"derivation_type": "unspecified"}, + ) + assert_response(response, 404) + assert response.json()["error_code"] == "ENTITY_NOT_FOUND" + @pytest.mark.parametrize( "derivation_type", @@ -137,17 +161,10 @@ def test_create_invalid_data(client, root_circuit, circuit): assert data["error_code"] == "INVALID_REQUEST" -@pytest.mark.parametrize( - ("client_fixture", "expected_status", "expected_error"), - [ - ("client_user_2", 404, "ENTITY_NOT_FOUND"), - ("client_no_project", 403, "NOT_AUTHORIZED"), - ], -) -def test_create_non_authorized( - request, client_fixture, expected_status, expected_error, root_circuit, circuit -): - client = request.getfixturevalue(client_fixture) +def test_create_without_authorization(client_no_project, root_circuit, circuit): + client = client_no_project + expected_status = 403 + expected_error = "NOT_AUTHORIZED" data = assert_request( client.post, @@ -160,3 +177,94 @@ def test_create_non_authorized( expected_status_code=expected_status, ).json() assert data["error_code"] == expected_error + + +def _create_entities(route, client_user_1, client_user_2, json_data): + """Create public and private entities. + + Created entities: + + public_u1 (PROJECT_ID) + private_u1 (PROJECT_ID) + public_u2 (UNRELATED_PROJECT_ID) + private_u2 (UNRELATED_PROJECT_ID) + """ + public_u1 = assert_request( + client_user_1.post, + url=route, + json=json_data | {"name": "Public u1/0", "authorized_public": True}, + ).json() + assert public_u1["authorized_public"] is True + assert public_u1["authorized_project_id"] == PROJECT_ID + + private_u1 = assert_request( + client_user_1.post, + url=route, + json=json_data | {"name": "Private u1/0", "authorized_public": False}, + ).json() + assert private_u1["authorized_public"] is False + assert private_u1["authorized_project_id"] == PROJECT_ID + + public_u2 = assert_request( + client_user_2.post, + url=route, + json=json_data | {"name": "Public u2/0", "authorized_public": True}, + ).json() + assert public_u2["authorized_public"] is True + assert public_u2["authorized_project_id"] == UNRELATED_PROJECT_ID + + private_u2 = assert_request( + client_user_2.post, + url=route, + json=json_data | {"name": "Private u2/0", "authorized_public": False}, + ).json() + assert private_u2["authorized_public"] is False + assert private_u2["authorized_project_id"] == UNRELATED_PROJECT_ID + + return public_u1, private_u1, public_u2, private_u2 + + +def test_create_with_authorization(client_user_1, client_user_2, root_circuit_json_data): + """Check the authorization when trying to create the derivation.""" + route = "/circuit" + public_u1, private_u1, public_u2, private_u2 = _create_entities( + route, client_user_1, client_user_2, json_data=root_circuit_json_data + ) + + # these calls are done with client_user_1, that can create derivations for u1 entitites only + for i, (used_id, generated_id, expected_status) in enumerate( + [ + (public_u1["id"], public_u1["id"], 200), + (private_u1["id"], public_u1["id"], 200), + (public_u2["id"], public_u1["id"], 200), + (private_u2["id"], public_u1["id"], 404), # used_id cannot be read + (public_u1["id"], private_u1["id"], 200), + (private_u1["id"], private_u1["id"], 200), + (public_u2["id"], private_u1["id"], 200), + (private_u2["id"], private_u1["id"], 404), # used_id cannot be read + (public_u1["id"], public_u2["id"], 404), # generated_id is in a different project + (private_u1["id"], public_u2["id"], 404), # generated_id is in a different project + (public_u2["id"], public_u2["id"], 404), # generated_id is in a different project + (private_u2["id"], public_u2["id"], 404), # generated_id is in a different project + (public_u1["id"], private_u2["id"], 404), # generated_id cannot be read + (private_u1["id"], private_u2["id"], 404), # generated_id cannot be read + (public_u2["id"], private_u2["id"], 404), # generated_id cannot be read + (private_u2["id"], private_u2["id"], 404), # generated_id cannot be read + ] + ): + data = assert_request( + client_user_1.post, + url="/derivation", + json={ + "used_id": used_id, + "generated_id": generated_id, + "derivation_type": "circuit_extraction", + }, + expected_status_code=expected_status, + context=f"Test {i}", + ).json() + if expected_status == 200: + assert data["generated"]["id"] == generated_id, f"Error in test {i}" + assert data["used"]["id"] == used_id, f"Error in test {i}" + elif expected_status == 404: + assert data["error_code"] == "ENTITY_NOT_FOUND", f"Error in test {i}" diff --git a/tests/test_hierarchy.py b/tests/test_hierarchy.py index 3afd9d3c..79a207c3 100644 --- a/tests/test_hierarchy.py +++ b/tests/test_hierarchy.py @@ -100,6 +100,13 @@ def models(db, circuit_json_data, person_id, root_circuits): "authorized_public": False, "root_circuit_id": root_circuits[2].id, }, + { + "scale": CircuitScale.microcircuit, + "build_category": CircuitBuildCategory.em_reconstruction, + "authorized_project_id": UNRELATED_PROJECT_ID, + "authorized_public": True, + "root_circuit_id": None, # it's forbidden to link a private root_circuit_id + }, ] rows = [ Circuit( @@ -148,6 +155,8 @@ def hierarchy(db, root_circuits, models): C6[C6-u2-private] C7[C7-u2-private] + C8[C8-u2-public] + R0 -->|D0| C0 C0 -->|D0| C1 C0 -->|D1| C2 @@ -158,6 +167,7 @@ def hierarchy(db, root_circuits, models): C5 -->|D0| C6 R2 -->|D1| C7 + R2 -->|D0| C8 C0 -->|D1| C4 ``` @@ -175,6 +185,7 @@ def hierarchy(db, root_circuits, models): d0(used_id=r[1].id, generated_id=c[5].id), d0(used_id=c[5].id, generated_id=c[6].id), d1(used_id=r[2].id, generated_id=c[7].id), + d0(used_id=r[2].id, generated_id=c[8].id), d1(used_id=c[0].id, generated_id=c[4].id), # add multiple parents for c4 ] return add_all_db(db, derivations) @@ -214,6 +225,14 @@ def test_hierarchy(db, client_user_1, client_user_2, root_circuit, root_circuits "name": "circuit-3", "parent_id": None, }, + { + "authorized_project_id": str(UNRELATED_PROJECT_ID), + "authorized_public": True, + "children": [], + "id": str(models[8].id), + "name": "circuit-8", + "parent_id": None, + }, { "authorized_project_id": str(PROJECT_ID), "authorized_public": True, @@ -316,6 +335,14 @@ def test_hierarchy(db, client_user_1, client_user_2, root_circuit, root_circuits "name": "circuit-5", "parent_id": None, }, + { + "authorized_project_id": str(UNRELATED_PROJECT_ID), + "authorized_public": True, + "children": [], + "id": str(models[8].id), + "name": "circuit-8", + "parent_id": None, + }, { "authorized_project_id": str(PROJECT_ID), "authorized_public": True, @@ -412,7 +439,16 @@ def test_hierarchy(db, client_user_1, client_user_2, root_circuit, root_circuits { "authorized_project_id": str(UNRELATED_PROJECT_ID), "authorized_public": False, - "children": [], + "children": [ + { + "authorized_project_id": str(UNRELATED_PROJECT_ID), + "authorized_public": True, + "children": [], + "id": str(models[8].id), + "name": "circuit-8", + "parent_id": str(root_circuits[2].id), + }, + ], "id": str(root_circuits[2].id), "name": "root-circuit-2", "parent_id": None, @@ -443,6 +479,14 @@ def test_hierarchy(db, client_user_1, client_user_2, root_circuit, root_circuits "name": "circuit-6", "parent_id": None, }, + { + "authorized_project_id": str(UNRELATED_PROJECT_ID), + "authorized_public": True, + "children": [], + "id": str(models[8].id), + "name": "circuit-8", + "parent_id": None, + }, { "authorized_project_id": str(PROJECT_ID), "authorized_public": True, diff --git a/tests/utils.py b/tests/utils.py index e0ed8a02..235d238d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -262,17 +262,17 @@ def add_all_db(db, rows, *, same_transaction=False): return rows -def assert_response(response, expected_status_code=200): +def assert_response(response, expected_status_code=200, context=None): assert response.status_code == expected_status_code, ( f"Request {response.request.method} {response.request.url}: " f"expected={expected_status_code}, actual={response.status_code}, " - f"content={response.content}" + f"content={response.content}, context={context}" ) -def assert_request(client_method, *, expected_status_code=200, **kwargs): +def assert_request(client_method, *, expected_status_code=200, context=None, **kwargs): response = client_method(**kwargs) - assert_response(response, expected_status_code=expected_status_code) + assert_response(response, expected_status_code=expected_status_code, context=context) return response