diff --git a/onto/database/mock.py b/onto/database/mock.py index 113f5c5..a1a2b5e 100644 --- a/onto/database/mock.py +++ b/onto/database/mock.py @@ -43,7 +43,7 @@ def set(cls, ref: Reference, snapshot: Snapshot, transaction=_NA): @classmethod def get(cls, ref: Reference, transaction=_NA): - return cls.d[str(ref)] + return Snapshot(cls.d[str(ref)]) update = set create = set @@ -63,7 +63,7 @@ def query(cls, q): qualifier = q._to_qualifier() for k, v in cls.d.items(): if qualifier(v): - yield k, v + yield MockReference.from_str(k), Snapshot(v) yield from () diff --git a/onto/sink/graphql.py b/onto/sink/graphql.py index fb18360..af7c27b 100644 --- a/onto/sink/graphql.py +++ b/onto/sink/graphql.py @@ -166,6 +166,10 @@ def _as_graphql_schema(self): def start(self): subscription_schema = self._as_graphql_schema() return subscription_schema + + @staticmethod + def _get_user(info): + return info.context.request.user class GraphQLSubscriptionSink(GraphQLSink): @@ -206,15 +210,20 @@ class GraphQLQuerySink(GraphQLSink): def _register_op(self): from gql import query - + extra_args = set() + if '__user' in extra_args: + kwargs = {__user: self._get_user(info), **kwargs} async def f(parent, info, **kwargs): - res = self._invoke_mediator(func_name='query', **kwargs) + res = await self._invoke_mediator(func_name='query', **kwargs) return res name = self.sink_name f.__name__ = name query(f) args = dict(self._args_of('query')) + if '__user' in args: + extra_args.add('__user') + args = {k: v for k, v in args.items() if not str(k).startswith('__')} return name, args @@ -225,7 +234,10 @@ class GraphQLMutationSink(GraphQLSink): def _register_op(self): from gql import mutate + extra_args = set() async def f(parent, info, **kwargs): + if '__user' in extra_args: + kwargs = { __user: self._get_user(info), **kwargs } res = await self._invoke_mediator(func_name='mutate', **kwargs) return res @@ -233,6 +245,10 @@ async def f(parent, info, **kwargs): f.__name__ = name mutate(f, snake_argument=False) args = dict(self._args_of('mutate')) + if '__user' in args: + extra_args.add('__user') + # TODO: NOTE: __ fields are filtered out + args = { k: v for k, v in args.items() if not str(k).startswith('__')} return name, args diff --git a/onto/utils.py b/onto/utils.py index 3064bb0..85febcb 100644 --- a/onto/utils.py +++ b/onto/utils.py @@ -73,7 +73,7 @@ def snapshot_to_obj( # if not snapshot.exists: # return None - d = snapshot.to_dict() if not isinstance(snapshot, dict) else snapshot # TODO: improve + d = snapshot.to_dict() obj_cls = super_cls if "obj_type" in d: