Skip to content

Commit

Permalink
Merge pull request #9 from crashappsec/nettrino/interface_inheritance
Browse files Browse the repository at this point in the history
Add basic interface inheritance support (fixes #8)
  • Loading branch information
nettrino committed Dec 2, 2022
2 parents 0c4621e + 13a9c45 commit 61f540a
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 48 deletions.
5 changes: 3 additions & 2 deletions gql_schema_codegen/block/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .block import Block, BlockInfo, BlockField, BlockFieldInfo
from .block import (Block, BlockField, BlockFieldInfo, BlockInfo,
get_inheritance_tree)

__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo"]
__all__ = ["Block", "BlockInfo", "BlockField", "BlockFieldInfo", "get_inheritance_tree"]
37 changes: 29 additions & 8 deletions gql_schema_codegen/block/block.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import re
from typing import List, Literal, NamedTuple, Optional, Union
from collections import defaultdict
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Union

from ..base import BaseInfo
from ..constants import RESOLVER_TYPES, VALUE_TYPES
from ..constants.block_fields import all_block_fields
from ..dependency import Dependency, DependencyGroup
from ..dependency import (Dependency, DependencyGroup,
get_interface_dependencies,
remove_interface_dependencies)
from ..utils import pascal_case


Expand All @@ -18,6 +21,14 @@ class BlockFieldInfo(NamedTuple):
value_type: str


# a dictionary where for each node, we hold its children
inheritanceTree: Dict[str, Set[str]] = defaultdict(lambda: set())


def get_inheritance_tree():
return inheritanceTree


class Block(BaseInfo):
def __init__(self, info, dependency_group: DependencyGroup) -> None:
super().__init__(info)
Expand All @@ -35,6 +46,7 @@ def display_name(self):

@property
def heading_file_line(self):
global inheritanceTree
display_name = self.display_name

if self.type == "enum":
Expand All @@ -59,15 +71,24 @@ def heading_file_line(self):
)

if not self.implements:
# check if we have an interface implementing another interface
deps = get_interface_dependencies()
if display_name in deps:
inheritanceTree[deps[display_name]].add(display_name)
return f"@dataclass(kw_only=True)\nclass {display_name}({deps[display_name]}):"

inheritanceTree["root"].add(display_name)
return (
f"@dataclass(kw_only=True)\nclass {display_name}(DataClassJSONMixin):"
)

for el in self.info.implements.split("&"): # type: ignore
self.parent_classes.add(el.strip())

parent = ", ".join(list(self.parent_classes))
return f"@dataclass(kw_only=True)\nclass {display_name}({parent}):"
parents = remove_interface_dependencies(
[x.strip() for x in self.info.implements.split("&")] # type: ignore
)
for p in parents:
inheritanceTree[p].add(display_name)
parent_str = ", ".join(parents)
return f"@dataclass(kw_only=True)\nclass {display_name}({parent_str}):"

@property
def category(self):
Expand Down Expand Up @@ -99,8 +120,8 @@ def file_representation(self):
parent_fields = set()
for p in self.parent_classes:
parent_fields.update(all_block_fields.get(p, set()))

for f in self.fields:
# don't re-include parent fields
if str(f).split(":")[1].strip() in parent_fields:
continue

Expand Down
13 changes: 11 additions & 2 deletions gql_schema_codegen/dependency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
from .dependency import Dependency, DependencyGroup
from .dependency import (Dependency, DependencyGroup,
get_interface_dependencies,
remove_interface_dependencies,
update_interface_dependencies)

__all__ = ["Dependency", "DependencyGroup"]
__all__ = [
"Dependency",
"DependencyGroup",
"get_interface_dependencies",
"update_interface_dependencies",
"remove_interface_dependencies",
]
72 changes: 71 additions & 1 deletion gql_schema_codegen/dependency/dependency.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,82 @@
from itertools import groupby
from typing import List, NamedTuple, Set
from typing import Dict, List, NamedTuple, Set


class Dependency(NamedTuple):
imported_from: str
dependency: str


