-
Notifications
You must be signed in to change notification settings - Fork 299
Update column + temporal filtering #19
Changes from all commits
dcc8aaa
23f1cca
6d33dbc
179bccd
d6d0d80
61b120b
f4d75bc
439635b
a766ab7
3be1a08
e0f3d99
a893493
b7aa5e0
507e788
82390d4
2055d7e
42a48c8
5ec06ab
b9688f0
f95bc86
79a440e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ | |
|
|
||
| from runtype import dataclass | ||
|
|
||
| from .sql import Select, Checksum, Compare, DbPath, DbKey, Count, Enum, TableName, In, Value | ||
| from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, Enum, TableName, In, Value, Time | ||
| from .database import Database | ||
|
|
||
| logger = logging.getLogger("diff_tables") | ||
|
|
@@ -23,23 +23,36 @@ class TableSegment: | |
| database: Database | ||
| table_path: DbPath | ||
| key_column: str | ||
| extra_columns: Tuple[str, ...] | ||
| start: DbKey = None | ||
| end: DbKey = None | ||
| update_column: str = None | ||
| extra_columns: Tuple[str, ...] = () | ||
| start_key: DbKey = None | ||
| end_key: DbKey = None | ||
| min_time: DbTime = None | ||
| max_time: DbTime = None | ||
|
|
||
| _count: int = None | ||
| _checksum: int = None | ||
|
|
||
| def _make_range_pred(self): | ||
| if self.start is not None: | ||
| yield Compare("<=", str(self.start), self.key_column) | ||
| if self.end is not None: | ||
| yield Compare("<", self.key_column, str(self.end)) | ||
| def __post_init__(self): | ||
| if not self.update_column and (self.min_time or self.max_time): | ||
| raise ValueError("Error: min_time/max_time feature requires to specify 'update_column'") | ||
|
|
||
| def _make_key_range(self): | ||
| if self.start_key is not None: | ||
| yield Compare("<=", str(self.start_key), self.key_column) | ||
| if self.end_key is not None: | ||
| yield Compare("<", self.key_column, str(self.end_key)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the key version will also benefit from a CLI arg 👍🏻 Thanks for changing its name and making it symmetric with time Lots of people will probably instead of temporal just get the id of the last records when they did the last insertion, then do the check based on that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have |
||
|
|
||
| def _make_update_range(self): | ||
| if self.min_time is not None: | ||
| yield Compare("<=", Time(self.min_time), self.update_column) | ||
| if self.max_time is not None: | ||
| yield Compare("<", self.update_column, Time(self.max_time)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A Python question: Why is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's just a cleaner syntax, saving me from creating a list, appending to it, and returning it. |
||
|
|
||
| def _make_select(self, *, table=None, columns=None, where=None, group_by=None, order_by=None): | ||
| if columns is None: | ||
| columns = [self.key_column] | ||
| where = list(self._make_range_pred()) + ([] if where is None else [where]) | ||
| where = list(self._make_key_range()) + list(self._make_update_range()) + ([] if where is None else [where]) | ||
| order_by = None if order_by is None else [order_by] | ||
| return Select( | ||
| table=table or TableName(self.table_path), | ||
|
|
@@ -70,16 +83,16 @@ def find_checkpoints(self, checkpoints: List[DbKey]) -> List[DbKey]: | |
| def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]: | ||
| "Split the current TableSegment to a bunch of smaller ones, separate by the given checkpoints" | ||
|
|
||
| if self.start and self.end: | ||
| assert all(self.start <= c < self.end for c in checkpoints) | ||
| if self.start_key and self.end_key: | ||
| assert all(self.start_key <= c < self.end_key for c in checkpoints) | ||
| checkpoints.sort() | ||
|
|
||
| # Calculate sub-segments | ||
| positions = [self.start] + checkpoints + [self.end] | ||
| positions = [self.start_key] + checkpoints + [self.end_key] | ||
| ranges = list(zip(positions[:-1], positions[1:])) | ||
|
|
||
| # Create table segments | ||
| tables = [self.new(start=s, end=e) for s, e in ranges] | ||
| tables = [self.new(start_key=s, end_key=e) for s, e in ranges] | ||
|
|
||
| return tables | ||
|
|
||
|
|
@@ -102,7 +115,11 @@ def count(self) -> int: | |
|
|
||
| @property | ||
| def _relevant_columns(self) -> List[str]: | ||
| return [self.key_column] + list(self.extra_columns) | ||
| return ( | ||
| [self.key_column] | ||
| + ([self.update_column] if self.update_column is not None else []) | ||
| + list(self.extra_columns) | ||
| ) | ||
|
|
||
| @property | ||
| def checksum(self) -> int: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,36 +2,110 @@ | |
| import unittest | ||
|
|
||
| import preql | ||
| import arrow # comes with preql | ||
|
|
||
| from data_diff.database import connect_to_uri | ||
| from data_diff.diff_tables import TableDiffer, TableSegment | ||
|
|
||
| from .common import TEST_MYSQL_CONN_STRING, str_to_checksum | ||
|
|
||
|
|
||
| class TestDiffTables(unittest.TestCase): | ||
| class TestWithConnection(unittest.TestCase): | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| # Avoid leaking connections that require waiting for the GC, which can | ||
| # cause deadlocks for table-level modifications. | ||
| cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING) | ||
| cls.connection = connect_to_uri(TEST_MYSQL_CONN_STRING) | ||
|
|
||
| class TestDates(TestWithConnection): | ||
| def setUp(self): | ||
| self.connection.query("DROP TABLE IF EXISTS a", None) | ||
| self.connection.query("DROP TABLE IF EXISTS b", None) | ||
| self.preql(r""" | ||
| table a { | ||
| datetime: datetime | ||
| comment: string | ||
| } | ||
| commit() | ||
| func add(date, comment) { | ||
| new a(date, comment) | ||
| } | ||
| """) | ||
| self.now = now = arrow.get(self.preql.now()) | ||
| self.preql.add(now.shift(days=-50), "50 days ago") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't look related to |
||
| self.preql.add(now.shift(hours=-3), "3 hours ago") | ||
| self.preql.add(now.shift(minutes=-10), "10 mins ago") | ||
| self.preql.add(now.shift(seconds=-1), "1 second ago") | ||
| self.preql.add(now, "now") | ||
|
|
||
| self.preql(r""" | ||
| const table b = a | ||
| commit() | ||
| """) | ||
|
|
||
| self.preql.add(self.now.shift(seconds=-3), "2 seconds ago") | ||
| self.preql.commit() | ||
|
|
||
|
|
||
| def test_init(self): | ||
| a = TableSegment(self.connection, ('a', ), 'id', 'datetime', max_time=self.now.datetime) | ||
| self.assertRaises(ValueError, TableSegment, self.connection, ('a', ), 'id', max_time=self.now.datetime) | ||
|
|
||
| def test_basic(self): | ||
| differ = TableDiffer(10, 100) | ||
| a = TableSegment(self.connection, ('a', ), 'id', 'datetime') | ||
| b = TableSegment(self.connection, ('b', ), 'id', 'datetime') | ||
| assert a.count == 6 | ||
| assert b.count == 5 | ||
|
|
||
| assert not list(differ.diff_tables(a, a)) | ||
| self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1 ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remember to run
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, thanks |
||
|
|
||
| def test_offset(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about a test where the diff actually passes when you pass a range?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's what I do at the end of |
||
| differ = TableDiffer(2, 10) | ||
| sec1 = self.now.shift(seconds=-1).datetime | ||
| a = TableSegment(self.connection, ('a', ), 'id', 'datetime', max_time=sec1) | ||
| b = TableSegment(self.connection, ('b', ), 'id', 'datetime', max_time=sec1) | ||
| assert a.count == 4 | ||
| assert b.count == 3 | ||
|
|
||
| assert not list(differ.diff_tables(a, a)) | ||
| self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1 ) | ||
|
|
||
| a = TableSegment(self.connection, ('a', ), 'id', 'datetime', min_time=sec1) | ||
| b = TableSegment(self.connection, ('b', ), 'id', 'datetime', min_time=sec1) | ||
| assert a.count == 2 | ||
| assert b.count == 2 | ||
| assert not list(differ.diff_tables(a, b)) | ||
|
|
||
| day1 = self.now.shift(days=-1).datetime | ||
|
|
||
| a = TableSegment(self.connection, ('a', ), 'id', 'datetime', min_time=day1, max_time=sec1) | ||
| b = TableSegment(self.connection, ('b', ), 'id', 'datetime', min_time=day1, max_time=sec1) | ||
| assert a.count == 3 | ||
| assert b.count == 2 | ||
| assert not list(differ.diff_tables(a, a)) | ||
| self.assertEqual( len( list(differ.diff_tables(a, b)) ), 1) | ||
|
|
||
|
|
||
| class TestDiffTables(TestWithConnection): | ||
|
|
||
| def setUp(self): | ||
| self.connection.query("DROP TABLE IF EXISTS ratings_test", None) | ||
| self.connection.query("DROP TABLE IF EXISTS ratings_test2", None) | ||
| self.preql.load("./tests/setup.pql") | ||
| self.preql.commit() | ||
|
|
||
| self.table = TableSegment(TestDiffTables.connection, | ||
| self.table = TableSegment(self.connection, | ||
| ('ratings_test', ), | ||
| 'id', | ||
| ('timestamp', )) | ||
| 'timestamp') | ||
|
|
||
| self.table2 = TableSegment(TestDiffTables.connection, | ||
| self.table2 = TableSegment(self.connection, | ||
| ("ratings_test2", ), | ||
| 'id', | ||
| ('timestamp', )) | ||
| 'timestamp') | ||
|
|
||
| self.differ = TableDiffer(3, 4) | ||
|
|
||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏻