diff --git a/link/crdt/core.py b/link/crdt/core.py index a6059c1..fe0b3ac 100644 --- a/link/crdt/core.py +++ b/link/crdt/core.py @@ -8,6 +8,10 @@ class CRDT(object): _type_err_msg = 'Invalid value type' + @classmethod + def _check_same_value(cls, a, b): + return a._value == b._value + @classmethod def _assert_mergeable(cls, a, b): if not isinstance(a, cls) and not isinstance(b, cls): @@ -15,8 +19,13 @@ def _assert_mergeable(cls, a, b): 'Supplied arguments are not {0}'.format(cls.__name__) ) + if not cls._check_same_value(a, b): + raise ValueError( + 'Supplied arguments does not have the same initial value' + ) + @classmethod - def merge(cls, a, b): + def merge(cls, a, b, context=None): raise NotImplementedError() def __init__(self, value=None, context=None): @@ -45,7 +54,12 @@ def _update_vclock(self): def _post_init(self): pass - def _check_type(self, value): + @classmethod + def _match_py_type(cls, pytype): + return cls._py_type in pytype.mro() + + @classmethod + def _check_type(cls, value): raise NotImplementedError() def _assert_type(self, value): diff --git a/link/crdt/counter.py b/link/crdt/counter.py index 55791d6..c3bd2eb 100644 --- a/link/crdt/counter.py +++ b/link/crdt/counter.py @@ -10,10 +10,10 @@ class Counter(CRDT): _type_err_msg = 'Counters can only be integers' @classmethod - def merge(cls, a, b): + def merge(cls, a, b, context=None): cls._assert_mergeable(a, b) - crdt = cls() + crdt = cls(value=a._value, context=context) crdt._increment = a._increment + b._increment crdt._vclock = max(a._vclock, b._vclock) crdt._update_vclock() @@ -25,8 +25,13 @@ def _post_init(self): def _default_value(self): return 0 - def _check_type(self, value): - return isinstance(value, self._py_type) + @classmethod + def _match_py_type(cls, pytype): + return cls._py_type in pytype.mro() and bool not in pytype.mro() + + @classmethod + def _check_type(cls, value): + return isinstance(value, cls._py_type) and not isinstance(value, bool) def increment(self, amount=1): self._assert_type(amount) diff --git a/link/crdt/diff.py b/link/crdt/diff.py index 41f3761..24dd33e 100644 --- a/link/crdt/diff.py +++ b/link/crdt/diff.py @@ -57,11 +57,15 @@ def __call__(self, a, b): amro = set(a.__class__.mro()) bmro = set(b.__class__.mro()) - if not (amro.issubset(bmro) or bmro.issubset(amro)): + if any([ + not (amro.issubset(bmro) or bmro.issubset(amro)), + isinstance(a, bool) and not isinstance(b, bool), + not isinstance(a, bool) and isinstance(b, bool) + ]): raise TypeError('Supplied arguments must be of the same type') for crdt_type in TYPES.values(): - if isinstance(a, crdt_type._py_type): + if crdt_type._check_type(a): crdt = crdt_type(value=a) method = getattr(self, crdt_type._type_name) diff --git a/link/crdt/flag.py b/link/crdt/flag.py index 66723cc..8fe0b3f 100644 --- a/link/crdt/flag.py +++ b/link/crdt/flag.py @@ -10,10 +10,10 @@ class Flag(CRDT): _type_err_msg = 'Flags can only be booleans' @classmethod - def merge(cls, a, b): + def merge(cls, a, b, context=None): cls._assert_mergeable(a, b) - crdt = cls() + crdt = cls(value=a._value, context=context) crdt._mutation = a._mutation if a._vclock >= b._vclock else b._mutation crdt._vclock = max(a._vclock, b._vclock) crdt._update_vclock() @@ -25,8 +25,9 @@ def _post_init(self): def _default_value(self): return False - def _check_type(self, value): - return isinstance(value, self._py_type) + @classmethod + def _check_type(cls, value): + return isinstance(value, cls._py_type) def enable(self): self._mutation = 'enable' diff --git a/link/crdt/map.py b/link/crdt/map.py index 0c26156..dfcb58d 100644 --- a/link/crdt/map.py +++ b/link/crdt/map.py @@ -16,26 +16,56 @@ class Map(Mapping, CRDT): _type_err_msg = 'Map must be a dict with keys ending with "_{datatype}"' @classmethod - def merge(cls, a, b): + def _check_same_value(cls, a, b): + for key in a._value: + if key in b._value: + aval = a._value[key] + bval = b._value[key] + + if not aval._check_same_value(aval, bval): + return False + + else: + return False + + return True + + @classmethod + def merge(cls, a, b, context=None): cls._assert_mergeable(a, b) - crdt = cls() + crdt = cls(value=a._compute_value(), context=context) crdt._removes = a._removes.union(b._removes) + # merge a and b values + for key in a._value: + if key in b._value: + suba = a._value[key] + subb = b._value[key] + + crdt._updates[key] = suba.merge(suba, subb, context=crdt) + del crdt._value[key] + + # complete with missing keys from a + for key in b._value: + if key not in a._value: + crdt._updates[key] = b._value[key] + + # merge a and b updates for key in a._updates: if key not in b._updates: - crdt._updates[key] = a._updates + crdt._updates[key] = a._updates[key] else: suba = a._updates[key] subb = b._updates[key] - crdt._updates[key] = suba.merge(suba, subb) + crdt._updates[key] = suba.merge(suba, subb, context=crdt) # complete with missing keys from a for key in b._updates: if key not in a._updates: - crdt._updates[key] = a._updates + crdt._updates[key] = a._updates[key] crdt._vclock = max(a._vclock, b._vclock) crdt._update_vclock() @@ -49,7 +79,8 @@ def _post_init(self): def _default_value(self): return dict() - def _check_key(self, key): + @classmethod + def _check_key(cls, key): for typename in TYPES: suffix = '_{0}'.format(typename) @@ -62,13 +93,14 @@ def _get_key_type(self, key): datatype = key.rsplit('_', 1)[1] return TYPES[datatype] - def _check_type(self, value): - if not isinstance(value, self._py_type): - raise TypeError(self._type_err_msg) + @classmethod + def _check_type(cls, value): + if not isinstance(value, cls._py_type): + return False for key in value: try: - self._check_key(key) + cls._check_key(key) except TypeError: return False @@ -105,7 +137,7 @@ def __iter__(self): return iter(self.current) def isdirty(self): - return self._removes and self._updates + return bool(self._removes) or bool(self._updates) def _coerce_value(self, value): cvalue = {} @@ -153,6 +185,12 @@ def current(self): return cvalue + def _compute_value(self): + return { + key: v.current if not isinstance(v, Map) else v._compute_value() + for key, v in self._value.items() + } + TYPES = { cls._type_name: cls diff --git a/link/crdt/register.py b/link/crdt/register.py index 8e52bb9..1422c3b 100644 --- a/link/crdt/register.py +++ b/link/crdt/register.py @@ -12,10 +12,10 @@ class Register(CRDT): _type_err_msg = 'Registers can only be strings' @classmethod - def merge(cls, a, b): + def merge(cls, a, b, context=None): cls._assert_mergeable(a, b) - crdt = cls() + crdt = cls(value=a._value, context=context) crdt._new = a._new if a._vclock >= b._vclock else b._new crdt._vclock = max(a._vclock, b._vclock) crdt._update_vclock() @@ -27,8 +27,16 @@ def _post_init(self): def _default_value(self): return '' - def _check_type(self, value): - return isinstance(value, self._py_type) + @classmethod + def _match_py_type(cls, pytype): + return any([ + string_type in pytype.mro() + for string_type in string_types + ]) + + @classmethod + def _check_type(cls, value): + return isinstance(value, cls._py_type) def assign(self, value): self._assert_type(value) diff --git a/link/crdt/set.py b/link/crdt/set.py index e57350e..d94e62a 100644 --- a/link/crdt/set.py +++ b/link/crdt/set.py @@ -13,10 +13,10 @@ class Set(collections.Set, CRDT): _type_err_msg = 'Sets can only be set of strings' @classmethod - def merge(cls, a, b): + def merge(cls, a, b, context=None): cls._assert_mergeable(a, b) - crdt = cls() + crdt = cls(value=a._value, context=context) crdt._adds = a._adds.union(b._adds) crdt._removes = a._removes.union(b._removes) crdt._vclock = max(a._vclock, b._vclock) @@ -34,8 +34,9 @@ def _check_element(self, element): if not isinstance(element, string_types): raise TypeError('Set elements can only be strings') - def _check_type(self, value): - if not isinstance(value, self._py_type): + @classmethod + def _check_type(cls, value): + if not isinstance(value, cls._py_type): return False for element in value: @@ -71,10 +72,10 @@ def isdirty(self): def mutation(self): return { - 'adds': list(self._adds), - 'removes': list(self._removes) + 'adds': set(self._adds), + 'removes': set(self._removes) } @CRDT.current.getter def current(self): - return set(self._value) + self._adds - self._removes + return set(self._value).union(self._adds).difference(self._removes) diff --git a/link/crdt/test/__init__.py b/link/crdt/test/__init__.py new file mode 100644 index 0000000..40a96af --- /dev/null +++ b/link/crdt/test/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/link/crdt/test/base.py b/link/crdt/test/base.py new file mode 100644 index 0000000..e1ed2bf --- /dev/null +++ b/link/crdt/test/base.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from unittest import TestCase +from six import PY2, PY3 + + +class UTCase(TestCase): + def assertItemsEqual(self, a, b): + if PY2: + return super(UTCase, self).assertItemsEqual(a, b) + + elif PY3: + return self.assertCountEqual(a, b) diff --git a/link/crdt/test/counter.py b/link/crdt/test/counter.py new file mode 100644 index 0000000..d7801bf --- /dev/null +++ b/link/crdt/test/counter.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +from link.crdt.test.base import UTCase +from unittest import main + +from link.crdt.counter import Counter + + +class TestCounter(UTCase): + def test_increment(self): + crdt = Counter(value=5) + self.assertEqual(crdt.current, 5) + + crdt.increment() + + self.assertEqual(crdt.current, 6) + self.assertEqual(crdt._increment, 1) + self.assertEqual(crdt._vclock, 1) + self.assertTrue(crdt.isdirty()) + self.assertEqual(crdt.mutation(), {'increment': 1}) + + crdt.decrement() + + self.assertEqual(crdt.current, 5) + self.assertEqual(crdt._increment, 0) + self.assertEqual(crdt._vclock, 2) + self.assertFalse(crdt.isdirty()) + self.assertEqual(crdt.mutation(), {'increment': 0}) + + def test_merge(self): + a = Counter(value=5) + b = Counter(value=5) + + b.increment(3) + b.increment(2) + a.increment(2) + + c = Counter.merge(a, b) + + self.assertEqual(c.current, 12) + self.assertEqual(c._increment, 7) + self.assertEqual(c._vclock, 3) + self.assertTrue(c.isdirty()) + self.assertEqual(c.mutation(), {'increment': 7}) + + def test_fail_merge(self): + a = Counter(value=5) + b = Counter(value=7) + + with self.assertRaises(ValueError): + Counter.merge(a, b) + + def test_fail_type(self): + with self.assertRaises(TypeError): + Counter(value='not int') + + +if __name__ == '__main__': + main() diff --git a/link/crdt/test/diff.py b/link/crdt/test/diff.py new file mode 100644 index 0000000..b3eefb6 --- /dev/null +++ b/link/crdt/test/diff.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- + +from link.crdt.test.base import UTCase +from unittest import main + +from link.crdt.diff import crdt_diff + + +class TestDiff(UTCase): + def test_diff_counter(self): + crdt = crdt_diff(5, 6) + + self.assertEqual(crdt._value, 5) + self.assertEqual(crdt._increment, 1) + self.assertEqual(crdt.current, 6) + + def test_diff_flag(self): + crdt = crdt_diff(False, True) + + self.assertEqual(crdt._value, False) + self.assertEqual(crdt.current, True) + self.assertEqual(crdt._mutation, 'enable') + + crdt = crdt_diff(True, False) + + self.assertEqual(crdt._value, True) + self.assertEqual(crdt.current, False) + self.assertEqual(crdt._mutation, 'disable') + + def test_diff_register(self): + crdt = crdt_diff('test', 'test2') + + self.assertEqual(crdt._value, 'test') + self.assertEqual(crdt.current, 'test2') + self.assertEqual(crdt._new, 'test2') + + def test_diff_set(self): + crdt = crdt_diff({'1', '2'}, {'2', '3'}) + + self.assertEqual(crdt._value, {'1', '2'}) + self.assertEqual(crdt.current, {'2', '3'}) + self.assertEqual(crdt._adds, {'2', '3'}) + self.assertEqual(crdt._removes, {'1'}) + + def test_diff_map(self): + a = { + 'a_counter': 5, + 'b_flag': True, + 'c_register': 'test', + 'd_set': {'1', '2'}, + 'e_map': { + 'a_counter': 5 + } + } + b = { + 'a_counter': 7, + 'b_flag': False, + 'c_register': 'test2', + 'd_set': {'2', '3'} + } + crdt = crdt_diff(a, b) + + self.assertIn('a_counter', crdt._value) + self.assertIn('b_flag', crdt._value) + self.assertIn('c_register', crdt._value) + self.assertIn('d_set', crdt._value) + self.assertIn('e_map', crdt._removes) + self.assertEqual(crdt.current, b) + self.assertItemsEqual(crdt.mutation(), [ + {'remove': 'e_map'}, + { + 'update': 'a_counter', + 'mutation': { + 'increment': 2 + } + }, + { + 'update': 'b_flag', + 'mutation': { + 'disable': None + } + }, + { + 'update': 'c_register', + 'mutation': { + 'assign': 'test2' + } + }, + { + 'update': 'd_set', + 'mutation': { + 'adds': {'2', '3'}, + 'removes': {'1'} + } + } + ]) + + def test_fail_diff(self): + with self.assertRaises(TypeError): + crdt_diff(1, True) + + with self.assertRaises(TypeError): + crdt_diff('str', 2) + +if __name__ == '__main__': + main() diff --git a/link/crdt/test/flag.py b/link/crdt/test/flag.py new file mode 100644 index 0000000..2106c78 --- /dev/null +++ b/link/crdt/test/flag.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- + +from link.crdt.test.base import UTCase +from unittest import main + +from link.crdt.flag import Flag + + +class TestFlag(UTCase): + def test_enable(self): + crdt = Flag(value=False) + + self.assertFalse(crdt.current) + self.assertFalse(crdt.isdirty()) + + crdt.enable() + + self.assertTrue(crdt.current) + self.assertEqual(crdt._mutation, 'enable') + self.assertEqual(crdt._vclock, 1) + self.assertTrue(crdt.isdirty()) + self.assertEqual(crdt.mutation(), {'enable': None}) + + crdt.disable() + + self.assertFalse(crdt.current) + self.assertEqual(crdt._mutation, 'disable') + self.assertEqual(crdt._vclock, 2) + self.assertTrue(crdt.isdirty()) + self.assertEqual(crdt.mutation(), {'disable': None}) + + def test_merge(self): + a = Flag(value=False) + b = Flag(value=False) + + b.enable() + b.disable() + a.enable() + + c = Flag.merge(a, b) + + self.assertFalse(c.current) + self.assertEqual(c._mutation, 'disable') + self.assertEqual(c._vclock, 3) + self.assertTrue(c.isdirty()) + self.assertEqual(c.mutation(), {'disable': None}) + + def test_fail_merge(self): + a = Flag(value=False) + b = Flag(value=True) + + with self.assertRaises(ValueError): + Flag.merge(a, b) + + def test_fail_type(self): + with self.assertRaises(TypeError): + Flag(value='not bool') + + +if __name__ == '__main__': + main() diff --git a/link/crdt/test/map.py b/link/crdt/test/map.py new file mode 100644 index 0000000..015b681 --- /dev/null +++ b/link/crdt/test/map.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +from link.crdt.test.base import UTCase +from unittest import main + +from link.crdt.map import Map + + +class TestMap(UTCase): + def test_enable(self): + expected = { + 'a_counter': 5, + 'b_flag': False, + 'c_register': 'test', + 'd_set': {'1', '2'}, + 'e_map': { + 'a_counter': 5 + } + } + crdt = Map(value=expected) + + self.assertEqual(crdt.current, expected) + + crdt['a_counter'].increment(3) + crdt['b_flag'].enable() + crdt['c_register'].assign('test2') + crdt['d_set'].add('3') + del crdt['e_map'] + crdt['e_register'].assign('test') + + expected = { + 'a_counter': 8, + 'b_flag': True, + 'c_register': 'test2', + 'd_set': {'1', '2', '3'}, + 'e_register': 'test' + } + + self.assertEqual(crdt.current, expected) + self.assertIn('a_counter', crdt._value) + self.assertIn('b_flag', crdt._value) + self.assertIn('c_register', crdt._value) + self.assertIn('d_set', crdt._value) + self.assertIn('e_register', crdt._updates) + self.assertEqual(crdt._vclock, 7) + self.assertTrue(crdt.isdirty()) + self.assertItemsEqual(crdt.mutation(), [ + {'remove': 'e_map'}, + { + 'update': 'a_counter', + 'mutation': { + 'increment': 3 + } + }, + { + 'update': 'b_flag', + 'mutation': { + 'enable': None + } + }, + { + 'update': 'c_register', + 'mutation': { + 'assign': 'test2' + } + }, + { + 'update': 'd_set', + 'mutation': { + 'adds': {'3'}, + 'removes': set() + } + }, + { + 'update': 'e_register', + 'mutation': { + 'assign': 'test' + } + } + ]) + + def test_merge(self): + a = Map(value={'a_counter': 5}) + b = Map(value={'a_counter': 5}) + + a['a_counter'].increment(3) + b['a_counter'].increment(3) + b['a_counter'].decrement(2) + + c = Map.merge(a, b) + + self.assertEqual(c.current, {'a_counter': 9}) + self.assertIn('a_counter', c._updates) + self.assertEqual(c._vclock, 3) + self.assertTrue(c.isdirty()) + self.assertEqual(c.mutation(), [ + { + 'update': 'a_counter', + 'mutation': { + 'increment': 4 + } + } + ]) + + def test_fail_merge(self): + a = Map(value={'a_counter': 5}) + b = Map(value={'b_counter': 5}) + + with self.assertRaises(ValueError): + Map.merge(a, b) + + def test_fail_type(self): + with self.assertRaises(TypeError): + not_dict = 42 + Map(value=not_dict) + + with self.assertRaises(TypeError): + Map(value={'a_counter': 'not int'}) + + +if __name__ == '__main__': + main() diff --git a/link/crdt/test/register.py b/link/crdt/test/register.py new file mode 100644 index 0000000..07aaf66 --- /dev/null +++ b/link/crdt/test/register.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +from link.crdt.test.base import UTCase +from unittest import main + +from link.crdt.register import Register + + +class TestRegister(UTCase): + def test_enable(self): + crdt = Register(value='initial') + + self.assertEqual(crdt.current, 'initial') + + crdt.assign('new') + + self.assertEqual(crdt.current, 'new') + self.assertEqual(crdt._new, 'new') + self.assertEqual(crdt._vclock, 1) + self.assertTrue(crdt.isdirty()) + self.assertEqual(crdt.mutation(), {'assign': 'new'}) + + def test_merge(self): + a = Register(value='initial') + b = Register(value='initial') + + b.assign('1') + b.assign('2') + a.assign('1') + + c = Register.merge(a, b) + + self.assertEqual(c.current, '2') + self.assertEqual(c._new, '2') + self.assertEqual(c._vclock, 3) + self.assertTrue(c.isdirty()) + self.assertEqual(c.mutation(), {'assign': '2'}) + + def test_fail_merge(self): + a = Register(value='1') + b = Register(value='2') + + with self.assertRaises(ValueError): + Register.merge(a, b) + + def test_fail_type(self): + with self.assertRaises(TypeError): + not_str = 42 + Register(value=not_str) + + +if __name__ == '__main__': + main() diff --git a/link/crdt/test/set.py b/link/crdt/test/set.py new file mode 100644 index 0000000..06b5b9e --- /dev/null +++ b/link/crdt/test/set.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +from link.crdt.test.base import UTCase +from unittest import main + +from link.crdt.set import Set + + +class TestSet(UTCase): + def test_enable(self): + crdt = Set(value={'1', '2'}) + + self.assertEqual(crdt.current, {'1', '2'}) + + crdt.add('3') + crdt.discard('2') + + self.assertEqual(crdt.current, {'1', '3'}) + self.assertEqual(crdt._adds, {'3'}) + self.assertEqual(crdt._removes, {'2'}) + self.assertEqual(crdt._vclock, 2) + self.assertTrue(crdt.isdirty()) + self.assertEqual(crdt.mutation(), { + 'adds': {'3'}, + 'removes': {'2'} + }) + + def test_merge(self): + a = Set(value={'1', '2'}) + b = Set(value={'1', '2'}) + + b.add('3') + a.discard('2') + + c = Set.merge(a, b) + + self.assertEqual(c.current, {'1', '3'}) + self.assertEqual(c._adds, {'3'}) + self.assertEqual(c._removes, {'2'}) + self.assertEqual(c._vclock, 2) + self.assertTrue(c.isdirty()) + self.assertEqual(c.mutation(), { + 'adds': {'3'}, + 'removes': {'2'} + }) + + def test_fail_merge(self): + a = Set(value={'1'}) + b = Set(value={'2'}) + + with self.assertRaises(ValueError): + Set.merge(a, b) + + def test_fail_type(self): + with self.assertRaises(TypeError): + not_set = 42 + Set(value=not_set) + + with self.assertRaises(TypeError): + Set(value={42}) + + +if __name__ == '__main__': + main() diff --git a/link/crdt/test/utils.py b/link/crdt/test/utils.py new file mode 100644 index 0000000..ed2bacb --- /dev/null +++ b/link/crdt/test/utils.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +from link.crdt.test.base import UTCase +from unittest import main + +from link.crdt.utils import get_crdt_type_by_name, get_crdt_type_by_py_type +from link.crdt.counter import Counter +from link.crdt.flag import Flag +from link.crdt.register import Register +from link.crdt.set import Set +from link.crdt.map import Map + +from six import string_types +import collections + + +class TestUtils(UTCase): + def test_get_type_by_name(self): + l = [ + (Counter, 'counter'), + (Flag, 'flag'), + (Register, 'register'), + (Set, 'set'), + (Map, 'map') + ] + + for crdt_type, name in l: + got = get_crdt_type_by_name(name) + self.assertIs(got, crdt_type) + + got = get_crdt_type_by_name('unknown') + self.assertIsNone(got) + + def test_get_type_by_py_type(self): + l = [ + (Counter, int), + (Flag, bool), + (Register, str), + (Set, collections.Set), + (Map, dict) + ] + + for crdt_type, py_type in l: + got = get_crdt_type_by_py_type(py_type) + self.assertIs(got, crdt_type) + + got = get_crdt_type_by_py_type(UTCase) + self.assertIsNone(got) + +if __name__ == '__main__': + main() diff --git a/link/crdt/utils.py b/link/crdt/utils.py index 77af4a1..b3af6b3 100644 --- a/link/crdt/utils.py +++ b/link/crdt/utils.py @@ -10,7 +10,7 @@ def get_crdt_type_by_name(name): def get_crdt_type_by_py_type(pytype): for crdt_type in TYPES.values(): - if crdt_type._py_type in pytype.mro(): + if crdt_type._match_py_type(pytype): return crdt_type return None