Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic interface inheritance support (fixes #8) #9

Merged
merged 2 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions gql_schema_codegen/block/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .block import Block, BlockInfo, BlockField, BlockFieldInfo
from .block import (Block, BlockField, BlockFieldInfo, BlockInfo,
get_inheritance_tree)

__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo"]
__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo", "get_inheritance_tree"]
40 changes: 32 additions & 8 deletions gql_schema_codegen/block/block.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import re
from typing import List, Literal, NamedTuple, Optional, Union
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Union

from ..base import BaseInfo
from ..constants import RESOLVER_TYPES, VALUE_TYPES
from ..constants.block_fields import all_block_fields
from ..dependency import Dependency, DependencyGroup
from ..dependency import (Dependency, DependencyGroup,
get_interface_dependencies,
remove_interface_dependencies)
from ..utils import pascal_case


Expand All @@ -18,6 +20,14 @@ class BlockFieldInfo(NamedTuple):
value_type: str


# a dictionary where for each node, we hold its children
inheritanceTree: Dict[str, Set[str]] = {"root": set()}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inheritanceTree: Dict[str, Set[str]] = {"root": set()}
inheritanceTree: Dict[str, Set[str]] = defaultdict(lambda: set())

https://docs.python.org/3.11/library/collections.html#collections.defaultdict

then whatever you access on the dict is always present automatically



def get_inheritance_tree():
return inheritanceTree


class Block(BaseInfo):
def __init__(self, info, dependency_group: DependencyGroup) -> None:
super().__init__(info)
Expand All @@ -35,6 +45,7 @@ def display_name(self):

@property
def heading_file_line(self):
global inheritanceTree
display_name = self.display_name

if self.type == "enum":
Expand All @@ -59,15 +70,28 @@ def heading_file_line(self):
)

if not self.implements:
# check if we have an interface implementing another interface
deps = get_interface_dependencies()
if display_name in deps:
siblings = inheritanceTree.get(deps[display_name], set())
siblings.add(display_name)
inheritanceTree[deps[display_name]] = siblings
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
siblings = inheritanceTree.get(deps[display_name], set())
siblings.add(display_name)
inheritanceTree[deps[display_name]] = siblings
inheritanceTree.setdefault(deps[display_name], set()).add(display_name)

https://docs.python.org/3.11/library/stdtypes.html#dict.setdefault

or if you do defaultdict as per above:

Suggested change
siblings = inheritanceTree.get(deps[display_name], set())
siblings.add(display_name)
inheritanceTree[deps[display_name]] = siblings
inheritanceTree[deps[display_name]].add(display_name)

return f"@dataclass(kw_only=True)\nclass {display_name}({deps[display_name]}):"

inheritanceTree["root"].add(display_name)
return (
f"@dataclass(kw_only=True)\nclass {display_name}(DataClassJSONMixin):"
)

for el in self.info.implements.split("&"): # type: ignore
self.parent_classes.add(el.strip())

parent = ", ".join(list(self.parent_classes))
return f"@dataclass(kw_only=True)\nclass {display_name}({parent}):"
parents = remove_interface_dependencies(
[x.strip() for x in self.info.implements.split("&")] # type: ignore
)
for p in parents:
siblings = inheritanceTree.get(p, set())
siblings.add(display_name)
inheritanceTree[p] = siblings
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar pattern as above with either setdefault or defaultdict()[]

parent_str = ", ".join(parents)
return f"@dataclass(kw_only=True)\nclass {display_name}({parent_str}):"

@property
def category(self):
Expand Down Expand Up @@ -99,8 +123,8 @@ def file_representation(self):
parent_fields = set()
for p in self.parent_classes:
parent_fields.update(all_block_fields.get(p, set()))

for f in self.fields:
# don't re-include parent fields
if str(f).split(":")[1].strip() in parent_fields:
continue

Expand Down
13 changes: 11 additions & 2 deletions gql_schema_codegen/dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
from .dependency import Dependency, DependencyGroup
from .dependency import (Dependency, DependencyGroup,
get_interface_dependencies,
remove_interface_dependencies,
update_interface_dependencies)