INTERMEDIATE_INTERFACES: Dict[str, str] = {}


def get_interface_dependencies():
return INTERMEDIATE_INTERFACES


def update_interface_dependencies(config_file_content):
global INTERMEDIATE_INTERFACES
if isinstance(config_file_content, dict):
data = config_file_content.get("interfaceInheritance")
if isinstance(data, dict):
INTERMEDIATE_INTERFACES = data


def remove_interface_dependencies(
interfaces: List[str],
intermediate_interfaces: Dict[str, str] = {},
) -> List[str]:
"""Filter all dependencies from intermediate interfaces
Assumes that all keys in intermediate_interfaces are leaf nodes and
returns all other parent interfaces from a given list
Intemediate interfaces in GraphQL implement other interfaces. This is part
of the spec (see https://github.com/graphql/graphql-spec/pull/373)
but not implemented in all clients yet (e.g., neo4j)
Thus we are parsing intermediate interfaces in scalars.yml so as to emit
proper python code without relying on the graphql schema being correct
>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 't3'], {'i1':'i2'})
['t1', 't2', 'i1', 't3']
>>> remove_interface_dependencies(['t1', 't2', 't3'], {})
['t1', 't2', 't3']
>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', 't3'], \
{'i1':'i2', 'i2': 'i3'})
['t1', 't2', 'i1', 't3']
>>> remove_interface_dependencies(['t1', 't2', 'i1', 'i2', 'i3', \
'ni1', 'ni2', 't3'], \
{'i1':'i2', 'i2': 'i3', 'ni1':'ni2'})
['t1', 't2', 'i1', 'ni1', 't3']
"""
deps: Set[str] = set()
intermediate_interfaces = intermediate_interfaces or INTERMEDIATE_INTERFACES
for i in interfaces:
if i not in intermediate_interfaces:
# this is not an intermediate interface, thus we are keeping it
# as a dependency
continue

# add the parent dependency to the list of tracked dependencies:
deps.add(intermediate_interfaces[i])

# transitively fetch all dependencies
seen: Set[str] = set()
while deps:
d = deps.pop()
if d in seen:
continue
if d in intermediate_interfaces:
deps.add(intermediate_interfaces[d])
seen.add(d)

return list(filter(lambda x: x not in seen, interfaces))


