Skip to content

Commit

Permalink
Merge pull request #53 from chanzuckerberg/jgadling/relay-node-fix
Browse files Browse the repository at this point in the history
Fix where clause tests and relay node tests
  • Loading branch information
jgadling committed Jun 20, 2024
2 parents 409eb53 + 1eb042d commit 97dbaf6
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 19 deletions.
13 changes: 11 additions & 2 deletions platformics/api/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,17 @@ def resolver(
info: Info,
id: Annotated[strawberry.ID, argument(description="The ID of the object.")],
):
return id.resolve_type(info).resolve_node(
id.node_id,
type_resolvers = []
for selected_type in info.selected_fields[0].selections:
field_type = selected_type.type_condition
type_def = info.schema.get_type_by_name(field_type)
origin = type_def.origin.resolve_type if isinstance(type_def.origin, LazyType) else type_def.origin
assert issubclass(origin, Node)
type_resolvers.append(origin)
# FIXME TODO this only works if we're getting a *single* subclassed `Node` type --
# if we're getting multiple subclass types, we need to resolve them all somehow
return type_resolvers[0].resolve_node(
id,
info=info,
required=not is_optional,
)
Expand Down
12 changes: 11 additions & 1 deletion platformics/api/types/entities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from typing import Any, Iterable

import strawberry
from strawberry.types import Info
Expand Down Expand Up @@ -29,3 +29,13 @@ async def resolve_nodes(cls, *, info: Info, node_ids: Iterable[str], required: b
gql_type: str = cls.__strawberry_definition__.name # type: ignore
sql_model = getattr(db_module, gql_type)
return await dataloader.resolve_nodes(sql_model, node_ids)

@classmethod
async def resolve_node(cls, node_id: str, info: Info, required: bool = False) -> Any:
dataloader = info.context["sqlalchemy_loader"]
db_module = info.context["db_module"]
gql_type: str = cls.__strawberry_definition__.name # type: ignore
sql_model = getattr(db_module, gql_type)
res = await dataloader.resolve_nodes(sql_model, [node_id])
if res:
return res[0]
3 changes: 2 additions & 1 deletion platformics/test_infra/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def update_file_ids(cls) -> None:
session.execute(
sa.text(
f"""UPDATE {entity_name} SET {entity_field_name}_id = file.id
FROM file WHERE {entity_name}.entity_id = file.entity_id""",
FROM file WHERE {entity_name}.entity_id = file.entity_id and file.entity_field_name = :field_name""",
),
{"field_name": entity_field_name},
)
session.commit()
4 changes: 2 additions & 2 deletions test_app/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ init:
$(docker_compose_run) $(APP_CONTAINER) black .
# $(docker_compose_run) $(CONTAINER) ruff check --fix .
$(docker_compose_run) $(APP_CONTAINER) sh -c 'strawberry export-schema main:schema > /app/api/schema.graphql'
sleep 5 # wait for the app to reload after having files updated.
docker compose up -d
sleep 5
sleep 5 # wait for the app to reload after having files updated.
docker compose exec $(APP_CONTAINER) python3 -m sgqlc.introspection --exclude-deprecated --exclude-description http://localhost:9009/graphql api/schema.json

.PHONY: clean
Expand All @@ -55,6 +54,7 @@ clean: ## Remove all codegen'd artifacts.
rm -rf cerbos
rm -rf support
rm -rf database
rm -rf test_infra
$(docker_compose) --profile '*' down

.PHONY: start
Expand Down
4 changes: 3 additions & 1 deletion test_app/schema/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ classes:
primer_file:
range: GenomicRange
inverse: GenomicRange.sequencing_reads
# Only mutable by system user (needed for upload flow)
annotations:
mutable: false
mutable: true
system_writable_only: True
contig:
range: Contig
inverse: Contig.sequencing_reads
Expand Down
15 changes: 3 additions & 12 deletions test_app/tests/test_nested_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,12 @@
Tests for nested queries + authorization
"""

import base64
import pytest
from collections import defaultdict
from platformics.database.connect import SyncDB
from conftest import GQLTestClient, SessionStorage
from test_infra.factories.sample import SampleFactory
from test_infra.factories.sequencing_read import SequencingReadFactory
from platformics.api.types.entities import Entity


def get_id(entity: Entity) -> str:
entity_type = entity.__class__.__name__
node_id = f"{entity_type}:{entity.id}".encode("ascii")
node_id_b64 = base64.b64encode(node_id).decode("utf-8")
return node_id_b64


@pytest.mark.asyncio
Expand Down Expand Up @@ -155,9 +146,9 @@ async def test_relay_node_queries(
sample1 = SampleFactory(owner_user_id=111, collection_id=888)
sample2 = SampleFactory(owner_user_id=111, collection_id=888)
sequencing_read = SequencingReadFactory(sample=sample1, owner_user_id=111, collection_id=888)
sample1_id = get_id(sample1)
sample2_id = get_id(sample2)
sequencing_read_id = get_id(sequencing_read)
sample1_id = sample1.id
sample2_id = sample2.id
sequencing_read_id = sequencing_read.id

# Fetch one node
query = f"""
Expand Down
1 change: 1 addition & 0 deletions test_app/tests/test_where_clause.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ async def test_soft_deleted_objects(sync_db: SyncDB, gql_client: GQLTestClient)
}}
}}
"""
# Only service identities are allowed to soft delete entities
output = await gql_client.query(soft_delete_mutation, member_projects=[project_id], service_identity="workflows")
assert len(output["data"]["updateSequencingRead"]) == 3

Expand Down

0 comments on commit 97dbaf6

Please sign in to comment.