Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fixes, refactor names #9

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/graphql_sqlalchemy/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from sqlalchemy.ext.declarative import DeclarativeMeta

from .graphql_types import get_graphql_type_from_column
from .helpers import get_table
from .helpers import get_table, has_int
from .inputs import (
ON_CONFLICT_INPUT,
get_inc_input_type,
get_insert_input_type,
get_order_input_type,
get_pk_columns_input,
get_set_input_type,
get_where_input_type,
)
Expand Down Expand Up @@ -60,15 +59,15 @@ def make_delete_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentM

def make_update_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
return {
"_inc": GraphQLArgument(get_inc_input_type(model, inputs)),
**({"_inc": GraphQLArgument(get_inc_input_type(model, inputs))} if has_int(model) else {}),
"_set": GraphQLArgument(get_set_input_type(model, inputs)),
"where": GraphQLArgument(get_where_input_type(model, inputs)),
}


def make_update_by_pk_args(model: DeclarativeMeta, inputs: Inputs) -> GraphQLArgumentMap:
return {
"_inc": GraphQLArgument(get_inc_input_type(model, inputs)),
**({"_inc": GraphQLArgument(get_inc_input_type(model, inputs))} if has_int(model) else {}),
"_set": GraphQLArgument(get_set_input_type(model, inputs)),
"pk_columns": GraphQLArgument(GraphQLNonNull(get_pk_columns_input(model))),
**make_pk_args(model),
}
15 changes: 5 additions & 10 deletions src/graphql_sqlalchemy/dialects/pg/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,23 @@

from ...helpers import get_table
from ...inputs import get_where_input_type
from ...names import (
get_model_column_update_enum_name,
get_model_conflict_input_name,
get_model_constraint_enum_name,
get_model_constraint_key_name,
)
from ...names import get_field_name
from ...types import Inputs


def get_constraint_enum(model: DeclarativeMeta) -> GraphQLEnumType:
type_name = get_model_constraint_enum_name(model)
type_name = get_field_name(model, "constraint")

fields = {}
for column in get_table(model).primary_key:
key_name = get_model_constraint_key_name(model, column, is_primary_key=True)
key_name = get_field_name(model, "pkey")
fields[key_name] = key_name

return GraphQLEnumType(type_name, fields)


def get_update_column_enums(model: DeclarativeMeta) -> GraphQLEnumType:
type_name = get_model_column_update_enum_name(model)
type_name = get_field_name(model, "update_column")

fields = {}
for column in get_table(model).columns:
Expand All @@ -34,7 +29,7 @@ def get_update_column_enums(model: DeclarativeMeta) -> GraphQLEnumType:


def get_conflict_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
type_name = get_model_conflict_input_name(model)
type_name = get_field_name(model, "on_conflict")
if type_name in inputs:
return inputs[type_name]

Expand Down
17 changes: 11 additions & 6 deletions src/graphql_sqlalchemy/helpers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from typing import List, Tuple
from typing import List, Tuple, Union

from sqlalchemy import Table
from graphql import GraphQLList, GraphQLScalarType
from sqlalchemy import Table, Integer, Float
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import Mapper, RelationshipProperty


def get_table(model: DeclarativeMeta) -> Table:
return model.__table__ # type: ignore
def get_table(model: Union[DeclarativeMeta, GraphQLScalarType, GraphQLList]) -> Table:
return getattr(model, "__table__")


def get_mapper(model: DeclarativeMeta) -> Mapper:
return model.__mapper__ # type: ignore
return getattr(model, "__mapper__")


def get_relationships(model: DeclarativeMeta) -> List[Tuple[str, RelationshipProperty]]:
return get_mapper(model).relationships.items() # type: ignore
return getattr(get_mapper(model).relationships, "items")()


def has_int(model: DeclarativeMeta) -> bool:
return any([isinstance(i.type, (Integer, Float)) for i in get_table(model).columns])
51 changes: 23 additions & 28 deletions src/graphql_sqlalchemy/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,7 @@

from .graphql_types import get_base_comparison_fields, get_graphql_type_from_column, get_string_comparison_fields
from .helpers import get_relationships, get_table
from .names import (
get_graphql_type_comparison_name,
get_model_inc_input_type_name,
get_model_insert_input_name,
get_model_order_by_input_name,
get_model_pk_columns_input_type_name,
get_model_set_input_type_name,
get_model_where_input_name,
)
from .names import get_field_name
from .types import Inputs

