diff --git a/README.md b/README.md index 8e061a10..ec5fd77e 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,9 @@ tables. - Validate that your replication mechnism is working correctly - Find changes between two versions of the same table -It uses a bisection algorithm to efficiently check if e.g. a table is the same -between MySQL and Postgres, or Postgres and Snowflake, or MySQL and RDS! +It uses a bisection algorithm and checksums to efficiently check if e.g. a table +is the same between MySQL and Postgres, or Postgres and Snowflake, or MySQL and +RDS! ```python $ data-diff postgres:/// Original postgres:/// Original_1diff -v --bisection-factor=4 @@ -164,9 +165,9 @@ Postgres) to avoid incurring the long setup time repeatedly. ```shell-session preql -f dev/prepare_db.pql postgres://postgres:Password1@127.0.0.1:5432/postgres preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql -preql -f dev/prepare_db.psq snowflake:// -preql -f dev/prepare_db.psq mssql:// -preql -f dev/prepare_db_bigquery.pql bigquery:/// # Bigquery has its own +preql -f dev/prepare_db.pql snowflake:// +preql -f dev/prepare_db.pql mssql:// +preql -f dev/prepare_db_bigquery.pql bigquery:/// # Bigquery has its own scripts ``` **6. Run data-diff against seeded database** diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 9f68bdda..57a76752 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -17,8 +17,9 @@ @click.argument("table1_name") @click.argument("db2_uri") @click.argument("table2_name") -@click.option("-k", "--key_column", default="id", help="Name of primary key column") -@click.option("-c", "--columns", default=["updated_at"], multiple=True, help="Names of extra columns to compare") +@click.option("-k", "--key-column", default="id", help="Name of primary key column") +@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column") +@click.option("-c", "--columns", default=[], multiple=True, help="Names of extra columns to compare") @click.option("-l", "--limit", default=None, help="Maximum number of differences to find") @click.option("--bisection-factor", default=32, help="Segments per iteration") @click.option("--bisection-threshold", default=1024**2, help="Minimal bisection threshold") @@ -31,6 +32,7 @@ def main( db2_uri, table2_name, key_column, + update_column, columns, limit, bisection_factor, @@ -53,8 +55,8 @@ def main( start = time.time() - table1 = TableSegment(db1, (table1_name,), key_column, columns) - table2 = TableSegment(db2, (table2_name,), key_column, columns) + table1 = TableSegment(db1, (table1_name,), key_column, update_column, columns) + table2 = TableSegment(db2, (table2_name,), key_column, update_column, columns) differ = TableDiffer(bisection_factor=bisection_factor, bisection_threshold=bisection_threshold, debug=debug) diff_iter = differ.diff_tables(table1, table2) diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 53053826..b372b661 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -6,7 +6,7 @@ from runtype import dataclass -from .sql import Select, Checksum, Compare, DbPath, DbKey, Count, Enum, TableName, In, Value +from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, Enum, TableName, In, Value, Time from .database import Database logger = logging.getLogger("diff_tables") @@ -23,23 +23,36 @@ class TableSegment: database: Database table_path: DbPath key_column: str - extra_columns: Tuple[str, ...] - start: DbKey = None - end: DbKey = None + update_column: str = None + extra_columns: Tuple[str, ...] = () + start_key: DbKey = None + end_key: DbKey = None + min_time: DbTime = None + max_time: DbTime = None _count: int = None _checksum: int = None - def _make_range_pred(self): - if self.start is not None: - yield Compare("<=", str(self.start), self.key_column) - if self.end is not None: - yield Compare("<", self.key_column, str(self.end)) + def __post_init__(self): + if not self.update_column and (self.min_time or self.max_time): + raise ValueError("Error: min_time/max_time feature requires to specify 'update_column'") + + def _make_key_range(self): + if self.start_key is not None: + yield Compare("<=", str(self.start_key), self.key_column) + if self.end_key is not None: + yield Compare("<", self.key_column, str(self.end_key)) + + def _make_update_range(self): + if self.min_time is not None: + yield Compare("<=", Time(self.min_time), self.update_column) + if self.max_time is not None: + yield Compare("<", self.update_column, Time(self.max_time)) def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None): if columns is None: columns = [self.key_column] - where = list(self._make_range_pred()) + ([] if where is None else [where]) + where = list(self._make_key_range()) + list(self._make_update_range()) + ([] if where is None else [where]) order_by = None if order_by is None else [order_by] return Select( table=table or TableName(self.table_path), @@ -70,16 +83,16 @@ def find_checkpoints(self, checkpoints: List[DbKey]) -> List[DbKey]: def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]: "Split the current TableSegment to a bunch of smaller ones, separate by the given checkpoints" - if self.start and self.end: - assert all(self.start <= c < self.end for c in checkpoints) + if self.start_key and self.end_key: + assert all(self.start_key <= c < self.end_key for c in checkpoints) checkpoints.sort() # Calculate sub-segments - positions = [self.start] + checkpoints + [self.end] + positions = [self.start_key] + checkpoints + [self.end_key] ranges = list(zip(positions[:-1], positions[1:])) # Create table segments - tables = [self.new(start=s, end=e) for s, e in ranges] + tables = [self.new(start_key=s, end_key=e) for s, e in ranges] return tables @@ -102,7 +115,11 @@ def count(self) -> int: @property def _relevant_columns(self) -> List[str]: - return [self.key_column] + list(self.extra_columns) + return ( + [self.key_column] + + ([self.update_column] if self.update_column is not None else []) + + list(self.extra_columns) + ) @property def checksum(self) -> int: diff --git a/data_diff/sql.py b/data_diff/sql.py index 3699b27a..94978584 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -2,11 +2,13 @@ """ from typing import List, Union, Tuple, Optional +from datetime import datetime from runtype import dataclass DbPath = Tuple[str, ...] DbKey = Union[int, str, bytes] +DbTime = datetime class Sql: @@ -139,3 +141,11 @@ def compile(self, c: Compiler): if self.column: return f"count({c.compile(self.column)})" return "count(*)" + + +@dataclass +class Time(Sql): + time: datetime + + def compile(self, c: Compiler): + return "'%s'" % self.time.isoformat() diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index eb0bb6fe..9b340890 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -2,14 +2,14 @@ import unittest import preql +import arrow # comes with preql from data_diff.database import connect_to_uri from data_diff.diff_tables import TableDiffer, TableSegment from .common import TEST_MYSQL_CONN_STRING, str_to_checksum - -class TestDiffTables(unittest.TestCase): +class TestWithConnection(unittest.TestCase): @classmethod def setUpClass(cls): # Avoid leaking connections that require waiting for the GC, which can @@ -17,21 +17,95 @@ def setUpClass(cls): cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING) cls.connection = connect_to_uri(TEST_MYSQL_CONN_STRING) +class TestDates(TestWithConnection): + def setUp(self): + self.connection.query("DROP TABLE IF EXISTS a", None) + self.connection.query("DROP TABLE IF EXISTS b", None) + self.preql(r""" + table a { + datetime: datetime + comment: string + } + commit() + + func add(date, comment) { + new a(date, comment) + } + """) + self.now = now = arrow.get(self.preql.now()) + self.preql.add(now.shift(days=-50), "50 days ago") + self.preql.add(now.shift(hours=-3), "3 hours ago") + self.preql.add(now.shift(minutes=-10), "10 mins ago") + self.preql.add(now.shift(seconds=-1), "1 second ago") + self.preql.add(now, "now") + + self.preql(r""" + const table b = a + commit() + """) + + self.preql.add(self.now.shift(seconds=-3), "2 seconds ago") + self.preql.commit() + + + def test_init(self): + a = TableSegment(self.connection, ('a', ), 'id', 'datetime', max_time=self.now.datetime) + self.assertRaises(ValueError, TableSegment, self.connection, ('a', ), 'id', max_time=self.now.datetime) + + def test_basic(self): + differ = TableDiffer(10, 100) + a = TableSegment(self.connection, ('a', ), 'id', 'datetime') + b = TableSegment(self.connection, ('b', ), 'id', 'datetime') + assert a.count == 6 + assert b.count == 5 + + assert not list(differ.diff_tables(a, a)) + self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1 ) + + def test_offset(self): + differ = TableDiffer(2, 10) + sec1 = self.now.shift(seconds=-1).datetime + a = TableSegment(self.connection, ('a', ), 'id', 'datetime', max_time=sec1) + b = TableSegment(self.connection, ('b', ), 'id', 'datetime', max_time=sec1) + assert a.count == 4 + assert b.count == 3 + + assert not list(differ.diff_tables(a, a)) + self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1 ) + + a = TableSegment(self.connection, ('a', ), 'id', 'datetime', min_time=sec1) + b = TableSegment(self.connection, ('b', ), 'id', 'datetime', min_time=sec1) + assert a.count == 2 + assert b.count == 2 + assert not list(differ.diff_tables(a, b)) + + day1 = self.now.shift(days=-1).datetime + + a = TableSegment(self.connection, ('a', ), 'id', 'datetime', min_time=day1, max_time=sec1) + b = TableSegment(self.connection, ('b', ), 'id', 'datetime', min_time=day1, max_time=sec1) + assert a.count == 3 + assert b.count == 2 + assert not list(differ.diff_tables(a, a)) + self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1) + + +class TestDiffTables(TestWithConnection): + def setUp(self): self.connection.query("DROP TABLE IF EXISTS ratings_test", None) self.connection.query("DROP TABLE IF EXISTS ratings_test2", None) self.preql.load("./tests/setup.pql") self.preql.commit() - self.table = TableSegment(TestDiffTables.connection, + self.table = TableSegment(self.connection, ('ratings_test', ), 'id', - ('timestamp', )) + 'timestamp') - self.table2 = TableSegment(TestDiffTables.connection, + self.table2 = TableSegment(self.connection, ("ratings_test2", ), 'id', - ('timestamp', )) + 'timestamp') self.differ = TableDiffer(3, 4)