From f8d24ea09b819774000858d2d35a8baf6fa16405 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 27 Oct 2022 16:36:57 -0300 Subject: [PATCH 1/2] Tests: one more data-type for oracle --- tests/test_database_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_database_types.py b/tests/test_database_types.py index bb792826..1c36ff7f 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -178,6 +178,7 @@ def init_conns(): "numeric", "real", "double precision", + "Number(5, 2)", ], "uuid": [ "CHAR(100)", From 5e879faba934be56dcbaae99d3538bcab34b2390 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 28 Oct 2022 10:40:21 -0300 Subject: [PATCH 2/2] Queries: Added Param mechanism, to help speed up query construction. --- data_diff/queries/ast_classes.py | 17 ++++++++++++++++- data_diff/queries/compiler.py | 10 +++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 88d7ab11..f363df14 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -6,7 +6,7 @@ from data_diff.utils import ArithString, join_iter -from .compiler import Compilable, Compiler +from .compiler import Compilable, Compiler, cv_params from .base import SKIP, CompileError, DbPath, Schema, args_as_tuple @@ -691,3 +691,18 @@ def compile(self, c: Compiler) -> str: class Commit(Statement): def compile(self, c: Compiler) -> str: return "COMMIT" if not c.database.is_autocommit else SKIP + +@dataclass +class Param(ExprNode, ITable): + """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" + + name: str + + @property + def source_table(self): + return self + + def compile(self, c: Compiler) -> str: + params = cv_params.get() + return c._compile(params[self.name]) + diff --git a/data_diff/queries/compiler.py b/data_diff/queries/compiler.py index 31242131..e9a66bed 100644 --- a/data_diff/queries/compiler.py +++ b/data_diff/queries/compiler.py @@ -8,10 +8,15 @@ from data_diff.utils import ArithString from data_diff.databases.database_types import AbstractDialect, DbPath +import contextvars + +cv_params = contextvars.ContextVar("params") + @dataclass class Compiler: database: AbstractDialect + params: dict = {} in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag @@ -21,7 +26,10 @@ class Compiler: _counter: List = [0] - def compile(self, elem) -> str: + def compile(self, elem, params=None) -> str: + if params: + cv_params.set(params) + res = self._compile(elem) if self.root and self._subqueries: subq = ", ".join(f"\n {k} AS ({v})" for k, v in self._subqueries.items())