Skip to content

Commit

Permalink
:add: unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
linkdd committed Aug 26, 2016
1 parent dfffefe commit b1d8836
Show file tree
Hide file tree
Showing 17 changed files with 636 additions and 35 deletions.
18 changes: 16 additions & 2 deletions link/crdt/core.py
Expand Up @@ -8,15 +8,24 @@ class CRDT(object):

_type_err_msg = 'Invalid value type'

@classmethod
def _check_same_value(cls, a, b):
return a._value == b._value

@classmethod
def _assert_mergeable(cls, a, b):
if not isinstance(a, cls) and not isinstance(b, cls):
raise TypeError(
'Supplied arguments are not {0}'.format(cls.__name__)
)

if not cls._check_same_value(a, b):
raise ValueError(
'Supplied arguments does not have the same initial value'
)

@classmethod
def merge(cls, a, b):
def merge(cls, a, b, context=None):
raise NotImplementedError()

def __init__(self, value=None, context=None):
Expand Down Expand Up @@ -45,7 +54,12 @@ def _update_vclock(self):
def _post_init(self):
pass

def _check_type(self, value):
@classmethod
def _match_py_type(cls, pytype):
return cls._py_type in pytype.mro()

@classmethod
def _check_type(cls, value):
raise NotImplementedError()

def _assert_type(self, value):
Expand Down
13 changes: 9 additions & 4 deletions link/crdt/counter.py
Expand Up @@ -10,10 +10,10 @@ class Counter(CRDT):
_type_err_msg = 'Counters can only be integers'

@classmethod
def merge(cls, a, b):
def merge(cls, a, b, context=None):
cls._assert_mergeable(a, b)

crdt = cls()
crdt = cls(value=a._value, context=context)
crdt._increment = a._increment + b._increment
crdt._vclock = max(a._vclock, b._vclock)
crdt._update_vclock()
Expand All @@ -25,8 +25,13 @@ def _post_init(self):
def _default_value(self):
return 0

def _check_type(self, value):
return isinstance(value, self._py_type)
@classmethod
def _match_py_type(cls, pytype):
return cls._py_type in pytype.mro() and bool not in pytype.mro()

@classmethod
def _check_type(cls, value):
return isinstance(value, cls._py_type) and not isinstance(value, bool)

def increment(self, amount=1):
self._assert_type(amount)
Expand Down
8 changes: 6 additions & 2 deletions link/crdt/diff.py
Expand Up @@ -57,11 +57,15 @@ def __call__(self, a, b):
amro = set(a.__class__.mro())
bmro = set(b.__class__.mro())

if not (amro.issubset(bmro) or bmro.issubset(amro)):
if any([
not (amro.issubset(bmro) or bmro.issubset(amro)),
isinstance(a, bool) and not isinstance(b, bool),
not isinstance(a, bool) and isinstance(b, bool)
]):
raise TypeError('Supplied arguments must be of the same type')

for crdt_type in TYPES.values():
if isinstance(a, crdt_type._py_type):
if crdt_type._check_type(a):
crdt = crdt_type(value=a)

method = getattr(self, crdt_type._type_name)
Expand Down
9 changes: 5 additions & 4 deletions link/crdt/flag.py
Expand Up @@ -10,10 +10,10 @@ class Flag(CRDT):
_type_err_msg = 'Flags can only be booleans'

@classmethod
def merge(cls, a, b):
def merge(cls, a, b, context=None):
cls._assert_mergeable(a, b)

crdt = cls()
crdt = cls(value=a._value, context=context)
crdt._mutation = a._mutation if a._vclock >= b._vclock else b._mutation
crdt._vclock = max(a._vclock, b._vclock)
crdt._update_vclock()
Expand All @@ -25,8 +25,9 @@ def _post_init(self):
def _default_value(self):
return False

def _check_type(self, value):
return isinstance(value, self._py_type)
@classmethod
def _check_type(cls, value):
return isinstance(value, cls._py_type)

def enable(self):
self._mutation = 'enable'
Expand Down
60 changes: 49 additions & 11 deletions link/crdt/map.py
Expand Up @@ -16,26 +16,56 @@ class Map(Mapping, CRDT):
_type_err_msg = 'Map must be a dict with keys ending with "_{datatype}"'

@classmethod
def merge(cls, a, b):
def _check_same_value(cls, a, b):
for key in a._value:
if key in b._value:
aval = a._value[key]
bval = b._value[key]

if not aval._check_same_value(aval, bval):
return False

else:
return False

return True

@classmethod
def merge(cls, a, b, context=None):
cls._assert_mergeable(a, b)

crdt = cls()
crdt = cls(value=a._compute_value(), context=context)
crdt._removes = a._removes.union(b._removes)

# merge a and b values
for key in a._value:
if key in b._value:
suba = a._value[key]
subb = b._value[key]

crdt._updates[key] = suba.merge(suba, subb, context=crdt)
del crdt._value[key]

# complete with missing keys from a
for key in b._value:
if key not in a._value:
crdt._updates[key] = b._value[key]

# merge a and b updates
for key in a._updates:
if key not in b._updates:
crdt._updates[key] = a._updates
crdt._updates[key] = a._updates[key]

else:
suba = a._updates[key]
subb = b._updates[key]

crdt._updates[key] = suba.merge(suba, subb)
crdt._updates[key] = suba.merge(suba, subb, context=crdt)

# complete with missing keys from a
for key in b._updates:
if key not in a._updates:
crdt._updates[key] = a._updates
crdt._updates[key] = a._updates[key]

