diff --git a/alembic/versions/20250602_145126_8f0d631d2bd4_default_migration_message.py b/alembic/versions/20250602_145126_8f0d631d2bd4_default_migration_message.py new file mode 100644 index 00000000..0d5b61ef --- /dev/null +++ b/alembic/versions/20250602_145126_8f0d631d2bd4_default_migration_message.py @@ -0,0 +1,45 @@ +"""Default migration message + +Revision ID: 8f0d631d2bd4 +Revises: 1589bff44728 +Create Date: 2025-06-02 14:51:26.859717 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +from sqlalchemy import Text +import app.db.types + +# revision identifiers, used by Alembic. +revision: str = "8f0d631d2bd4" +down_revision: Union[str, None] = "1589bff44728" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "derivation", + sa.Column("used_id", sa.Uuid(), nullable=False), + sa.Column("generated_id", sa.Uuid(), nullable=False), + sa.ForeignKeyConstraint( + ["generated_id"], ["entity.id"], name=op.f("fk_derivation_generated_id_entity") + ), + sa.ForeignKeyConstraint( + ["used_id"], ["entity.id"], name=op.f("fk_derivation_used_id_entity") + ), + sa.PrimaryKeyConstraint("used_id", "generated_id", name=op.f("pk_derivation")), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("derivation") + # ### end Alembic commands ### diff --git a/app/cli/import_data.py b/app/cli/import_data.py index 731639e4..7c560867 100644 --- a/app/cli/import_data.py +++ b/app/cli/import_data.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from collections import Counter, defaultdict from contextlib import closing -from operator import attrgetter from pathlib import Path from typing import Any @@ -18,7 +17,7 @@ from app.cli import curate, utils from app.cli.brain_region_data import BRAIN_ATLAS_REGION_VOLUMES -from app.cli.curation import electrical_cell_recording +from app.cli.curation import cell_composition, electrical_cell_recording from app.cli.types import ContentType from app.cli.utils import ( AUTHORIZED_PUBLIC, @@ -37,6 +36,7 @@ BrainRegionHierarchy, CellComposition, DataMaturityAnnotationBody, + Derivation, ElectricalCellRecording, EModel, Entity, @@ -71,10 +71,6 @@ from app.logger import L from app.schemas.base import ProjectContext -from app.cli.curation import electrical_cell_recording, cell_composition -from app.cli.types import ContentType - - BRAIN_ATLAS_NAME = "BlueBrain Atlas" REQUIRED_PATH = click.Path(exists=True, readable=True, dir_okay=False, resolve_path=True) @@ -561,6 +557,75 @@ def ingest( db.commit() +class ImportEModelDerivations(Import): + name = "EModelWorkflow" + + @staticmethod + def is_correct_type(data): + types = ensurelist(data["@type"]) + return "EModelWorkflow" in types + + @staticmethod + def ingest( + db, + project_context, + data_list: list[dict], + all_data_by_id: dict[str, dict], + hierarchy_name: str, + ): + """Import emodel derivations from EModelWorkflow.""" + legacy_emodel_ids = set() + derivations = {} + for data in tqdm(data_list, desc="EModelWorkflow"): + legacy_emodel_id = utils.find_id_in_entity(data, "EModel", "generates") + legacy_etc_id = utils.find_id_in_entity( + data, "ExtractionTargetsConfiguration", "hasPart" + ) + if not legacy_emodel_id: + L.warning("Not found EModel id in EModelWorkflow: {}", data["@id"]) + continue + if not legacy_etc_id: + L.warning( + "Not found ExtractionTargetsConfiguration id in EModelWorkflow: {}", data["@id"] + ) + continue + if not (etc := all_data_by_id.get(legacy_etc_id)): + L.warning("Not found ExtractionTargetsConfiguration with id {}", legacy_etc_id) + continue + if not (legacy_trace_ids := list(utils.find_ids_in_entity(etc, "Trace", "uses"))): + L.warning( + "Not found traces in ExtractionTargetsConfiguration with id {}", legacy_etc_id + ) + continue + if legacy_emodel_id in legacy_emodel_ids: + L.warning("Duplicated and ignored traces for EModel id {}", legacy_emodel_id) + continue + legacy_emodel_ids.add(legacy_emodel_id) + if not (emodel := utils._find_by_legacy_id(legacy_emodel_id, EModel, db)): + L.warning("Not found EModel with legacy id {}", legacy_emodel_id) + continue + if emodel.id in derivations: + L.warning("Duplicated and ignored traces for EModel uuid {}", emodel.id) + derivations[emodel.id] = [ + utils._find_by_legacy_id(legacy_trace_id, ElectricalCellRecording, db).id + for legacy_trace_id in legacy_trace_ids + ] + + rows = [ + Derivation(used_id=trace_id, generated_id=emodel_id) + for emodel_id, trace_ids in derivations.items() + for trace_id in trace_ids + ] + L.info( + "Imported derivations for {} EModels from {} records", len(derivations), len(data_list) + ) + # delete everything from derivation table before adding the records + query = sa.delete(Derivation) + db.execute(query) + db.add_all(rows) + db.commit() + + class ImportBrainAtlas(Import): name = "BrainAtlas" @@ -1370,6 +1435,7 @@ def _do_import(db, input_dir, project_context, hierarchy_name): ImportBrainAtlas, ImportDistribution, ImportNeuronMorphologyFeatureAnnotation, + ImportEModelDerivations, ] for importer in importers: diff --git a/app/cli/utils.py b/app/cli/utils.py index 3eae6483..374f9135 100644 --- a/app/cli/utils.py +++ b/app/cli/utils.py @@ -361,13 +361,14 @@ def get_or_create_distribution( def find_id_in_entity(entity: dict | None, type_: str, entity_list_key: str): if not entity: return None - return next( - ( - part.get("@id") - for part in ensurelist(entity.get(entity_list_key, [])) - if is_type(part, type_) - ), - None, + return next(find_ids_in_entity(entity, type_, entity_list_key), None) + + +def find_ids_in_entity(entity: dict, type_: str, entity_list_key: str): + return ( + part.get("@id") + for part in ensurelist(entity.get(entity_list_key, [])) + if is_type(part, type_) ) diff --git a/app/db/model.py b/app/db/model.py index e353d37b..02fd425f 100644 --- a/app/db/model.py +++ b/app/db/model.py @@ -919,3 +919,11 @@ class CellComposition(NameDescriptionVectorMixin, LocationMixin, SpeciesMixin, E __tablename__ = EntityType.cell_composition id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) __mapper_args__ = {"polymorphic_identity": __tablename__} # noqa: RUF012 + + +class Derivation(Base): + __tablename__ = "derivation" + used_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) + generated_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("entity.id"), primary_key=True) + used: Mapped["Entity"] = relationship(foreign_keys=[used_id]) + generated: Mapped["Entity"] = relationship(foreign_keys=[generated_id]) diff --git a/app/filters/entity.py b/app/filters/entity.py new file mode 100644 index 00000000..066ac67a --- /dev/null +++ b/app/filters/entity.py @@ -0,0 +1,20 @@ +from typing import Annotated + +from fastapi_filter import FilterDepends + +from app.db.model import Entity +from app.db.types import EntityType +from app.filters.base import CustomFilter + + +class BasicEntityFilter(CustomFilter): + type: EntityType | None = None + + order_by: list[str] = ["-creation_date"] # noqa: RUF012 + + class Constants(CustomFilter.Constants): + model = Entity + ordering_model_fields = ["creation_date", "update_date", "name"] # noqa: RUF012 + + +BasicEntityFilterDep = Annotated[BasicEntityFilter, FilterDepends(BasicEntityFilter)] diff --git a/app/routers/__init__.py b/app/routers/__init__.py index 1d8d7c79..4a17bf63 100644 --- a/app/routers/__init__.py +++ b/app/routers/__init__.py @@ -10,6 +10,7 @@ brain_region_hierarchy, cell_composition, contribution, + derivation, electrical_cell_recording, emodel, etype, @@ -45,6 +46,7 @@ brain_region_hierarchy.router, cell_composition.router, contribution.router, + derivation.router, electrical_cell_recording.router, emodel.router, etype.router, diff --git a/app/routers/asset.py b/app/routers/asset.py index e8d84c00..d00016a0 100644 --- a/app/routers/asset.py +++ b/app/routers/asset.py @@ -1,16 +1,15 @@ """Generic asset routes.""" import uuid -from enum import StrEnum from http import HTTPStatus from pathlib import Path -from typing import TYPE_CHECKING, Annotated +from typing import Annotated from fastapi import APIRouter, Form, HTTPException, UploadFile, status from starlette.responses import RedirectResponse from app.config import settings -from app.db.types import AssetLabel, EntityType +from app.db.types import AssetLabel from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep from app.dependencies.db import RepoGroupDep from app.dependencies.s3 import S3ClientDep @@ -19,6 +18,7 @@ from app.schemas.types import ListResponse, PaginationResponse from app.service import asset as asset_service from app.utils.files import calculate_sha256_digest, get_content_type +from app.utils.routers import EntityRoute, entity_route_to_type from app.utils.s3 import ( delete_from_s3, generate_presigned_url, @@ -32,18 +32,6 @@ tags=["assets"], ) -if not TYPE_CHECKING: - # EntityRoute (hyphen-separated) <-> EntityType (underscore_separated) - EntityRoute = StrEnum( - "EntityRoute", {item.name: item.name.replace("_", "-") for item in EntityType} - ) -else: - EntityRoute = StrEnum - - -def _entity_route_to_type(entity_route: EntityRoute) -> EntityType: - return EntityType[entity_route.name] - @router.get("/{entity_route}/{entity_id}/assets") def get_entity_assets( @@ -56,7 +44,7 @@ def get_entity_assets( assets = asset_service.get_entity_assets( repos, user_context=user_context, - entity_type=_entity_route_to_type(entity_route), + entity_type=entity_route_to_type(entity_route), entity_id=entity_id, ) # TODO: proper pagination @@ -76,7 +64,7 @@ def get_entity_asset( return asset_service.get_entity_asset( repos, user_context=user_context, - entity_type=_entity_route_to_type(entity_route), + entity_type=entity_route_to_type(entity_route), entity_id=entity_id, asset_id=asset_id, ) @@ -117,7 +105,7 @@ def upload_entity_asset( asset_read = asset_service.create_entity_asset( repos=repos, user_context=user_context, - entity_type=_entity_route_to_type(entity_route), + entity_type=entity_route_to_type(entity_route), entity_id=entity_id, filename=file.filename, content_type=content_type, @@ -149,7 +137,7 @@ def download_entity_asset( asset = asset_service.get_entity_asset( repos, user_context=user_context, - entity_type=_entity_route_to_type(entity_route), + entity_type=entity_route_to_type(entity_route), entity_id=entity_id, asset_id=asset_id, ) @@ -205,7 +193,7 @@ def delete_entity_asset( asset = asset_service.delete_entity_asset( repos, user_context=user_context, - entity_type=_entity_route_to_type(entity_route), + entity_type=entity_route_to_type(entity_route), entity_id=entity_id, asset_id=asset_id, ) diff --git a/app/routers/derivation.py b/app/routers/derivation.py new file mode 100644 index 00000000..c5be93cc --- /dev/null +++ b/app/routers/derivation.py @@ -0,0 +1,12 @@ +"""Generic derivation routes.""" + +from fastapi import APIRouter + +import app.service.derivation + +router = APIRouter( + prefix="", + tags=["derivation"], +) + +router.get("/{entity_route}/{entity_id}/derived-from")(app.service.derivation.read_many) diff --git a/app/routers/emodel.py b/app/routers/emodel.py index 49b5fed0..ffb9edbd 100644 --- a/app/routers/emodel.py +++ b/app/routers/emodel.py @@ -1,5 +1,6 @@ from fastapi import APIRouter +import app.service.electrical_cell_recording import app.service.emodel router = APIRouter( diff --git a/app/schemas/base.py b/app/schemas/base.py index d1efb98d..f9fd1b61 100644 --- a/app/schemas/base.py +++ b/app/schemas/base.py @@ -109,3 +109,7 @@ class LicensedCreateMixin(BaseModel): class LicensedReadMixin(BaseModel): model_config = ConfigDict(from_attributes=True) license: LicenseRead | None + + +class BasicEntityRead(IdentifiableMixin, EntityTypeMixin): + pass diff --git a/app/service/derivation.py b/app/service/derivation.py new file mode 100644 index 00000000..e7cd3c6a --- /dev/null +++ b/app/service/derivation.py @@ -0,0 +1,54 @@ +"""Generic derivation service.""" + +import uuid + +from sqlalchemy import and_ +from sqlalchemy.orm import aliased + +from app.db.model import Derivation, Entity +from app.dependencies.auth import UserContextDep +from app.dependencies.common import PaginationQuery +from app.dependencies.db import SessionDep +from app.filters.entity import BasicEntityFilterDep +from app.queries.common import router_read_many +from app.schemas.base import BasicEntityRead +from app.schemas.types import ListResponse +from app.utils.routers import EntityRoute, entity_route_to_type + + +def read_many( + *, + user_context: UserContextDep, + db: SessionDep, + entity_route: EntityRoute, + entity_id: uuid.UUID, + pagination_request: PaginationQuery, + entity_filter: BasicEntityFilterDep, +) -> ListResponse[BasicEntityRead]: + """Return a list of basic entities used to generate the specified entity.""" + db_model_class = Entity + generated_alias = aliased(Entity, flat=True, name="generated_alias") + entity_type = entity_route_to_type(entity_route) + # always needed regardless of the filter, so they cannot go to filter_keys + apply_filter_query_operations = ( + lambda q: q.join(Derivation, db_model_class.id == Derivation.used_id) + .join(generated_alias, Derivation.generated_id == generated_alias.id) + .where(and_(generated_alias.id == entity_id, generated_alias.type == entity_type)) + ) + name_to_facet_query_params = filter_joins = None + return router_read_many( + db=db, + db_model_class=db_model_class, + authorized_project_id=user_context.project_id, + with_search=None, + with_in_brain_region=None, + facets=None, + aliases={}, + apply_filter_query_operations=apply_filter_query_operations, + apply_data_query_operations=None, + pagination_request=pagination_request, + response_schema_class=BasicEntityRead, + name_to_facet_query_params=name_to_facet_query_params, + filter_model=entity_filter, + filter_joins=filter_joins, + ) diff --git a/app/utils/routers.py b/app/utils/routers.py new file mode 100644 index 00000000..ffda91b1 --- /dev/null +++ b/app/utils/routers.py @@ -0,0 +1,16 @@ +from enum import StrEnum +from typing import TYPE_CHECKING + +from app.db.types import EntityType + +if not TYPE_CHECKING: + # EntityRoute (hyphen-separated) <-> EntityType (underscore_separated) + EntityRoute = StrEnum( + "EntityRoute", {item.name: item.name.replace("_", "-") for item in EntityType} + ) +else: + EntityRoute = StrEnum + + +def entity_route_to_type(entity_route: EntityRoute) -> EntityType: + return EntityType[entity_route.name] diff --git a/tests/conftest.py b/tests/conftest.py index 8c160734..1319d709 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,6 +55,7 @@ ClientProxy, add_db, assert_request, + create_electrical_cell_recording_id_with_assets, ) @@ -683,3 +684,26 @@ def faceted_memodels(db: Session, client: TestClient, agents: tuple[Agent, Agent brain_region_ids=brain_region_ids, agent_ids=agent_ids, ) + + +@pytest.fixture +def electrical_cell_recording_json_data(brain_region_id, subject_id, license_id): + return { + "name": "my-name", + "description": "my-description", + "subject_id": subject_id, + "brain_region_id": str(brain_region_id), + "license_id": str(license_id), + "recording_location": ["soma[0]_0.5"], + "recording_type": "intracellular", + "recording_origin": "in_vivo", + "ljp": 11.5, + "authorized_public": False, + } + + +@pytest.fixture +def trace_id_with_assets(db, client, tmp_path, electrical_cell_recording_json_data): + return create_electrical_cell_recording_id_with_assets( + db, client, tmp_path, electrical_cell_recording_json_data + ) diff --git a/tests/test_derivation.py b/tests/test_derivation.py new file mode 100644 index 00000000..62f3557a --- /dev/null +++ b/tests/test_derivation.py @@ -0,0 +1,32 @@ +from app.db.model import Derivation + +from tests.utils import add_all_db, assert_response, create_electrical_cell_recording_id + + +def test_get_derived_from(db, client, create_emodel_ids, electrical_cell_recording_json_data): + # create two emodels, one with derivations and one without + generated_emodel_id, other_emodel_id = create_emodel_ids(2) + trace_ids = [ + create_electrical_cell_recording_id( + client, json_data=electrical_cell_recording_json_data | {"name": f"name-{i}"} + ) + for i in range(2) + ] + derivations = [ + Derivation(used_id=ecr_id, generated_id=generated_emodel_id) for ecr_id in trace_ids + ] + add_all_db(db, derivations) + + response = client.get(url=f"/emodel/{generated_emodel_id}/derived-from") + + assert_response(response, 200) + data = response.json()["data"] + assert len(data) == 2 + assert [data[0]["id"], data[1]["id"]] == [str(id_) for id_ in reversed(trace_ids)] + assert all(d["type"] == "electrical_cell_recording" for d in data) + + response = client.get(url=f"/emodel/{other_emodel_id}/derived-from") + + assert_response(response, 200) + data = response.json()["data"] + assert len(data) == 0 diff --git a/tests/test_electrical_cell_recording.py b/tests/test_electrical_cell_recording.py index 25f46dc9..1ded43a1 100644 --- a/tests/test_electrical_cell_recording.py +++ b/tests/test_electrical_cell_recording.py @@ -2,113 +2,29 @@ import pytest -from app.db.model import ElectricalCellRecording, ElectricalRecordingStimulus from app.db.types import EntityType from .utils import ( PROJECT_ID, - add_db, assert_request, check_authorization, check_brain_region_filter, check_missing, - create_asset_file, create_brain_region, + create_electrical_cell_recording_db, + create_electrical_cell_recording_id, ) -MODEL = ElectricalCellRecording ROUTE = "electrical-cell-recording" -@pytest.fixture -def json_data(brain_region_id, subject_id, license_id): - return { - "name": "my-name", - "description": "my-description", - "subject_id": subject_id, - "brain_region_id": str(brain_region_id), - "license_id": str(license_id), - "recording_location": ["soma[0]_0.5"], - "recording_type": "intracellular", - "recording_origin": "in_vivo", - "ljp": 11.5, - "authorized_public": False, - } - - -@pytest.fixture -def create(client, json_data): - def _create(**kwargs): - return assert_request(client.post, url=ROUTE, json=json_data | kwargs).json() - - return _create - - -@pytest.fixture -def create_id(create): - def _create_id(**kwargs): - return create(**kwargs)["id"] - - return _create_id - - -@pytest.fixture -def create_db(db, create_id): - def _create_db(**kwargs): - return db.get(MODEL, create_id(**kwargs)) - - return _create_db - - -def _create_electrical_recording_id( - db, - recording_id, +def test_create_one( + client, subject_id, license_id, brain_region_id, electrical_cell_recording_json_data ): - return add_db( - db, - ElectricalRecordingStimulus( - name="protocol", - description="protocol-description", - dt=0.1, - injection_type="current_clamp", - shape="sinusoidal", - start_time=0.0, - end_time=1.0, - recording_id=recording_id, - authorized_public=False, - authorized_project_id=PROJECT_ID, - ), - ).id - - -@pytest.fixture -def trace_id(tmp_path, client, db, create_id): - trace_id = create_id() - - # add two protocols that refer to it - _create_electrical_recording_id(db, trace_id) - _create_electrical_recording_id(db, trace_id) - - filepath = tmp_path / "trace.nwb" - filepath.write_bytes(b"trace") - - # add an asset too - create_asset_file( - client=client, - entity_type="electrical_cell_recording", - entity_id=trace_id, - file_name="my-trace.nwb", - file_obj=filepath.read_bytes(), - ) - - return trace_id - - -def test_create_one(client, subject_id, license_id, brain_region_id, json_data): data = assert_request( client.post, url=ROUTE, - json=json_data, + json=electrical_cell_recording_json_data, ).json() assert data["name"] == "my-name" @@ -121,10 +37,10 @@ def test_create_one(client, subject_id, license_id, brain_region_id, json_data): assert data["created_by"]["id"] == data["updated_by"]["id"] -def test_read_one(client, subject_id, license_id, brain_region_id, trace_id): +def test_read_one(client, subject_id, license_id, brain_region_id, trace_id_with_assets): data = assert_request( client.get, - url=f"{ROUTE}/{trace_id}", + url=f"{ROUTE}/{trace_id_with_assets}", ).json() assert data["name"] == "my-name" @@ -144,12 +60,21 @@ def test_missing(client): check_missing(ROUTE, client) -def test_authorization(client_user_1, client_user_2, client_no_project, json_data): - check_authorization(ROUTE, client_user_1, client_user_2, client_no_project, json_data) +def test_authorization( + client_user_1, client_user_2, client_no_project, electrical_cell_recording_json_data +): + check_authorization( + ROUTE, client_user_1, client_user_2, client_no_project, electrical_cell_recording_json_data + ) -def test_pagination(client, create_id): - _ = [create_id(name=f"entity-{i}") for i in range(2)] +def test_pagination(client, electrical_cell_recording_json_data): + _ = [ + create_electrical_cell_recording_id( + client, json_data=electrical_cell_recording_json_data | {"name": f"entity-{i}"} + ) + for i in range(2) + ] response = assert_request( client.get, url=ROUTE, @@ -163,7 +88,7 @@ def test_pagination(client, create_id): @pytest.fixture -def faceted_ids(db, brain_region_hierarchy_id, create_id): +def faceted_ids(db, client, brain_region_hierarchy_id, electrical_cell_recording_json_data): brain_region_ids = [ create_brain_region( db, brain_region_hierarchy_id, annotation_value=i, name=f"region-{i}" @@ -172,8 +97,14 @@ def faceted_ids(db, brain_region_hierarchy_id, create_id): ] trace_ids = [ - create_id( - name=f"trace-{i}", description=f"brain-region-{i}", brain_region_id=str(region_id) + create_electrical_cell_recording_id( + client, + json_data=electrical_cell_recording_json_data + | { + "name": f"trace-{i}", + "description": f"brain-region-{i}", + "brain_region_id": str(region_id), + }, ) for i, region_id in enumerate(brain_region_ids) ] @@ -225,8 +156,15 @@ def test_facets(client, faceted_ids): ] -def test_brain_region_filter(db, client, brain_region_hierarchy_id, create_db): +def test_brain_region_filter( + db, client, brain_region_hierarchy_id, electrical_cell_recording_json_data +): def create_model_function(_db, name, brain_region_id): - return create_db(name=name, brain_region_id=str(brain_region_id)) + return create_electrical_cell_recording_db( + db, + client, + json_data=electrical_cell_recording_json_data + | {"name": name, "brain_region_id": str(brain_region_id)}, + ) check_brain_region_filter(ROUTE, client, db, brain_region_hierarchy_id, create_model_function) diff --git a/tests/utils.py b/tests/utils.py index c4986fe2..1becb1bc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,8 +11,11 @@ from app.db.model import ( BrainRegion, BrainRegionHierarchy, + ElectricalCellRecording, + ElectricalRecordingStimulus, MTypeClass, MTypeClassification, + ReconstructionMorphology, ) from app.db.types import EntityType from app.routers.asset import EntityRoute @@ -44,6 +47,11 @@ "project-id": UNRELATED_PROJECT_ID, } +ROUTES = { + ReconstructionMorphology: "/reconstruction-morphology", + ElectricalCellRecording: "/electrical-cell-recording", +} + class ClientProxy: """Proxy TestClient to pass default headers without creating a new instance. @@ -80,7 +88,7 @@ def create_reconstruction_morphology_id( description="Test Morphology Description", ): response = client.post( - "/reconstruction-morphology", + ROUTES[ReconstructionMorphology], json={ "name": name, "description": description, @@ -164,6 +172,56 @@ def attach_mtype(db, entity_id, mtype_id): return add_db(db, MTypeClassification(entity_id=str(entity_id), mtype_class_id=str(mtype_id))) +def create_electrical_recording_stimulus_id(db, recording_id): + return add_db( + db, + ElectricalRecordingStimulus( + name="protocol", + description="protocol-description", + dt=0.1, + injection_type="current_clamp", + shape="sinusoidal", + start_time=0.0, + end_time=1.0, + recording_id=recording_id, + authorized_public=False, + authorized_project_id=PROJECT_ID, + ), + ).id + + +def create_electrical_cell_recording_id(client, json_data): + result = assert_request(client.post, url=ROUTES[ElectricalCellRecording], json=json_data).json() + return uuid.UUID(result["id"]) + + +def create_electrical_cell_recording_db(db, client, json_data): + trace_id = create_electrical_cell_recording_id(client, json_data) + return db.get(ElectricalCellRecording, trace_id) + + +def create_electrical_cell_recording_id_with_assets(db, client, tmp_path, json_data): + trace_id = create_electrical_cell_recording_id(client, json_data) + + # add two protocols that refer to it + create_electrical_recording_stimulus_id(db, trace_id) + create_electrical_recording_stimulus_id(db, trace_id) + + filepath = tmp_path / "trace.nwb" + filepath.write_bytes(b"trace") + + # add an asset too + create_asset_file( + client=client, + entity_type="electrical_cell_recording", + entity_id=trace_id, + file_name="my-trace.nwb", + file_obj=filepath.read_bytes(), + ) + + return trace_id + + def check_missing(route, client): assert_request( client.get,