class DependencyGroup:
def __init__(
self, deps: Set[Dependency] = set(), direct_deps: Set[str] = set()
Expand Down
2 changes: 1 addition & 1 deletion gql_schema_codegen/scalar/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def file_representation(self):
self.dependency_group.add_direct_dependency("neo4j")
return f"{self.name} = datetime"

if type(self.value) is not str:
if not isinstance(self.value, str):
self.dependency_group.add_dependency(
Dependency(imported_from="typing", dependency="Any")
)
Expand Down
88 changes: 54 additions & 34 deletions gql_schema_codegen/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,21 @@
import os
import re
import subprocess
from typing import List, Optional, Set
from typing import Dict, List, Optional, Set

import yaml
from graphql import (
build_client_schema,
build_schema,
get_introspection_query,
print_schema,
)
from graphql import (build_client_schema, build_schema,
get_introspection_query, print_schema)
from graphqlclient import GraphQLClient

from ..block import Block, BlockField, BlockFieldInfo, BlockInfo
from ..constants import (
BLOCK_PATTERN,
DIRECTIVE_PATTERN,
DIRECTIVE_USAGE_PATTERN,
FIELD_PATTERN,
RESOLVER_TYPES,
SCALAR_PATTERN,
UNION_PATTERN,
)
from ..block import (Block, BlockField, BlockFieldInfo, BlockInfo,
get_inheritance_tree)
from ..constants import (BLOCK_PATTERN, DIRECTIVE_PATTERN,
DIRECTIVE_USAGE_PATTERN, FIELD_PATTERN,
RESOLVER_TYPES, SCALAR_PATTERN, UNION_PATTERN)
from ..constants.block_fields import all_block_fields
from ..dependency import Dependency, DependencyGroup
from ..dependency import (Dependency, DependencyGroup,
update_interface_dependencies)
from ..scalar import ScalarInfo, ScalarType
from ..union import UnionInfo, UnionType

Expand All @@ -47,19 +39,21 @@ class Schema:
_only_blocks: bool = False

def __init__(self, **kwargs) -> None:
if "path" in kwargs and type(kwargs["path"]) is str:
if "path" in kwargs and isinstance(kwargs["path"], str):
self.path = kwargs["path"]

if "url" in kwargs and type(kwargs["url"]) is str:
if "url" in kwargs and isinstance(kwargs["url"], str):
self.url = kwargs["url"]

if "config_file" in kwargs and type(kwargs["config_file"]) is str:
if "config_file" in kwargs and isinstance(kwargs["config_file"], str):
self.config_file = kwargs["config_file"]

self._special_blocks = kwargs.get("blocks", self._special_blocks)
self._import_blocks = kwargs.get("import_blocks", self._import_blocks)
self._only_blocks = kwargs.get("only_blocks", self._only_blocks)

update_interface_dependencies(self.config_file_content)

self.dependency_group = DependencyGroup()

@property
Expand All @@ -76,9 +70,9 @@ def config_file_content(self) -> Optional[dict[str, str]]:

@property
def custom_scalars(self) -> dict[str, str]:
if type(self.config_file_content) is dict:
if isinstance(self.config_file_content, dict):
data = self.config_file_content.get("scalars")
if type(data) is dict:
if isinstance(data, dict):
return data

return {}
Expand Down Expand Up @@ -209,7 +203,6 @@ def blocks(self):

block_type = block["type"]
block_name = block["name"]

all_block_fields[block_name] = set()
for field in self.get_fields_from_block(block["fields"]):
all_block_fields[block_name].add(field["name"])
Expand Down Expand Up @@ -250,11 +243,40 @@ def blocks(self):

@property
def sorted_blocks(self):
types_order = ["enum", "type", "param_type", "input"]
return sorted(
self.blocks,
key=lambda b: (types_order.index(b.type) if b.type in types_order else -1),
)
# first populate inheritance tree. this is VERY dirty for now but we
# should refactor this soon. We are calling b.heading_file_line for all
# blocks here as this is what populates the tree, and we only do this
# once
all_blocks: Dict[str, Block] = {}
for b in self.blocks:
all_blocks[b.name] = b
_ = b.heading_file_line
inheritanceTree = get_inheritance_tree()
sorted_bl: List[Block] = []
# first add enums - these have no dependencies
sorted_bl.extend(list(filter(lambda x: x.type == "enum", self.blocks)))

types_order = ["interface", "type", "param_type", "input"]
for t in types_order:
interfaces = list(filter(lambda x: x.type == t, self.blocks))

to_add = {b.name for b in interfaces}
blocks: List[Block] = []

# add nodes in a BFS manner to ensure we don't break dependencies
queue: List[str] = ["root"]
visited: Set[str] = set(["root"])
while queue:
node = queue.pop(0)
if node in to_add:
blocks.append(all_blocks[node])
for child_node in inheritanceTree.get(node, []):
if child_node not in visited:
visited.add(child_node)
queue.append(child_node)

sorted_bl.extend(blocks)
return sorted_bl

@property
def unions(self):
Expand Down Expand Up @@ -302,17 +324,15 @@ def file_representation(self):
lines: List[str] = ["\n" * 2]

if len(self.scalars) > 0:
# lines.extend(['## Scalars'] + ['\n' * 2])

for s in self.scalars:
lines.extend([s.file_representation] + ["\n" * 2])
lines.extend([s.file_representation] + ["\n"])

lines.append("\n" * 2)

if len(self.unions) > 0:
self.dependency_group.add_dependency(
Dependency(imported_from="typing", dependency="Union")
)
lines.extend(["## Union Types"] + ["\n" * 2])

for u in self.unions:
lines.extend([u.file_representation] + ["\n" * 2])

Expand Down

0 comments on commit 61f540a

Please sign in to comment.