crdt._vclock = max(a._vclock, b._vclock)
crdt._update_vclock()
Expand All @@ -49,7 +79,8 @@ def _post_init(self):
def _default_value(self):
return dict()

def _check_key(self, key):
@classmethod
def _check_key(cls, key):
for typename in TYPES:
suffix = '_{0}'.format(typename)

Expand All @@ -62,13 +93,14 @@ def _get_key_type(self, key):
datatype = key.rsplit('_', 1)[1]
return TYPES[datatype]

def _check_type(self, value):
if not isinstance(value, self._py_type):
raise TypeError(self._type_err_msg)
@classmethod
def _check_type(cls, value):
if not isinstance(value, cls._py_type):
return False

for key in value:
try:
self._check_key(key)
cls._check_key(key)

except TypeError:
return False
Expand Down Expand Up @@ -105,7 +137,7 @@ def __iter__(self):
return iter(self.current)

def isdirty(self):
return self._removes and self._updates
return bool(self._removes) or bool(self._updates)

def _coerce_value(self, value):
cvalue = {}
Expand Down Expand Up @@ -153,6 +185,12 @@ def current(self):

return cvalue

def _compute_value(self):
return {
key: v.current if not isinstance(v, Map) else v._compute_value()
for key, v in self._value.items()
}


TYPES = {
cls._type_name: cls
Expand Down
16 changes: 12 additions & 4 deletions link/crdt/register.py
Expand Up @@ -12,10 +12,10 @@ class Register(CRDT):
_type_err_msg = 'Registers can only be strings'

@classmethod
def merge(cls, a, b):
def merge(cls, a, b, context=None):
cls._assert_mergeable(a, b)

crdt = cls()
crdt = cls(value=a._value, context=context)
crdt._new = a._new if a._vclock >= b._vclock else b._new
crdt._vclock = max(a._vclock, b._vclock)
crdt._update_vclock()
Expand All @@ -27,8 +27,16 @@ def _post_init(self):
def _default_value(self):
return ''

def _check_type(self, value):
return isinstance(value, self._py_type)
@classmethod
def _match_py_type(cls, pytype):
return any([
string_type in pytype.mro()
for string_type in string_types
])

@classmethod
def _check_type(cls, value):
return isinstance(value, cls._py_type)

def assign(self, value):
self._assert_type(value)
Expand Down
15 changes: 8 additions & 7 deletions link/crdt/set.py
Expand Up @@ -13,10 +13,10 @@ class Set(collections.Set, CRDT):
_type_err_msg = 'Sets can only be set of strings'

@classmethod
def merge(cls, a, b):
def merge(cls, a, b, context=None):
cls._assert_mergeable(a, b)

crdt = cls()
crdt = cls(value=a._value, context=context)
crdt._adds = a._adds.union(b._adds)
crdt._removes = a._removes.union(b._removes)
crdt._vclock = max(a._vclock, b._vclock)
Expand All @@ -34,8 +34,9 @@ def _check_element(self, element):
if not isinstance(element, string_types):
raise TypeError('Set elements can only be strings')

def _check_type(self, value):
if not isinstance(value, self._py_type):
@classmethod
def _check_type(cls, value):
if not isinstance(value, cls._py_type):
return False

for element in value:
Expand Down Expand Up @@ -71,10 +72,10 @@ def isdirty(self):

def mutation(self):
return {
'adds': list(self._adds),
'removes': list(self._removes)
'adds': set(self._adds),
'removes': set(self._removes)
}

@CRDT.current.getter
def current(self):
return set(self._value) + self._adds - self._removes
return set(self._value).union(self._adds).difference(self._removes)
1 change: 1 addition & 0 deletions link/crdt/test/__init__.py
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
13 changes: 13 additions & 0 deletions link/crdt/test/base.py
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-

from unittest import TestCase
from six import PY2, PY3


class UTCase(TestCase):
def assertItemsEqual(self, a, b):
if PY2:
return super(UTCase, self).assertItemsEqual(a, b)

elif PY3:
return self.assertCountEqual(a, b)
59 changes: 59 additions & 0 deletions link/crdt/test/counter.py
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-

from link.crdt.test.base import UTCase
from unittest import main

from link.crdt.counter import Counter


class TestCounter(UTCase):
def test_increment(self):
crdt = Counter(value=5)
self.assertEqual(crdt.current, 5)

crdt.increment()

self.assertEqual(crdt.current, 6)
self.assertEqual(crdt._increment, 1)
self.assertEqual(crdt._vclock, 1)
self.assertTrue(crdt.isdirty())
self.assertEqual(crdt.mutation(), {'increment': 1})

crdt.decrement()

self.assertEqual(crdt.current, 5)
self.assertEqual(crdt._increment, 0)
self.assertEqual(crdt._vclock, 2)
self.assertFalse(crdt.isdirty())
self.assertEqual(crdt.mutation(), {'increment': 0})

def test_merge(self):
a = Counter(value=5)
b = Counter(value=5)

b.increment(3)
b.increment(2)
a.increment(2)

c = Counter.merge(a, b)

self.assertEqual(c.current, 12)
self.assertEqual(c._increment, 7)
self.assertEqual(c._vclock, 3)
self.assertTrue(c.isdirty())
self.assertEqual(c.mutation(), {'increment': 7})

def test_fail_merge(self):
a = Counter(value=5)
b = Counter(value=7)

with self.assertRaises(ValueError):
Counter.merge(a, b)

def test_fail_type(self):
with self.assertRaises(TypeError):
Counter(value='not int')


if __name__ == '__main__':
main()

0 comments on commit b1d8836

Please sign in to comment.