__all__ = ["Dependency", "DependencyGroup"]
__all__ = [
"Dependency",
"DependencyGroup",
"get_interface_dependencies",
"update_interface_dependencies",
"remove_interface_dependencies",
]
73 changes: 72 additions & 1 deletion gql_schema_codegen/dependency/dependency.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,83 @@
from itertools import groupby
from typing import List, NamedTuple, Set
from typing import Dict, List, NamedTuple, Set


class Dependency(NamedTuple):
imported_from: str
dependency: str


INTERMEDIATE_INTERFACES: Dict[str, str] = {}


def get_interface_dependencies():
return INTERMEDIATE_INTERFACES


def update_interface_dependencies(config_file_content):
global INTERMEDIATE_INTERFACES
if type(config_file_content) is dict:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if type(config_file_content) is dict:
if isinstance(config_file_content, dict)

doing type check with is is not ideal:

>>> class Foo(dict): pass
>>> type(Foo()) is dict
False
>>> isinstance(Foo(), dict)
True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know but I am explicitly followed the style of the rest of the repo here in the thought that we will perhaps refactor everything together

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do the defaultdict suggestion isinstance will be required. either way its up to you

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to make the change - leaving extra bad code in was not great and leaves more work for potential future refactorings

data = config_file_content.get("interfaceInheritance")
if type(data) is dict:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

INTERMEDIATE_INTERFACES = data


def remove_interface_dependencies(
interfaces: List[str],
intermediate_interfaces: Dict[str, str] = {},
) -> List[str]:
"""Filter all dependencies from intermediate interfaces

Assumes that all keys in intermediate_interfaces are leaf nodes and
returns all other parent interfaces from a given list

Intemediate interfaces in GraphQL implement other interfaces. This is part
of the spec (see https://github.com/graphql/graphql-spec/pull/373)
but not implemented in all clients yet (e.g., neo4j)
Thus we are parsing intermediate interfaces in scalars.yml so as to emit
proper python code without relying on the graphql schema being correct


>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 't3'], {'i1':'i2'})
['t1', 't2', 'i1', 't3']

>>> remove_interface_dependencies(['t1', 't2', 't3'], {})
['t1', 't2', 't3']

>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', 't3'], \
{'i1':'i2', 'i2': 'i3'})
['t1', 't2', 'i1', 't3']

>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', \
'ni1', 'ni2', 't3'], \
{'i1':'i2', 'i2': 'i3', 'ni1':'ni2'})
['t1', 't2', 'i1', 'ni1', 't3']
"""
deps: Set[str] = set()
if not intermediate_interfaces:
intermediate_interfaces = INTERMEDIATE_INTERFACES
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not intermediate_interfaces:
intermediate_interfaces = INTERMEDIATE_INTERFACES
intermediate_interfaces = intermediate_interfaces or INTERMEDIATE_INTERFACES

usually simplifies things a bit, especially if you count the mccabe function complexity its nice to reduce any direct if statements to more simple alternatives

for i in interfaces:
if i not in intermediate_interfaces:
# this is not an intermediate interface, thus we are keeping it
# as a dependency
continue

# add the parent dependency to the list of tracked dependencies:
deps.add(intermediate_interfaces[i])

# transitively fetch all dependencies
seen: Set[str] = set()
while deps:
d = deps.pop()
if d in seen:
continue
if d in intermediate_interfaces:
deps.add(intermediate_interfaces[d])
seen.add(d)

return list(filter(lambda x: x not in seen, interfaces))