ORDER_BY_ENUM = GraphQLEnumType("order_by", {"desc": "desc", "asc": "asc"})
Expand All @@ -35,7 +27,7 @@

def get_comparison_input_type(column: Column, inputs: Inputs) -> GraphQLInputObjectType:
graphql_type = get_graphql_type_from_column(column.type)
type_name = get_graphql_type_comparison_name(graphql_type)
type_name = get_field_name(graphql_type, "comparison")

if type_name in inputs:
return inputs[type_name]
Expand All @@ -50,7 +42,7 @@ def get_comparison_input_type(column: Column, inputs: Inputs) -> GraphQLInputObj


def get_where_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
type_name = get_model_where_input_name(model)
type_name = get_field_name(model, "where")
if type_name in inputs:
return inputs[type_name]

Expand All @@ -65,7 +57,7 @@ def get_fields() -> GraphQLInputFieldMap:
fields[column.name] = GraphQLInputField(get_comparison_input_type(column, inputs))

for name, relationship in get_relationships(model):
fields[name] = GraphQLInputField(inputs[get_model_where_input_name(relationship.mapper.entity)])
fields[name] = GraphQLInputField(inputs[get_field_name(relationship.mapper.entity, "where")])

return fields

Expand All @@ -74,7 +66,7 @@ def get_fields() -> GraphQLInputFieldMap:


def get_order_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
type_name = get_model_order_by_input_name(model)
type_name = get_field_name(model, "order_by")

def get_fields() -> GraphQLInputFieldMap:
fields = {}
Expand All @@ -83,7 +75,7 @@ def get_fields() -> GraphQLInputFieldMap:
fields[column.name] = GraphQLInputField(ORDER_BY_ENUM)

for name, relationship in get_relationships(model):
fields[name] = GraphQLInputField(inputs[get_model_order_by_input_name(relationship.mapper.entity)])
fields[name] = GraphQLInputField(inputs[get_field_name(relationship.mapper.entity, "order_by")])

return fields

Expand All @@ -100,16 +92,30 @@ def make_model_fields_input_type(model: DeclarativeMeta, type_name: str) -> Grap


def get_insert_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
type_name = get_model_insert_input_name(model)
type_name = get_field_name(model, "insert_input")
if type_name in inputs:
return inputs[type_name]

inputs[type_name] = make_model_fields_input_type(model, type_name)
return inputs[type_name]


def get_conflict_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
type_name = get_field_name(model, "on_conflict")
if type_name in inputs:
return inputs[type_name]

fields = {
"merge": GraphQLInputField(GraphQLNonNull(GraphQLBoolean)),
}

input_type = GraphQLInputObjectType(type_name, fields)
inputs[type_name] = input_type
return input_type


def get_inc_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
type_name = get_model_inc_input_type_name(model)
type_name = get_field_name(model, "inc_input")
if type_name in inputs:
return inputs[type_name]

Expand All @@ -123,20 +129,9 @@ def get_inc_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputOb


def get_set_input_type(model: DeclarativeMeta, inputs: Inputs) -> GraphQLInputObjectType:
type_name = get_model_set_input_type_name(model)
type_name = get_field_name(model, "set_input")
if type_name in inputs:
return inputs[type_name]

inputs[type_name] = make_model_fields_input_type(model, type_name)
return inputs[type_name]


def get_pk_columns_input(model: DeclarativeMeta) -> GraphQLInputObjectType:
type_name = get_model_pk_columns_input_type_name(model)
primary_key = get_table(model).primary_key

fields = {}
for column in primary_key.columns:
fields[column.name] = GraphQLInputField(GraphQLNonNull(get_graphql_type_from_column(column.type)))

return GraphQLInputObjectType(type_name, fields)
121 changes: 41 additions & 80 deletions src/graphql_sqlalchemy/names.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional

from graphql import GraphQLList, GraphQLScalarType
from sqlalchemy import Column
Expand All @@ -7,87 +7,48 @@
from .helpers import get_table


