diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 4a358e4c..1e6e22a4 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -36,6 +36,7 @@ @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") @click.option("-v", "--verbose", is_flag=True, help="Print extra info") +@click.option('-i', '--interactive', is_flag=True, help='Confirm queries, implies --debug') def main( db1_uri, table1_name, @@ -52,10 +53,13 @@ def main( stats, debug, verbose, + interactive, ): if limit and stats: print("Error: cannot specify a limit when using the -s/--stats switch") return + if interactive: + debug = True if debug: logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT) @@ -64,6 +68,10 @@ def main( db1 = connect_to_uri(db1_uri) db2 = connect_to_uri(db2_uri) + + if interactive: + db1.enable_interactive() + db2.enable_interactive() start = time.time() diff --git a/data_diff/database.py b/data_diff/database.py index 8d9f058c..5fefd290 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -3,8 +3,9 @@ from typing import Tuple import dsnparse +import os -from .sql import SqlOrStr, Compiler +from .sql import SqlOrStr, Compiler, Explain, Select logger = logging.getLogger("database") @@ -65,8 +66,17 @@ def _query(self, sql_code: str) -> list: def query(self, sql_ast: SqlOrStr, res_type: type): "Query the given SQL AST, and attempt to convert the result to type 'res_type'" - sql_code = Compiler(self).compile(sql_ast) + compiler = Compiler(self) + sql_code = compiler.compile(sql_ast) logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code) + if getattr(self, '_interactive', False) and isinstance(sql_ast, Select): + explained_sql = compiler.compile(Explain(sql_ast)) + logger.info(f"EXPLAIN for SQL SELECT") + logger.info(self._query(explained_sql)) + answer = input("Continue? [y/n] ") + if not answer.lower() in ["y", "yes"]: + os.exit(1) + res = self._query(sql_code) if res_type is int: res = _one(_one(res)) @@ -81,6 +91,9 @@ def query(self, sql_ast: SqlOrStr, res_type: type): else: raise ValueError(res_type) return res + + def enable_interactive(self): + self._interactive = True @abstractmethod def quote(self, s: str): diff --git a/data_diff/sql.py b/data_diff/sql.py index 94978584..a160824e 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -133,6 +133,7 @@ def compile(self, c: Compiler): return f"({c.compile(self.expr)} IN ({elems}))" + @dataclass class Count(Sql): column: Optional[SqlOrStr] = None @@ -146,6 +147,18 @@ def compile(self, c: Compiler): @dataclass class Time(Sql): time: datetime + column: Optional[SqlOrStr] = None def compile(self, c: Compiler): return "'%s'" % self.time.isoformat() + if self.column: + return f"count({c.compile(self.column)})" + return 'count(*)' + + +@dataclass +class Explain(Sql): + sql: Select + + def compile(self, c: Compiler): + return f"EXPLAIN {c.compile(self.sql)}" diff --git a/tests/test_sql.py b/tests/test_sql.py index 89e2d658..184df571 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,7 +1,8 @@ import unittest from data_diff.database import connect_to_uri -from data_diff.sql import Checksum, Compare, Compiler, Enum, In, Select, TableName, Count +from data_diff.sql import (Checksum, Compare, Compiler, Count, Enum, Explain, In, + Select, TableName) from .common import TEST_MYSQL_CONN_STRING @@ -91,3 +92,13 @@ def test_count_with_column(self): Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])]) ), ) + + def test_explain(self): + expected_sql = "EXPLAIN SELECT count(id) FROM `marine_mammals.walrus` WHERE (id IN (1, 2, 3))" + self.assertEqual(expected_sql, self.compiler.compile( + Explain(Select( + [Count("id")], + TableName(("marine_mammals", "walrus")), + [In("id", [1, 2, 3])] + ))) + )