Skip to content

Commit

Permalink
[cdd/emit/utils/sqlalchemy_utils.py] Rewrite foreign keys to the corr…
Browse files Browse the repository at this point in the history
…ect type ; [cdd/parse/utils/sqlalchemy_utils.py] Implement `get_pk_and_type` and `get_table_name` helpers for foreign key inference
  • Loading branch information
SamuelMarks committed Jan 9, 2023
1 parent 2de9f50 commit ee58e7f
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 8 deletions.
171 changes: 167 additions & 4 deletions cdd/emit/utils/sqlalchemy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import ast
from ast import (
AST,
Assign,
Attribute,
BinOp,
Call,
ClassDef,
Expr,
FunctionDef,
ImportFrom,
Expand All @@ -27,9 +29,17 @@
from cdd.ast_utils import get_value, maybe_type_comment, set_arg, set_value
from cdd.parse.utils.sqlalchemy_utils import (
column_type2typ,
get_pk_and_type,
get_table_name,
sqlalchemy_top_level_imports,
)
from cdd.pure_utils import none_types, rpartial, tab, upper_camelcase_to_pascal
from cdd.pure_utils import (
find_module_filepath,
none_types,
rpartial,
tab,
upper_camelcase_to_pascal,
)
from cdd.source_transformer import to_code
from cdd.tests.mocks.docstrings import docstring_repr_google_str, docstring_repr_str

Expand Down Expand Up @@ -478,7 +488,7 @@ def update_fk_for_file(filename):
- All SQLalchemy models being in the same directory as filename
- Correct imports being added
Then it can transform:
Then it can transform classes with members like:
```py
Column(
TableName0,
Expand All @@ -488,13 +498,166 @@ def update_fk_for_file(filename):
```
To the following, inferring that the primary key field is `id` by resolving the symbol and `ast.parse`ing it:
```py
Column(Integer, ForeignKey("table_name.id"))
Column(Integer, ForeignKey("table_name0.id"))
```
:param filename: Filename
:type filename: ```str```
"""
raise NotImplementedError("update_fk_for_file" + filename)
with open(filename, "rt") as f:
mod = ast.parse(f.read())

def handle_sqlalchemy_cls(symbol_to_module, sqlalchemy_class_def):
"""
Ensure the SQLalchemy classes have their foreign keys resolved properly
:param symbol_to_module: Dictionary of symbol to module, like `{"join": "os.path"}`
:type symbol_to_module: ```Dict[str,str]````
:param sqlalchemy_class_def: SQLalchemy `class`
:type sqlalchemy_class_def: ```ClassDef```
:return: SQLalchemy with foreign keys resolved properly
:rtype: ```ClassDef```
"""
sqlalchemy_class_def.body = list(
map(
lambda outer_node: rewrite_fk(symbol_to_module, outer_node)
if isinstance(outer_node, Assign)
and isinstance(outer_node.value, Call)
and isinstance(outer_node.value.func, Name)
and outer_node.value.func.id == "Column"
and any(
filter(
lambda node: isinstance(node, Call)
and isinstance(node.func, Name)
and node.func.id == "ForeignKey",
outer_node.value.args,
)
)
else outer_node,
sqlalchemy_class_def.body,
)
)
return sqlalchemy_class_def

symbol2module = dict(
chain.from_iterable(
map(
lambda import_from: map(
lambda _alias: (_alias.name, import_from.module), import_from.names
),
filterfalse(
lambda import_from: import_from.module == "sqlalchemy",
filter(
rpartial(isinstance, ImportFrom),
ast.walk(mod),
),
),
)
)
)

mod.body = list(
map(
lambda node: handle_sqlalchemy_cls(symbol2module, node)
if isinstance(node, ClassDef)
and any(
filter(
lambda base: isinstance(base, Name) and base.id == "Base",
node.bases,
)
)
else node,
mod.body,
)
)

with open(filename, "wt") as f:
f.write(to_code(mod))


def rewrite_fk(symbol_to_module, column_assign):
"""
Rewrite of the form:
```py
column_name = Column(
TableName0,
ForeignKey("TableName0"),
nullable=True,
)
```
To the following, inferring that the primary key field is `id` by resolving the symbol and `ast.parse`ing it:
```py
column_name = Column(Integer, ForeignKey("table_name0.id"))
```
:param symbol_to_module: Dictionary of symbol to module, like `{"join": "os.path"}`
:type symbol_to_module: ```Dict[str,str]````
:param column_assign: `column_name = Column()` in SQLalchemy with unresolved foreign key
:type column_assign: ```Assign```d
:return: `Assign()` in SQLalchemy with resolved foreign key
:rtype: ```Assign```
"""
assert (
isinstance(column_assign.value, Call)
and isinstance(column_assign.value.func, Name)
and column_assign.value.func.id == "Column"
)

def rewrite_fk_from_import(column_name, foreign_key_call):
"""
:param column_name: Field name
:type column_name: ```Name```
:param foreign_key_call: `ForeignKey` function call
:type foreign_key_call: ```Call```
:return:
:rtype: ```Tuple[Name, Call]```
"""
assert isinstance(column_name, Name)
assert (
isinstance(foreign_key_call, Call)
and isinstance(foreign_key_call.func, Name)
and foreign_key_call.func.id == "ForeignKey"
)
if column_name.id in symbol_to_module:
with open(
find_module_filepath(symbol_to_module[column_name.id], column_name.id),
"rt",
) as f:
mod = ast.parse(f.read())
matching_class = next(
filter(
lambda node: isinstance(node, ClassDef)
and node.name == column_name.id,
mod.body,
)
)
pk_typ = get_pk_and_type(matching_class)
assert pk_typ is not None
pk, typ = pk_typ
del pk_typ
return Name(id=typ, ctx=Load()), Call(
func=Name(id="ForeignKey", ctx=Load()),
args=[set_value(".".join((get_table_name(matching_class), pk)))],
keywords=[],
)
return column_name, foreign_key_call

column_assign.value.args = list(
chain.from_iterable(
(
rewrite_fk_from_import(*column_assign.value.args[:2]),
column_assign.value.args[2:],
)
)
)

return column_assign


typ2column_type = {v: k for k, v in column_type2typ.items()}
Expand Down
1 change: 0 additions & 1 deletion cdd/parse/utils/json_schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def transform_ref_fk_set(ref, foreign_key):
"""
entity = pascal_to_upper_camelcase(ref.rpartition("/")[2])
foreign_key["fk"] = entity
print(repr(ref), "->", repr(entity))
return entity

fk = {"fk": None}
Expand Down
87 changes: 86 additions & 1 deletion cdd/parse/utils/sqlalchemy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
"""

import ast
from ast import Call, Constant, ImportFrom, Load, Module, Name, Str, alias
from ast import (
Assign,
Call,
ClassDef,
Constant,
ImportFrom,
Load,
Module,
Name,
Str,
alias,
)
from itertools import chain, filterfalse
from operator import attrgetter

Expand Down Expand Up @@ -300,6 +311,80 @@ def imports_from(sqlalchemy_classes):
)


def get_pk_and_type(sqlalchemy_class):
"""
Get the primary key and its type from an SQLalchemy class
:param sqlalchemy_class: SQLalchemy class
:type sqlalchemy_class: ```ClassDef```
:return: Primary key and its type
:rtype: ```Tuple[str, str]```
"""
assert isinstance(
sqlalchemy_class, ClassDef
), "Expected `ClassDef` got `{type_name}`".format(
type_name=type(sqlalchemy_class).__name__
)
return (
lambda assign: assign
if assign is None
else (
assign.targets[0].id,
assign.value.args[0].id, # First arg is type
)
)(
next(
filter(
lambda assign: any(
filter(
lambda key_word: key_word.arg == "primary_key"
and get_value(key_word.value) is True,
assign.value.keywords,
)
),
filter(
lambda assign: isinstance(assign.value, Call)
and isinstance(assign.value.func, Name)
and assign.value.func.id == "Column",
filter(rpartial(isinstance, Assign), sqlalchemy_class.body),
),
),
None,
)
)


def get_table_name(sqlalchemy_class):
"""
Get the primary key and its type from an SQLalchemy class
:param sqlalchemy_class: SQLalchemy class
:type sqlalchemy_class: ```ClassDef```
:return: Primary key and its type
:rtype: ```str```
"""
return next(
map(
lambda assign: get_value(assign.value),
filter(
lambda node: next(
filter(lambda target: target.id == "__tablename__", node.targets),
None,
)
and node,
filter(
lambda node: isinstance(node, Assign)
and isinstance(node.value, (Str, Constant)),
sqlalchemy_class.body,
),
),
),
sqlalchemy_class.name,
)


# Construct from https://docs.sqlalchemy.org/en/13/core/type_basics.html#generic-types
column_type2typ = {
"BigInteger": "int",
Expand Down
29 changes: 27 additions & 2 deletions cdd/tests/test_parse/test_parse_sqlalchemy_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
"""
Tests for the utils that is used by the SQLalchemy parsers
"""

from copy import deepcopy
from unittest import TestCase

from cdd.parse.utils.sqlalchemy_utils import (
column_call_name_manipulator,
column_call_to_param,
get_pk_and_type,
get_table_name,
)
from cdd.tests.mocks.ir import intermediate_repr_node_pk
from cdd.tests.mocks.json_schema import config_schema
from cdd.tests.mocks.sqlalchemy import dataset_primary_key_column_assign, node_fk_call
from cdd.tests.mocks.sqlalchemy import (
config_decl_base_ast,
dataset_primary_key_column_assign,
node_fk_call,
)
from cdd.tests.utils_for_tests import unittest_main


Expand Down Expand Up @@ -62,5 +67,25 @@ def test_column_call_to_param_not_implemented(self) -> None:
call.args[2].func.id = "NotFound"
self.assertRaises(NotImplementedError, column_call_to_param, call)

def test_get_pk_and_type(self) -> None:
"""
Tests get_pk_and_type
"""
self.assertEqual(
get_pk_and_type(config_decl_base_ast), ("dataset_name", "String")
)
no_pk = deepcopy(config_decl_base_ast)
del no_pk.body[2]
self.assertIsNone(get_pk_and_type(no_pk))

def test_get_table_name(self) -> None:
"""
Tests `get_table_name`
"""
self.assertEqual(get_table_name(config_decl_base_ast), "config_tbl")
no_table_name = deepcopy(config_decl_base_ast)
del no_table_name.body[1]
self.assertEqual(get_table_name(no_table_name), "Config")


unittest_main()

0 comments on commit ee58e7f

Please sign in to comment.