Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ tables.
- Validate that your replication mechnism is working correctly
- Find changes between two versions of the same table

It uses a bisection algorithm to efficiently check if e.g. a table is the same
between MySQL and Postgres, or Postgres and Snowflake, or MySQL and RDS!
It uses a bisection algorithm and checksums to efficiently check if e.g. a table
is the same between MySQL and Postgres, or Postgres and Snowflake, or MySQL and
RDS!

```python
$ data-diff postgres:/// Original postgres:/// Original_1diff -v --bisection-factor=4
Expand Down Expand Up @@ -164,9 +165,9 @@ Postgres) to avoid incurring the long setup time repeatedly.
```shell-session
preql -f dev/prepare_db.pql postgres://postgres:Password1@127.0.0.1:5432/postgres
preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql
preql -f dev/prepare_db.psq snowflake://<uri>
preql -f dev/prepare_db.psq mssql://<uri>
preql -f dev/prepare_db_bigquery.pql bigquery:///<project> # Bigquery has its own
preql -f dev/prepare_db.pql snowflake://<uri>
preql -f dev/prepare_db.pql mssql://<uri>
preql -f dev/prepare_db_bigquery.pql bigquery:///<project> # Bigquery has its own scripts
```

**6. Run data-diff against seeded database**
Expand Down
10 changes: 6 additions & 4 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
@click.argument("table1_name")
@click.argument("db2_uri")
@click.argument("table2_name")
@click.option("-k", "--key_column", default="id", help="Name of primary key column")
@click.option("-c", "--columns", default=["updated_at"], multiple=True, help="Names of extra columns to compare")
@click.option("-k", "--key-column", default="id", help="Name of primary key column")
@click.option("-t", "--update-column", default=None, help="Name of updated_at/last_updated column")
@click.option("-c", "--columns", default=[], multiple=True, help="Names of extra columns to compare")
@click.option("-l", "--limit", default=None, help="Maximum number of differences to find")
@click.option("--bisection-factor", default=32, help="Segments per iteration")
@click.option("--bisection-threshold", default=1024**2, help="Minimal bisection threshold")
Expand All @@ -31,6 +32,7 @@ def main(
db2_uri,
table2_name,
key_column,
update_column,
columns,
limit,
bisection_factor,
Expand All @@ -53,8 +55,8 @@ def main(

start = time.time()

table1 = TableSegment(db1, (table1_name,), key_column, columns)
table2 = TableSegment(db2, (table2_name,), key_column, columns)
table1 = TableSegment(db1, (table1_name,), key_column, update_column, columns)
table2 = TableSegment(db2, (table2_name,), key_column, update_column, columns)

differ = TableDiffer(bisection_factor=bisection_factor, bisection_threshold=bisection_threshold, debug=debug)
diff_iter = differ.diff_tables(table1, table2)
Expand Down
47 changes: 32 additions & 15 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from runtype import dataclass

from .sql import Select, Checksum, Compare, DbPath, DbKey, Count, Enum, TableName, In, Value
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, Enum, TableName, In, Value, Time
from .database import Database

logger = logging.getLogger("diff_tables")
Expand All @@ -23,23 +23,36 @@ class TableSegment:
database: Database
table_path: DbPath
key_column: str
extra_columns: Tuple[str, ...]
start: DbKey = None
end: DbKey = None
update_column: str = None
extra_columns: Tuple[str, ...] = ()
start_key: DbKey = None
end_key: DbKey = None
min_time: DbTime = None
max_time: DbTime = None

_count: int = None
_checksum: int = None

def _make_range_pred(self):
if self.start is not None:
yield Compare("<=", str(self.start), self.key_column)
if self.end is not None:
yield Compare("<", self.key_column, str(self.end))
def __post_init__(self):
if not self.update_column and (self.min_time or self.max_time):
raise ValueError("Error: min_time/max_time feature requires to specify 'update_column'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻


def _make_key_range(self):
if self.start_key is not None:
yield Compare("<=", str(self.start_key), self.key_column)
if self.end_key is not None:
yield Compare("<", self.key_column, str(self.end_key))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the key version will also benefit from a CLI arg 👍🏻 Thanks for changing its name and making it symmetric with time

Lots of people will probably instead of temporal just get the id of the last records when they did the last insertion, then do the check based on that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have -k. Do you mean something different?


def _make_update_range(self):
if self.min_time is not None:
yield Compare("<=", Time(self.min_time), self.update_column)
if self.max_time is not None:
yield Compare("<", self.update_column, Time(self.max_time))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A Python question: Why is yield preferred here to just returning a list? Because it's annoying to merge the lists I'm assuming, versus adding generators together and then materializing them to a list like you're doing below? Kinda cool pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's just a cleaner syntax, saving me from creating a list, appending to it, and returning it.


def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None):
if columns is None:
columns = [self.key_column]
where = list(self._make_range_pred()) + ([] if where is None else [where])
where = list(self._make_key_range()) + list(self._make_update_range()) + ([] if where is None else [where])
order_by = None if order_by is None else [order_by]
return Select(
table=table or TableName(self.table_path),
Expand Down Expand Up @@ -70,16 +83,16 @@ def find_checkpoints(self, checkpoints: List[DbKey]) -> List[DbKey]:
def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]:
"Split the current TableSegment to a bunch of smaller ones, separate by the given checkpoints"

if self.start and self.end:
assert all(self.start <= c < self.end for c in checkpoints)
if self.start_key and self.end_key:
assert all(self.start_key <= c < self.end_key for c in checkpoints)
checkpoints.sort()

# Calculate sub-segments
positions = [self.start] + checkpoints + [self.end]
positions = [self.start_key] + checkpoints + [self.end_key]
ranges = list(zip(positions[:-1], positions[1:]))

# Create table segments
tables = [self.new(start=s, end=e) for s, e in ranges]
tables = [self.new(start_key=s, end_key=e) for s, e in ranges]

return tables

Expand All @@ -102,7 +115,11 @@ def count(self) -> int:

@property
def _relevant_columns(self) -> List[str]:
return [self.key_column] + list(self.extra_columns)
return (
[self.key_column]
+ ([self.update_column] if self.update_column is not None else [])
+ list(self.extra_columns)
)

@property
def checksum(self) -> int:
Expand Down
10 changes: 10 additions & 0 deletions data_diff/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
"""

from typing import List, Union, Tuple, Optional
from datetime import datetime

from runtype import dataclass

DbPath = Tuple[str, ...]
DbKey = Union[int, str, bytes]
DbTime = datetime


class Sql:
Expand Down Expand Up @@ -139,3 +141,11 @@ def compile(self, c: Compiler):
if self.column:
return f"count({c.compile(self.column)})"
return "count(*)"


@dataclass
class Time(Sql):
time: datetime

def compile(self, c: Compiler):
return "'%s'" % self.time.isoformat()
86 changes: 80 additions & 6 deletions tests/test_diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,110 @@
import unittest

import preql
import arrow # comes with preql

from data_diff.database import connect_to_uri
from data_diff.diff_tables import TableDiffer, TableSegment

from .common import TEST_MYSQL_CONN_STRING, str_to_checksum


class TestDiffTables(unittest.TestCase):
class TestWithConnection(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Avoid leaking connections that require waiting for the GC, which can
# cause deadlocks for table-level modifications.
cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING)
cls.connection = connect_to_uri(TEST_MYSQL_CONN_STRING)

class TestDates(TestWithConnection):
def setUp(self):
self.connection.query("DROP TABLE IF EXISTS a", None)
self.connection.query("DROP TABLE IF EXISTS b", None)
self.preql(r"""
table a {
datetime: datetime
comment: string
}
commit()
func add(date, comment) {
new a(date, comment)
}
""")
self.now = now = arrow.get(self.preql.now())
self.preql.add(now.shift(days=-50), "50 days ago")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I try to run your tests, this line errors (and presumably the others will too):

CleanShot 2022-05-11 at 15 45 02@2x

Would love if preql gave some more useful output here 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look related to preql in any way. What is Python.Framework? Never heard of it.

self.preql.add(now.shift(hours=-3), "3 hours ago")
self.preql.add(now.shift(minutes=-10), "10 mins ago")
self.preql.add(now.shift(seconds=-1), "1 second ago")
self.preql.add(now, "now")

self.preql(r"""
const table b = a
commit()
""")

self.preql.add(self.now.shift(seconds=-3), "2 seconds ago")
self.preql.commit()


def test_init(self):
a = TableSegment(self.connection, ('a', ), 'id', 'datetime', max_time=self.now.datetime)
self.assertRaises(ValueError, TableSegment, self.connection, ('a', ), 'id', max_time=self.now.datetime)

def test_basic(self):
differ = TableDiffer(10, 100)
a = TableSegment(self.connection, ('a', ), 'id', 'datetime')
b = TableSegment(self.connection, ('b', ), 'id', 'datetime')
assert a.count == 6
assert b.count == 5

assert not list(differ.diff_tables(a, a))
self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1 )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to run black(1), which'll probably complain about the spaces here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thanks


def test_offset(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a test where the diff actually passes when you pass a range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I do at the end of test_offset, unless you mean something else

differ = TableDiffer(2, 10)
sec1 = self.now.shift(seconds=-1).datetime
a = TableSegment(self.connection, ('a', ), 'id', 'datetime', max_time=sec1)
b = TableSegment(self.connection, ('b', ), 'id', 'datetime', max_time=sec1)
assert a.count == 4
assert b.count == 3

assert not list(differ.diff_tables(a, a))
self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1 )

a = TableSegment(self.connection, ('a', ), 'id', 'datetime', min_time=sec1)
b = TableSegment(self.connection, ('b', ), 'id', 'datetime', min_time=sec1)
assert a.count == 2
assert b.count == 2
assert not list(differ.diff_tables(a, b))

day1 = self.now.shift(days=-1).datetime

a = TableSegment(self.connection, ('a', ), 'id', 'datetime', min_time=day1, max_time=sec1)
b = TableSegment(self.connection, ('b', ), 'id', 'datetime', min_time=day1, max_time=sec1)
assert a.count == 3
assert b.count == 2
assert not list(differ.diff_tables(a, a))
self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1)


class TestDiffTables(TestWithConnection):

def setUp(self):
self.connection.query("DROP TABLE IF EXISTS ratings_test", None)
self.connection.query("DROP TABLE IF EXISTS ratings_test2", None)
self.preql.load("./tests/setup.pql")
self.preql.commit()

self.table = TableSegment(TestDiffTables.connection,
self.table = TableSegment(self.connection,
('ratings_test', ),
'id',
('timestamp', ))
'timestamp')

self.table2 = TableSegment(TestDiffTables.connection,
self.table2 = TableSegment(self.connection,
("ratings_test2", ),
'id',
('timestamp', ))
'timestamp')

self.differ = TableDiffer(3, 4)

Expand Down