Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
@click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff")
@click.option("-d", "--debug", is_flag=True, help="Print debug info")
@click.option("-v", "--verbose", is_flag=True, help="Print extra info")
@click.option('-i', '--interactive', is_flag=True, help='Confirm queries, implies --debug')
def main(
db1_uri,
table1_name,
Expand All @@ -52,10 +53,13 @@ def main(
stats,
debug,
verbose,
interactive,
):
if limit and stats:
print("Error: cannot specify a limit when using the -s/--stats switch")
return
if interactive:
debug = True

if debug:
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
Expand All @@ -64,6 +68,10 @@ def main(

db1 = connect_to_uri(db1_uri)
db2 = connect_to_uri(db2_uri)

if interactive:
db1.enable_interactive()
db2.enable_interactive()

start = time.time()

Expand Down
17 changes: 15 additions & 2 deletions data_diff/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import Tuple

import dsnparse
import os

from .sql import SqlOrStr, Compiler
from .sql import SqlOrStr, Compiler, Explain, Select


logger = logging.getLogger("database")
Expand Down Expand Up @@ -65,8 +66,17 @@ def _query(self, sql_code: str) -> list:

def query(self, sql_ast: SqlOrStr, res_type: type):
"Query the given SQL AST, and attempt to convert the result to type 'res_type'"
sql_code = Compiler(self).compile(sql_ast)
compiler = Compiler(self)
sql_code = compiler.compile(sql_ast)
logger.debug("Running SQL (%s): %s", type(self).__name__, sql_code)
if getattr(self, '_interactive', False) and isinstance(sql_ast, Select):
explained_sql = compiler.compile(Explain(sql_ast))
logger.info(f"EXPLAIN for SQL SELECT")
logger.info(self._query(explained_sql))
answer = input("Continue? [y/n] ")
if not answer.lower() in ["y", "yes"]:
os.exit(1)

res = self._query(sql_code)
if res_type is int:
res = _one(_one(res))
Expand All @@ -81,6 +91,9 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
else:
raise ValueError(res_type)
return res

def enable_interactive(self):
self._interactive = True

@abstractmethod
def quote(self, s: str):
Expand Down
13 changes: 13 additions & 0 deletions data_diff/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def compile(self, c: Compiler):
return f"({c.compile(self.expr)} IN ({elems}))"



@dataclass
class Count(Sql):
column: Optional[SqlOrStr] = None
Expand All @@ -146,6 +147,18 @@ def compile(self, c: Compiler):
@dataclass
class Time(Sql):
time: datetime
column: Optional[SqlOrStr] = None

def compile(self, c: Compiler):
return "'%s'" % self.time.isoformat()
if self.column:
return f"count({c.compile(self.column)})"
return 'count(*)'


@dataclass
class Explain(Sql):
sql: Select

def compile(self, c: Compiler):
return f"EXPLAIN {c.compile(self.sql)}"
13 changes: 12 additions & 1 deletion tests/test_sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest

from data_diff.database import connect_to_uri
from data_diff.sql import Checksum, Compare, Compiler, Enum, In, Select, TableName, Count
from data_diff.sql import (Checksum, Compare, Compiler, Count, Enum, Explain, In,
Select, TableName)

from .common import TEST_MYSQL_CONN_STRING

Expand Down Expand Up @@ -91,3 +92,13 @@ def test_count_with_column(self):
Select([Count("id")], TableName(("marine_mammals", "walrus")), [In("id", [1, 2, 3])])
),
)

def test_explain(self):
expected_sql = "EXPLAIN SELECT count(id) FROM `marine_mammals.walrus` WHERE (id IN (1, 2, 3))"
self.assertEqual(expected_sql, self.compiler.compile(
Explain(Select(
[Count("id")],
TableName(("marine_mammals", "walrus")),
[In("id", [1, 2, 3])]
)))
)