Skip to content

Commit

Permalink
Add bitwise and other helper methods to BigBitField.
Browse files Browse the repository at this point in the history
Also ensures that when checking if a bit is set, that the buffer is not
extended unnecessarily.

These changes are partially derived from #2802 / @nyaoouo - thanks!
  • Loading branch information
coleifer committed Oct 31, 2023
1 parent 40ad4f2 commit 83de3b6
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
51 changes: 50 additions & 1 deletion peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -5047,6 +5047,9 @@ def __init__(self, instance, name):
value = bytearray(value)
self._buffer = self.instance.__data__[self.name] = value

def clear(self):
self._buffer.clear()

def _ensure_length(self, idx):
byte_num, byte_offset = divmod(idx, 8)
cur_size = len(self._buffer)
Expand All @@ -5068,9 +5071,55 @@ def toggle_bit(self, idx):
return bool(self._buffer[byte_num] & (1 << byte_offset))

def is_set(self, idx):
byte_num, byte_offset = self._ensure_length(idx)
byte_num, byte_offset = divmod(idx, 8)
cur_size = len(self._buffer)
if cur_size <= byte_num:
return False
return bool(self._buffer[byte_num] & (1 << byte_offset))

__getitem__ = is_set
def __setitem__(self, item, value):
self.set_bit(item) if value else self.clear_bit(item)
__delitem__ = clear_bit

def __len__(self):
return len(self._buffer)

def _get_compatible_data(self, other):
if isinstance(other, BigBitFieldData):
data = other._buffer
elif isinstance(other, (bytes, bytearray, memoryview)):
data = other
else:
raise ValueError('Incompatible data-type')
diff = len(data) - len(self)
if diff > 0: self._buffer.extend(b'\x00' * diff)
return data

def _bitwise_op(self, other, op):
if isinstance(other, BigBitFieldData):
data = other._buffer
elif isinstance(other, (bytes, bytearray, memoryview)):
data = other
else:
raise ValueError('Incompatible data-type')
buf = bytearray(b'\x00' * max(len(self), len(other)))
for i, (a, b) in enumerate(zip(self._buffer, data)):
buf[i] = op(a, b)
return buf

def __and__(self, other):
return self._bitwise_op(other, operator.and_)
def __or__(self, other):
return self._bitwise_op(other, operator.or_)
def __xor__(self, other):
return self._bitwise_op(other, operator.xor)

def __iter__(self):
for b in self._buffer:
for j in range(8):
yield 1 if (b & (1 << j)) else 0

def __repr__(self):
return repr(self._buffer)
if sys.version_info[0] < 3:
Expand Down
49 changes: 49 additions & 0 deletions tests/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,35 @@ def test_bigbit_zero_idx(self):
b.data.clear_bit(0)
self.assertFalse(b.data.is_set(0))

# Out-of-bounds returns False and does not extend data.
self.assertFalse(b.data.is_set(1000))
self.assertTrue(len(b.data), 1)

def test_bigbit_item_methods(self):
b = Bits()
idxs = [0, 1, 4, 7, 8, 15, 16, 31, 32, 63]
for i in idxs:
b.data[i] = True
for i in range(64):
self.assertEqual(b.data[i], i in idxs)

data = list(b.data)
self.assertEqual(data, [1 if i in idxs else 0 for i in range(64)])

for i in range(64):
del b.data[i]
self.assertEqual(len(b.data), 8)
self.assertEqual(b.data._buffer, b'\x00' * 8)

def test_bigbit_set_clear(self):
b = Bits()
b.data = b'\x01'
for i in range(8):
self.assertEqual(b.data[i], i == 0)

b.data.clear()
self.assertEqual(len(b.data), 0)

def test_bigbit_field(self):
b = Bits.create()
b.data.set_bit(1)
Expand All @@ -692,6 +721,26 @@ def test_bigbit_field(self):
else:
self.assertFalse(b_db.data.is_set(x))

def test_bigbit_field_bitwise(self):
b1 = Bits(data=b'\x11')
b2 = Bits(data=b'\x12')
b3 = Bits(data=b'\x99')
self.assertEqual(b1.data & b2.data, b'\x10')
self.assertEqual(b1.data | b2.data, b'\x13')
self.assertEqual(b1.data ^ b2.data, b'\x03')
self.assertEqual(b1.data & b3.data, b'\x11')
self.assertEqual(b1.data | b3.data, b'\x99')
self.assertEqual(b1.data ^ b3.data, b'\x88')

b1.data &= b2.data
self.assertEqual(b1.data._buffer, b'\x10')

b1.data |= b2.data
self.assertEqual(b1.data._buffer, b'\x12')

b1.data ^= b3.data
self.assertEqual(b1.data._buffer, b'\x8b')

def test_bigbit_field_bulk_create(self):
b1, b2, b3 = Bits(), Bits(), Bits()
b1.data.set_bit(1)
Expand Down

0 comments on commit 83de3b6

Please sign in to comment.