def get_table_name(model: DeclarativeMeta) -> str:
FIELD_NAMES = {
"by_pk": "%s_by_pk",
"order_by": "%s_order_by",
"where": "%s_bool_exp",
"insert": "insert_%s",
"insert_one": "insert_%s_one",
"insert_input": "%s_insert_input",
"mutation_response": "%s_mutation_response",
"update": "update_%s",
"update_by_pk": "update_%s_by_pk",
"delete": "delete_%s",
"delete_by_pk": "delete_%s_by_pk",
"inc_input": "%s_inc_input",
"set_input": "%s_set_input",
"comparison": "%s_comparison_exp",
"arr_comparison": "arr_%s_comparison_exp",
"constraint": "%s_constraint",
"update_column": "%s_update_column",
"on_conflict": "%s_on_conflict",
"pkey": "%s_pkey",
"key": "%s_%s_key",
}


def get_table_name(model: Union[DeclarativeMeta, GraphQLScalarType, GraphQLList]) -> str:
return get_table(model).name


def get_model_pk_field_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_by_pk"
def get_field_name(
model: Union[DeclarativeMeta, GraphQLScalarType, GraphQLList],
field_name: str,
column: Optional[Union[Column, GraphQLScalarType, GraphQLList]] = None,
) -> str:
if field_name == "comparison":
if isinstance(model, GraphQLList):
return FIELD_NAMES["arr_comparison"] % model.of_type.name.lower()
else:
return FIELD_NAMES[field_name] % getattr(model, "name").lower()

else:
name = get_table_name(model)
if isinstance(column, Column) and field_name == "key":
return FIELD_NAMES[field_name] % (name, column.name)

def get_model_order_by_input_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_order_by"


def get_model_where_input_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_bool_exp"


def get_graphql_type_comparison_name(graphql_type: Union[GraphQLList[GraphQLScalarType], GraphQLScalarType]) -> str:
if isinstance(graphql_type, GraphQLList):
return f"arr_{graphql_type.of_type.name.lower()}_comparison_exp"

return f"{graphql_type.name.lower()}_comparison_exp"


def get_model_insert_input_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_insert_input"


def get_model_insert_object_name(model: DeclarativeMeta) -> str:
return f"insert_{get_table_name(model)}"


def get_model_insert_one_object_name(model: DeclarativeMeta) -> str:
return f"insert_{get_table_name(model)}_one"


def get_model_conflict_input_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_on_conflict"


def get_model_mutation_response_object_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_mutation_response"


def get_model_constraint_enum_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_constraint"


def get_model_constraint_key_name(model: DeclarativeMeta, column: Column, is_primary_key: bool = False) -> str:
if is_primary_key:
return f"{get_table_name(model)}_pkey"

return f"{get_table_name(model)}_{column.name}_key"


def get_model_column_update_enum_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_update_column"


def get_model_delete_name(model: DeclarativeMeta) -> str:
return f"delete_{get_table_name(model)}"


def get_model_delete_by_pk_name(model: DeclarativeMeta) -> str:
return f"delete_{get_table_name(model)}_by_pk"


def get_model_update_name(model: DeclarativeMeta) -> str:
return f"update_{get_table_name(model)}"


def get_model_update_by_pk_name(model: DeclarativeMeta) -> str:
return f"update_{get_table_name(model)}_by_pk"


def get_model_inc_input_type_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_inc_input"


def get_model_set_input_type_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_set_input"


def get_model_pk_columns_input_type_name(model: DeclarativeMeta) -> str:
return f"{get_table_name(model)}_pk_columns_input"
return FIELD_NAMES[field_name] % name
4 changes: 2 additions & 2 deletions src/graphql_sqlalchemy/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .graphql_types import get_graphql_type_from_column
from .helpers import get_relationships, get_table
from .names import get_model_mutation_response_object_name, get_table_name
from .names import get_field_name, get_table_name
from .resolvers import make_field_resolver
from .types import Objects

Expand Down Expand Up @@ -41,7 +41,7 @@ def get_fields() -> GraphQLFieldMap:


def build_mutation_response_type(model: DeclarativeMeta, objects: Objects) -> GraphQLObjectType:
type_name = get_model_mutation_response_object_name(model)
type_name = get_field_name(model, "mutation_response")

object_type = objects[get_table_name(model)]
fields = {
Expand Down
Loading