diff --git a/cdd/compound/gen_utils.py b/cdd/compound/gen_utils.py index 1b062d7..7c6811d 100644 --- a/cdd/compound/gen_utils.py +++ b/cdd/compound/gen_utils.py @@ -271,7 +271,7 @@ def gen_module( if emit_and_infer_imports: imports = "{}{}".format( imports or "", - " ".join(map(to_code, map(infer_imports, functions_and_classes))), + " ".join(map(to_code, chain(*map(infer_imports, functions_and_classes)))), ) # Too many params! - Clean things up for debugging: diff --git a/cdd/shared/ast_utils.py b/cdd/shared/ast_utils.py index e193b93..91c2139 100644 --- a/cdd/shared/ast_utils.py +++ b/cdd/shared/ast_utils.py @@ -1646,26 +1646,25 @@ def infer_imports(module): """ import cdd.sqlalchemy.utils.parser_utils # Should this be a function param instead? - if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef)): + if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign)): module = Module(body=[module], type_ignores=[], stmt=None) assert isinstance(module, Module), "Expected `Module` got `{type_name}`".format( type_name=type(module).__name__ ) - sqlalchemy_classes = filter( - lambda cls_def: any( + sqlalchemy_class_or_assigns = filter( + lambda class_or_assign_def: any( filter( - lambda base: isinstance(base, Name) and base.id == "Base", cls_def.bases + lambda base: isinstance(base, Name) and base.id == "Base", class_or_assign_def.bases ) - ), - filter(rpartial(isinstance, ClassDef), module.body), + ) if isinstance(class_or_assign_def, ClassDef) else isinstance(class_or_assign_def.value, Call) and class_or_assign_def.value.func.id.endswith("Table"), + filter(rpartial(isinstance, (ClassDef, Assign)), module.body), ) - # reduce(, sqlalchemy_classes, set) return list( ( - (cdd.sqlalchemy.utils.parser_utils.imports_from(sqlalchemy_classes),) - if sqlalchemy_classes + (cdd.sqlalchemy.utils.parser_utils.imports_from(sqlalchemy_class_or_assigns),) + if sqlalchemy_class_or_assigns else iter(()) ) ) diff --git a/cdd/sqlalchemy/utils/parser_utils.py b/cdd/sqlalchemy/utils/parser_utils.py index 6f46da8..91ea2e1 100644 --- a/cdd/sqlalchemy/utils/parser_utils.py +++ b/cdd/sqlalchemy/utils/parser_utils.py @@ -255,12 +255,12 @@ def column_call_name_manipulator(call, operation="remove", name=None): return call -def infer_imports_from_sqlalchemy(sqlalchemy_class_def): +def infer_imports_from_sqlalchemy(sqlalchemy_class_or_assigns): """ - Infer imports from SQLalchemy class + Infer imports from SQLalchemy ast - :param sqlalchemy_class_def: SQLalchemy class - :type sqlalchemy_class_def: ```ClassDef``` + :param sqlalchemy_class_or_assigns: SQLalchemy Class or Assign + :type sqlalchemy_class_or_assigns: ```Union[ClassDef, Assign]``` :return: filter of imports (can be considered ```Iterable[str]```) :rtype: ```filter``` @@ -275,7 +275,7 @@ def infer_imports_from_sqlalchemy(sqlalchemy_class_def): body=list( filter( rpartial(isinstance, Call), - ast.walk(sqlalchemy_class_def), + ast.walk(sqlalchemy_class_or_assigns), ) ), type_ignores=[], @@ -297,12 +297,12 @@ def infer_imports_from_sqlalchemy(sqlalchemy_class_def): return candidates_not_in_valid_types ^ candidates -def imports_from(sqlalchemy_classes): +def imports_from(sqlalchemy_asts): """ Generate `from sqlalchemy import <>` from the body of SQLalchemy `class`es - :param sqlalchemy_classes: SQLalchemy `class`es with base class of `Base` - :type sqlalchemy_classes: ```ClassDef``` + :param sqlalchemy_asts: SQLalchemy `class`es with base class of `Base` + :type sqlalchemy_asts: ```ClassDef``` :return: `from sqlalchemy import <>` where <> is what was inferred from `sqlalchemy_classes` :rtype: ```ImportFrom``` @@ -324,7 +324,7 @@ def imports_from(sqlalchemy_classes): chain.from_iterable( map( infer_imports_from_sqlalchemy, - sqlalchemy_classes, + sqlalchemy_asts, ) ), )