Skip to content

Commit

Permalink
Fixing cdd_gae gen parquet -> sqalalchemy table (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
skushnir123 committed Feb 14, 2023
1 parent fec1e5c commit 6b2e79e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion cdd/compound/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def gen_module(
if emit_and_infer_imports:
imports = "{}{}".format(
imports or "",
" ".join(map(to_code, map(infer_imports, functions_and_classes))),
" ".join(map(to_code, chain(*map(infer_imports, functions_and_classes)))),
)

# Too many params! - Clean things up for debugging:
Expand Down
17 changes: 8 additions & 9 deletions cdd/shared/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,26 +1646,25 @@ def infer_imports(module):
"""
import cdd.sqlalchemy.utils.parser_utils # Should this be a function param instead?

if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef)):
if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign)):
module = Module(body=[module], type_ignores=[], stmt=None)
assert isinstance(module, Module), "Expected `Module` got `{type_name}`".format(
type_name=type(module).__name__
)

sqlalchemy_classes = filter(
lambda cls_def: any(
sqlalchemy_class_or_assigns = filter(
lambda class_or_assign_def: any(
filter(
lambda base: isinstance(base, Name) and base.id == "Base", cls_def.bases
lambda base: isinstance(base, Name) and base.id == "Base", class_or_assign_def.bases
)
),
filter(rpartial(isinstance, ClassDef), module.body),
) if isinstance(class_or_assign_def, ClassDef) else isinstance(class_or_assign_def.value, Call) and class_or_assign_def.value.func.id.endswith("Table"),
filter(rpartial(isinstance, (ClassDef, Assign)), module.body),
)

# reduce(, sqlalchemy_classes, set)
return list(
(
(cdd.sqlalchemy.utils.parser_utils.imports_from(sqlalchemy_classes),)
if sqlalchemy_classes
(cdd.sqlalchemy.utils.parser_utils.imports_from(sqlalchemy_class_or_assigns),)
if sqlalchemy_class_or_assigns
else iter(())
)
)
Expand Down
18 changes: 9 additions & 9 deletions cdd/sqlalchemy/utils/parser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def column_call_name_manipulator(call, operation="remove", name=None):
return call


def infer_imports_from_sqlalchemy(sqlalchemy_class_def):
def infer_imports_from_sqlalchemy(sqlalchemy_class_or_assigns):
"""
Infer imports from SQLalchemy class
Infer imports from SQLalchemy ast
:param sqlalchemy_class_def: SQLalchemy class
:type sqlalchemy_class_def: ```ClassDef```
:param sqlalchemy_class_or_assigns: SQLalchemy Class or Assign
:type sqlalchemy_class_or_assigns: ```Union[ClassDef, Assign]```
:return: filter of imports (can be considered ```Iterable[str]```)
:rtype: ```filter```
Expand All @@ -275,7 +275,7 @@ def infer_imports_from_sqlalchemy(sqlalchemy_class_def):
body=list(
filter(
rpartial(isinstance, Call),
ast.walk(sqlalchemy_class_def),
ast.walk(sqlalchemy_class_or_assigns),
)
),
type_ignores=[],
Expand All @@ -297,12 +297,12 @@ def infer_imports_from_sqlalchemy(sqlalchemy_class_def):
return candidates_not_in_valid_types ^ candidates


def imports_from(sqlalchemy_classes):
def imports_from(sqlalchemy_asts):
"""
Generate `from sqlalchemy import <>` from the body of SQLalchemy `class`es
:param sqlalchemy_classes: SQLalchemy `class`es with base class of `Base`
:type sqlalchemy_classes: ```ClassDef```
:param sqlalchemy_asts: SQLalchemy `class`es with base class of `Base`
:type sqlalchemy_asts: ```ClassDef```
:return: `from sqlalchemy import <>` where <> is what was inferred from `sqlalchemy_classes`
:rtype: ```ImportFrom```
Expand All @@ -324,7 +324,7 @@ def imports_from(sqlalchemy_classes):
chain.from_iterable(
map(
infer_imports_from_sqlalchemy,
sqlalchemy_classes,
sqlalchemy_asts,
)
),
)
Expand Down

0 comments on commit 6b2e79e

Please sign in to comment.