From 55d3d67b3dcac3d3fcb5bffaa5fa84d5d0965ede Mon Sep 17 00:00:00 2001 From: Eleftherios Zisis Date: Wed, 21 May 2025 11:48:37 +0200 Subject: [PATCH] Fix missing ion channel models in memodel's emodel --- app/schemas/emodel.py | 3 -- app/service/emodel.py | 10 ++--- app/service/memodel.py | 1 + tests/conftest.py | 34 +++++++-------- tests/test_emodel.py | 78 ++++++++++++++++++++--------------- tests/test_memodel.py | 94 ++++++++++++++++++++++-------------------- 6 files changed, 116 insertions(+), 104 deletions(-) diff --git a/app/schemas/emodel.py b/app/schemas/emodel.py index 580998b9..7aef0788 100644 --- a/app/schemas/emodel.py +++ b/app/schemas/emodel.py @@ -56,7 +56,4 @@ class EModelRead( mtypes: list[MTypeClassRead] | None etypes: list[ETypeClassRead] | None exemplar_morphology: ExemplarMorphology - - -class EModelReadExpanded(EModelRead, AssetsMixin): ion_channel_models: list[IonChannelModelWAssets] diff --git a/app/service/emodel.py b/app/service/emodel.py index 942bcb6f..b175db70 100644 --- a/app/service/emodel.py +++ b/app/service/emodel.py @@ -28,7 +28,7 @@ from app.dependencies.db import SessionDep from app.filters.emodel import EModelFilterDep from app.queries.common import router_create_one, router_read_many, router_read_one -from app.schemas.emodel import EModelCreate, EModelRead, EModelReadExpanded +from app.schemas.emodel import EModelCreate, EModelRead from app.schemas.types import ListResponse if TYPE_CHECKING: @@ -61,13 +61,13 @@ def read_one( user_context: UserContextDep, db: SessionDep, id_: uuid.UUID, -) -> EModelReadExpanded: +) -> EModelRead: return router_read_one( id_=id_, db=db, db_model_class=EModel, authorized_project_id=user_context.project_id, - response_schema_class=EModelReadExpanded, + response_schema_class=EModelRead, apply_operations=_load, ) @@ -96,7 +96,7 @@ def read_many( with_search: SearchDep, facets: FacetsDep, in_brain_region: InBrainRegionDep, -) -> ListResponse[EModelReadExpanded]: +) -> ListResponse[EModelRead]: morphology_alias = aliased(ReconstructionMorphology, flat=True) agent_alias = aliased(Agent, flat=True) created_by_alias = aliased(Agent, flat=True) @@ -170,7 +170,7 @@ def read_many( apply_filter_query_operations=None, apply_data_query_operations=_load, pagination_request=pagination_request, - response_schema_class=EModelReadExpanded, + response_schema_class=EModelRead, name_to_facet_query_params=name_to_facet_query_params, filter_model=emodel_filter, filter_joins=filter_joins, diff --git a/app/service/memodel.py b/app/service/memodel.py index f2a654cc..4073b858 100644 --- a/app/service/memodel.py +++ b/app/service/memodel.py @@ -57,6 +57,7 @@ def _load(select: Select): joinedload(EModel.createdBy), joinedload(EModel.updatedBy), selectinload(EModel.assets), + selectinload(EModel.ion_channel_models), ), joinedload(MEModel.morphology).options( joinedload(ReconstructionMorphology.brain_region), diff --git a/tests/conftest.py b/tests/conftest.py index 14c562f7..80e24f77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -444,27 +444,27 @@ def emodel_id(create_emodel_ids: CreateIds) -> str: @pytest.fixture def create_memodel_ids( - db, morphology_id, brain_region_id, species_id, strain_id, emodel_id, agents + db, client, morphology_id, brain_region_id, species_id, strain_id, emodel_id, agents ) -> CreateIds: def _create_memodel_ids(count: int): memodel_ids: list[str] = [] for i in range(count): - memodel_id = add_db( - db, - MEModel( - name=f"{i}", - description=f"{i}_description", - brain_region_id=brain_region_id, - species_id=species_id, - strain_id=strain_id, - morphology_id=morphology_id, - emodel_id=emodel_id, - authorized_public=False, - authorized_project_id=PROJECT_ID, - holding_current=0, - threshold_current=0, - ), - ).id + memodel_id = assert_request( + client.post, + url="/memodel", + json={ + "name": f"{i}", + "description": f"{i}_description", + "brain_region_id": str(brain_region_id), + "species_id": str(species_id), + "strain_id": str(strain_id), + "morphology_id": str(morphology_id), + "emodel_id": str(emodel_id), + "authorized_public": False, + "holding_current": 0, + "threshold_current": 0, + }, + ).json()["id"] add_contributions(db, agents, memodel_id) diff --git a/tests/test_emodel.py b/tests/test_emodel.py index 691c7208..3a6f55e7 100644 --- a/tests/test_emodel.py +++ b/tests/test_emodel.py @@ -1,59 +1,69 @@ import itertools as it import uuid +import pytest from fastapi.testclient import TestClient +from app.db.model import EModel from app.db.types import EntityType from .conftest import CreateIds, EModelIds -from .utils import create_reconstruction_morphology_id +from .utils import assert_request, create_reconstruction_morphology_id from tests.routers.test_asset import _upload_entity_asset ROUTE = "/emodel" -def test_create_emodel(client: TestClient, species_id, strain_id, brain_region_id, morphology_id): - response = client.post( - ROUTE, - json={ - "brain_region_id": str(brain_region_id), - "species_id": species_id, - "strain_id": strain_id, - "description": "Test EModel Description", - "name": "Test EModel Name", - "iteration": "test iteration", - "score": -1, - "seed": -1, - "exemplar_morphology_id": morphology_id, - }, - ) - assert response.status_code == 200, f"Failed to create emodel: {response.text}" - data = response.json() - assert data["brain_region"]["id"] == str(brain_region_id), ( - f"Failed to get id for emodel: {data}" - ) - assert data["species"]["id"] == species_id, f"Failed to get species_id for emodel: {data}" - assert data["strain"]["id"] == strain_id, f"Failed to get strain_id for emodel: {data}" +@pytest.fixture +def json_data(db, emodel_id): + emodel = db.get(EModel, emodel_id) + return { + "brain_region_id": str(emodel.brain_region_id), + "species_id": str(emodel.species_id), + "strain_id": str(emodel.strain_id), + "description": emodel.description, + "name": emodel.name, + "iteration": emodel.iteration, + "score": emodel.score, + "seed": emodel.seed, + "exemplar_morphology_id": str(emodel.exemplar_morphology_id), + } + + +def _assert_read_response(data, json_data): + assert data["name"] == json_data["name"] + assert data["description"] == json_data["description"] + assert data["brain_region"]["id"] == json_data["brain_region_id"] + assert data["species"]["id"] == json_data["species_id"] + assert data["strain"]["id"] == json_data["strain_id"] assert data["createdBy"]["id"] == data["updatedBy"]["id"] + assert data["exemplar_morphology"]["id"] == json_data["exemplar_morphology_id"] + assert data["iteration"] == json_data["iteration"] + assert data["score"] == json_data["score"] + assert data["seed"] == json_data["seed"] + assert "ion_channel_models" in data + assert "assets" in data - response = client.get(ROUTE) - assert response.status_code == 200, f"Failed to get emodels: {response.text}" - data = response.json()["data"] - assert data[0]["createdBy"]["id"] == data[0]["updatedBy"]["id"] +def test_create_emodel(client: TestClient, json_data): + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) + + data = assert_request(client.get, url=ROUTE).json()["data"] + _assert_read_response(data[0], json_data) -def test_get_emodel(client: TestClient, emodel_id: str): + +def test_get_emodel(client: TestClient, emodel_id: str, json_data): _upload_entity_asset(client, EntityType.emodel, uuid.UUID(emodel_id)) - response = client.get(f"{ROUTE}/{emodel_id}") + data = assert_request( + client.get, + url=f"{ROUTE}/{emodel_id}", + ).json() + _assert_read_response(data, json_data) - assert response.status_code == 200 - data = response.json() assert data["id"] == emodel_id - assert "assets" in data assert len(data["assets"]) == 1 - assert "ion_channel_models" in data - assert data["createdBy"]["id"] == data["updatedBy"]["id"] def test_missing(client): diff --git a/tests/test_memodel.py b/tests/test_memodel.py index 8651fcae..706832bc 100644 --- a/tests/test_memodel.py +++ b/tests/test_memodel.py @@ -1,6 +1,7 @@ import operator as op import uuid +import pytest from fastapi.testclient import TestClient from app.db.model import MEModel @@ -9,6 +10,7 @@ from .conftest import CreateIds, MEModels from .utils import ( PROJECT_ID, + assert_request, check_brain_region_filter, create_reconstruction_morphology_id, ) @@ -16,21 +18,50 @@ ROUTE = "/memodel" -def test_get_memodel(client: TestClient, memodel_id): - response = client.get(f"{ROUTE}/{memodel_id}") - assert response.status_code == 200 - data = response.json() - assert data["id"] == memodel_id - assert "morphology" in data - assert "emodel" in data - assert "brain_region" in data - assert "species" in data - assert "strain" in data +@pytest.fixture +def json_data(db, memodel_id): + me_model = db.get(MEModel, memodel_id) + return { + "brain_region_id": str(me_model.brain_region_id), + "species_id": str(me_model.species_id), + "strain_id": str(me_model.strain_id), + "description": me_model.description, + "name": me_model.name, + "morphology_id": str(me_model.morphology_id), + "emodel_id": str(me_model.emodel_id), + "holding_current": me_model.holding_current, + "threshold_current": me_model.threshold_current, + } + + +def _assert_read_response(data, json_data): + assert data["name"] == json_data["name"] + assert data["description"] == json_data["description"] + assert data["brain_region"]["id"] == json_data["brain_region_id"] + assert data["species"]["id"] == json_data["species_id"] + assert data["strain"]["id"] == json_data["strain_id"] + assert data["createdBy"]["id"] == data["updatedBy"]["id"] + assert data["threshold_current"] == json_data["threshold_current"] + assert data["holding_current"] == json_data["holding_current"] + assert data["emodel"]["id"] == json_data["emodel_id"] + assert "ion_channel_models" in data["emodel"] + assert data["morphology"]["id"] == json_data["morphology_id"] + assert "assets" in data["emodel"] + assert "assets" in data["morphology"] assert "mtypes" in data assert "etypes" in data MEModelRead.model_validate(data) +def test_get_memodel(client: TestClient, memodel_id, json_data): + data = assert_request( + client.get, + url=f"{ROUTE}/{memodel_id}", + ).json() + _assert_read_response(data, json_data) + assert data["id"] == memodel_id + + def test_missing(client): response = client.get(f"{ROUTE}/{uuid.uuid4()}") assert response.status_code == 404 @@ -41,43 +72,16 @@ def test_missing(client): def test_create_memodel( client: TestClient, - species_id: str, - strain_id: str, - brain_region_id: int, - morphology_id: str, - emodel_id: str, + json_data, ): - response = client.post( - ROUTE, - json={ - "brain_region_id": str(brain_region_id), - "species_id": species_id, - "strain_id": strain_id, - "description": "Test MEModel Description", - "name": "Test MEModel Name", - "morphology_id": morphology_id, - "emodel_id": emodel_id, - "holding_current": 0, - "threshold_current": 0, - }, - ) - assert response.status_code == 200, f"Failed to create memodel: {response.text}" - data = response.json() - assert data["brain_region"]["id"] == str(brain_region_id), ( - f"Failed to get id for memodel: {data}" - ) - assert data["species"]["id"] == species_id, f"Failed to get species_id for memodel: {data}" - assert data["strain"]["id"] == strain_id, f"Failed to get strain_id for memodel: {data}" - assert "assets" in data["emodel"] - assert "assets" in data["morphology"] - assert data["createdBy"]["id"] == data["updatedBy"]["id"] + data = assert_request(client.post, url=ROUTE, json=json_data).json() + _assert_read_response(data, json_data) - response = client.get(f"{ROUTE}/{data['id']}") - assert response.status_code == 200, f"Failed to get morphologys: {response.text}" - data = response.json() - assert "assets" in data["emodel"] - assert "assets" in data["morphology"] - assert data["createdBy"]["id"] == data["updatedBy"]["id"] + data = assert_request(client.get, url=f"{ROUTE}/{data['id']}").json() + _assert_read_response(data, json_data) + + data = assert_request(client.get, url=ROUTE).json()["data"][0] + _assert_read_response(data, json_data) def test_facets(client: TestClient, faceted_memodels: MEModels):