class DependencyGroup:
def __init__(
self, deps: Set[Dependency] = set(), direct_deps: Set[str] = set()
Expand Down
78 changes: 49 additions & 29 deletions gql_schema_codegen/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,21 @@
import os
import re
import subprocess
from typing import List, Optional, Set
from typing import Dict, List, Optional, Set

import yaml
from graphql import (
build_client_schema,
build_schema,
get_introspection_query,
print_schema,
)
from graphql import (build_client_schema, build_schema,
get_introspection_query, print_schema)
from graphqlclient import GraphQLClient

from ..block import Block, BlockField, BlockFieldInfo, BlockInfo
from ..constants import (
BLOCK_PATTERN,
DIRECTIVE_PATTERN,
DIRECTIVE_USAGE_PATTERN,
FIELD_PATTERN,
RESOLVER_TYPES,
SCALAR_PATTERN,
UNION_PATTERN,
)
from ..block import (Block, BlockField, BlockFieldInfo, BlockInfo,
get_inheritance_tree)
from ..constants import (BLOCK_PATTERN, DIRECTIVE_PATTERN,
DIRECTIVE_USAGE_PATTERN, FIELD_PATTERN,
RESOLVER_TYPES, SCALAR_PATTERN, UNION_PATTERN)
from ..constants.block_fields import all_block_fields
from ..dependency import Dependency, DependencyGroup
from ..dependency import (Dependency, DependencyGroup,
update_interface_dependencies)
from ..scalar import ScalarInfo, ScalarType
from ..union import UnionInfo, UnionType

Expand Down Expand Up @@ -60,6 +52,8 @@ def __init__(self, **kwargs) -> None:
self._import_blocks = kwargs.get("import_blocks", self._import_blocks)
self._only_blocks = kwargs.get("only_blocks", self._only_blocks)

update_interface_dependencies(self.config_file_content)

self.dependency_group = DependencyGroup()

@property
Expand Down Expand Up @@ -209,7 +203,6 @@ def blocks(self):

block_type = block["type"]
block_name = block["name"]

all_block_fields[block_name] = set()
for field in self.get_fields_from_block(block["fields"]):
all_block_fields[block_name].add(field["name"])
Expand Down Expand Up @@ -250,11 +243,40 @@ def blocks(self):

@property
def sorted_blocks(self):
types_order = ["enum", "type", "param_type", "input"]
return sorted(
self.blocks,
key=lambda b: (types_order.index(b.type) if b.type in types_order else -1),
)
# first populate inheritance tree. this is VERY dirty for now but we
# should refactor this soon. We are calling b.heading_file_line for all
# blocks here as this is what populates the tree, and we only do this
# once
all_blocks: Dict[str, Block] = {}
for b in self.blocks:
all_blocks[b.name] = b
_ = b.heading_file_line
inheritanceTree = get_inheritance_tree()
sorted_bl: List[Block] = []
# first add enums - these have no dependencies
sorted_bl.extend(list(filter(lambda x: x.type == "enum", self.blocks)))

types_order = ["interface", "type", "param_type", "input"]
for t in types_order:
interfaces = list(filter(lambda x: x.type == t, self.blocks))

to_add = {b.name for b in interfaces}
blocks: List[Block] = []

# add nodes in a BFS manner to ensure we don't break dependencies
queue: List[str] = ["root"]
visited: Set[str] = set(["root"])
while queue:
node = queue.pop(0)
if node in to_add:
blocks.append(all_blocks[node])
for child_node in inheritanceTree.get(node, []):
if child_node not in visited:
visited.add(child_node)
queue.append(child_node)

sorted_bl.extend(blocks)
return sorted_bl

@property
def unions(self):
Expand Down Expand Up @@ -302,17 +324,15 @@ def file_representation(self):
lines: List[str] = ["\n" * 2]

if len(self.scalars) > 0:
# lines.extend(['## Scalars'] + ['\n' * 2])

for s in self.scalars:
lines.extend([s.file_representation] + ["\n" * 2])
lines.extend([s.file_representation] + ["\n"])

lines.append("\n" * 2)

if len(self.unions) > 0:
self.dependency_group.add_dependency(
Dependency(imported_from="typing", dependency="Union")
)
lines.extend(["## Union Types"] + ["\n" * 2])

for u in self.unions:
lines.extend([u.file_representation] + ["\n" * 2])

Expand Down