In [36]:
from copy import copy
from dataclasses import dataclass
from typing import Any, List


@dataclass
class Expr:
    expr: Any

    def __init__(self, expr):
        if not (expr in [0, 1] or isinstance(expr, (str, And, Or, Xor, Not))):
            raise ValueError(
                "Value must be 0, 1, or an instance of str, And, Or, Xor, or Not."
            )
        self.expr = copy(expr)

    def eval(self):
        """Evaluates the expression and returns the result."""
        try:
            return copy(self.expr.eval())
        except AttributeError:
            return copy(self.expr)

    def __eq__(self, other):
        match self.eval(), other.eval():
            case (0, 0) | (1, 1):
                return True
            case (a, b) | (Not(a), Not(b)) if a == b:
                return True
            case (And(a, And(b, c)), And(And(d, e), f)) if set(
                Expr(x) for x in [a, b, c]
            ) == set(Expr(y) for y in [d, e, f]):
                return True
            case (Or(a, Or(b, c)), Or(Or(d, e), f)) if set(
                Expr(x) for x in [a, b, c]
            ) == set(Expr(y) for y in [d, e, f]):
                return True
            case (Xor(a, Xor(b, c)), Xor(Xor(d, e), f)) if set(
                Expr(x) for x in [a, b, c]
            ) == set(Expr(y) for y in [d, e, f]):
                return True
            case _:
                return False

    def __hash__(self):
        return hash(str(self))

    def __str__(self):
        r = self.eval()
        if r in [0, 1] or isinstance(r, str):
            return f"{r}"
        else:
            return f"{self.expr}"


@dataclass
class And:
    a: Any
    b: Any

    def __init__(self, a, b):
        self.a = Expr(a).eval()
        self.b = Expr(b).eval()

    def eval(self):
        match Expr(self.a).eval(), Expr(self.b).eval():
            case (0, _) | (_, 0):
                return 0
            case (Not(a), b) | (a, Not(b)) if a == b:
                return 0
            case (1, a) | (a, 1):
                return a
            case _:
                return self

    def __str__(self):
        r = self.eval()
        if r in [0, 1] or isinstance(r, str):
            return f"{r}"
        else:
            return f"({self.a}&{self.b})"


@dataclass
class Or:
    a: Any
    b: Any

    def __init__(self, a, b):
        self.a = Expr(a).eval()
        self.b = Expr(b).eval()

    def eval(self):
        match Expr(self.a).eval(), Expr(self.b).eval():
            case (1, _) | (_, 1):
                return 1
            case (0, a) | (a, 0):
                return a
            case (Not(a), b) | (a, Not(b)) if a == b:
                return 1
            case _:
                return self

    def __str__(self):
        r = self.eval()
        if r in [0, 1] or isinstance(r, str):
            return f"{r}"
        else:
            return f"({self.a}|{self.b})"


@dataclass
class Xor:
    a: Any
    b: Any

    def __init__(self, a, b):
        self.a = Expr(a).eval()
        self.b = Expr(b).eval()

    def eval(self):
        match Expr(self.a).eval(), Expr(self.b).eval():
            case (0, 1) | (1, 0):
                return 1
            case (a, b) if a == b:
                return 0
            case (0, a) | (a, 0):
                return a
            case (1, a) | (a, 1):
                return Not(a)
            case (Not(a), b) | (a, Not(b)) if a == b:
                return 1
            case (Xor(a, b), c) | (a, Xor(b, c)) if a == b:
                return c
            case (Xor(a, b), c) | (a, Xor(b, c)) if a == c:
                return b
            case (Xor(a, b), c) | (a, Xor(b, c)) if b == c:
                return a
            case _:
                return copy(self)

    def __str__(self):
        r = self.eval()
        if r in [0, 1] or isinstance(r, str):
            return f"{r}"
        else:
            return f"({self.a}^{self.b})"


@dataclass
class Not:
    a: Any

    def __init__(self, a):
        self.a = Expr(a).eval()

    def eval(self):
        match Expr(self.a).eval():
            case 0:
                return 1
            case 1:
                return 0
            case Not(a):
                return a
            case _:
                return copy(self)

    def __str__(self):
        r = self.eval()
        if r in [0, 1]:
            return f"{r}"
        return f"~{self.a}"

In [37]:
# Example usage:
assert And("a", 0).eval() == 0
assert And("a", 1).eval() == "a"
assert And("a", Not("a")).eval() == 0
assert Or("a", 1).eval() == 1
assert Or("a", 0).eval() == "a"
assert Or("a", Not("a")).eval() == 1
assert Xor("a", 1).eval() == Not("a")
assert Xor("a", 0).eval() == "a"
assert Xor("a", Not("a")).eval() == 1
assert Xor("a", "a").eval() == 0
assert Xor("a", Xor("a", "b")).eval() == "b"
assert Not(Not("a")).eval() == "a"

assert Expr(And("a", And("b", "c"))) == Expr(And(And("a", "b"), "c"))
assert Expr(Or("a", Or("b", "c"))) == Expr(Or(Or("a", "b"), "c"))
assert Expr(Xor("a", Xor("b", "c"))) == Expr(Xor(Xor("a", "b"), "c"))

In [38]:
@dataclass
class BitArray:
    bits: List

    def from_int(n, length):
        """Create a BitArray from an integer"""
        return BitArray([int(i) for i in bin(n)[2:].zfill(length)])

    def to_int(self):
        """
        Converts the BitArray to an integer. If the BitArray contains bits with unknown values,
        returns None.
        """
        if not all(isinstance(Expr(bit).eval(), int) for bit in self.bits):
            return None
        return int("".join(str(bit) for bit in self.bits), 2)

    def len(self):
        """Returns the length of the BitArray"""
        return len(self.bits)

    def _validate_other(self, other):
        """ "Checks that `other` is a BitArray and has the same length as `self`."""
        if not isinstance(other, BitArray):
            raise ValueError("Operand must be a BitArray.")
        if self.len() != other.len():
            raise ValueError("BitArray lengths must be equal.")

    def __getitem__(self, index):
        """Returns the bits at the given index"""
        return self.bits[index]

    def __setitem__(self, index, value):
        """Sets the bits at the given index"""
        self.bits[index] = value

    def __add__(self, other):
        """
        Unsigned addition. Returns the tuple `(sum, carry)` where `sum` is a BitArray containing
        the sum, and `carry` is the carry bit.
        """
        self._validate_other(other)
        c, sum = 0, []
        for a, b in zip(reversed(self.bits), reversed(other.bits)):
            # Full adder logic, see https://en.wikipedia.org/wiki/Adder_(electronics)#Full_adder
            s = Xor(Xor(a, b), c).eval()
            c = Or(And(a, b), And(c, Xor(a, b))).eval()
            sum.append(s)
        sum.reverse()
        return BitArray(sum), c

    def __mul__(self, other):
        """
        Unsigned multiplication. Returns the tuple `(low, high)` where `low` is a BitArray
        containing the lower bits of the product, and `high` is a BitArray containing the higher
        bits of the product.
        """
        self._validate_other(other)
        len = self.len()
        product = BitArray([0] * (2 * len))
        for i, a in enumerate(reversed(self.bits)):
            # See https://en.wikipedia.org/wiki/Binary_multiplier#Unsigned_integers
            head = [0] * (len - i)
            tail = [0] * i
            part = head + [And(a, b).eval() for b in other.bits] + tail
            (product, _) = product + BitArray(part)
        return BitArray(product[len:]), BitArray(product[:len])

    def __and__(self, other):
        """Bitwise AND"""
        self._validate_other(other)
        return BitArray([And(a, b).eval() for a, b in zip(self.bits, other.bits)])

    def __or__(self, other):
        """Bitwise OR"""
        self._validate_other(other)
        return BitArray([Or(a, b).eval() for a, b in zip(self.bits, other.bits)])

    def __xor__(self, other):
        """Bitwise XOR"""
        self._validate_other(other)
        return BitArray([Xor(a, b).eval() for a, b in zip(self.bits, other.bits)])

    def __invert__(self):
        """Bitwise NOT"""
        return BitArray([Not(a) for a in self.bits])

    def __iand__(self, other):
        """In-place bitwise AND"""
        self._validate_other(other)
        return self & other

    def __ior__(self, other):
        """In-place bitwise OR"""
        self._validate_other(other)
        return self | other

    def __ixor__(self, other):
        """In-place bitwise XOR"""
        self._validate_other(other)
        return self ^ other

    def __lshift__(self, n):
        """Shifts left by `n` bits"""
        assert n >= 0, "Shift must be non-negative."
        shifted = self.bits[n:] + [0] * n
        return BitArray(shifted)

    def __rshift__(self, n):
        """Shifts right by `n` bits"""
        assert n >= 0, "Shift must be non-negative."
        shifted = [0] * n + self.bits[:-n] if n > 0 else self.bits
        return BitArray(shifted)

    def __str__(self):
        return "_".join(str(bit) for bit in self.bits)
    
    def print(self, bits_per_line):
        """Prints the BitArray in chunks of `bits_per_line` bits"""
        for i in range(0, len(self.bits), bits_per_line):
            print(" ".join(str(bit) for bit in self.bits[i : i + bits_per_line]))

    def reverse_bits(self):
        """Reverses the order of the bits"""
        return BitArray(self.bits[::-1])

    def rotate_left(self, n):
        """Rotates left by `n` bits"""
        assert n >= 0, "Shift must be non-negative."
        rotated = self.bits[n:] + self.bits[:n]
        return BitArray(rotated)

    def rotate_right(self, n):
        """Rotates right by `n` bits"""
        assert n >= 0, "Shift must be non-negative."
        rotated = self.bits[-n:] + self.bits[:-n] if n > 0 else self.bits
        return BitArray(rotated)

In [39]:
# Example usage:
expr = BitArray(["a", "b", "c", "d"])
b = BitArray.from_int(0b1100, 4)

assert f"{expr}" == "a_b_c_d"
assert f"{b}" == "1_1_0_0"
assert b.to_int() == 0b1100
assert f"{expr & b}" == "a_b_0_0"
assert f"{(expr | b)}" == "1_1_c_d"
assert f"{(expr ^ b)}" == "~a_~b_c_d"
assert f"{(~expr)}" == "~a_~b_~c_~d"
assert f"{(expr << 2)}" == "c_d_0_0"
assert f"{(expr >> 2)}" == "0_0_a_b" 
assert f"{expr.reverse_bits()}" == "d_c_b_a"
assert f"{expr.rotate_left(1)}" == "b_c_d_a"
assert f"{expr.rotate_right(1)}" == "d_a_b_c"

In [40]:
# Randomized tests for 64-bit BitArrays

@dataclass
class U64:
    """An utility class for C-like 64-bit unsigned integers"""
    x: int

    def __init__(self, x):
        self.x = x & (2**64 - 1)

    def __add__(self, other):
        return U64(self.x + other.x)
    
    def __mul__(self, other):
        return U64(self.x * other.x)

    def __and__(self, other):
        return U64(self.x & other.x)

    def __or__(self, other):
        return U64(self.x | other.x)

    def __xor__(self, other):
        return U64(self.x ^ other.x)

    def __invert__(self):
        return U64(~self.x & (2**64 - 1))

    def __lshift__(self, n):
        return U64(self.x << n)

    def __rshift__(self, n):
        return U64(self.x >> n)

    def __str__(self):
        return f"{self.x:#0{18}x}"

    def __int__(self):
        return self.x

    def rand():
        from random import randint

        return U64(randint(0, 2**64 - 1))

    def rotate_left(self, n):
        return U64((self.x << n) | (self.x >> (64 - n)))

    def rotate_right(self, n):
        return U64((self.x >> n) | (self.x << (64 - n)))

count = 3
for i in range(count):
    expr = U64.rand()
    b = U64.rand()

    arr_a = BitArray.from_int(int(expr), 64)
    arr_b = BitArray.from_int(int(b), 64)

    sum, _ = arr_a + arr_b
    prod, _ = arr_a * arr_b
    assert int(expr + b) == sum.to_int()
    assert int(expr * b) == prod.to_int()
    
    assert int(expr & b) == (arr_a & arr_b).to_int()
    assert int(expr | b) == (arr_a | arr_b).to_int()
    assert int(expr ^ b) == (arr_a ^ arr_b).to_int()
    assert int(~expr) == (~arr_a).to_int()
    assert int(expr << 2) == (arr_a << 2).to_int()
    assert int(expr >> 2) == (arr_a >> 2).to_int()
    assert int(expr.rotate_left(2)) == (arr_a.rotate_left(2)).to_int()
    assert int(expr.rotate_right(2)) == (arr_a.rotate_right(2)).to_int()
    print(f"{i+1} of {count} tests passed.", end="\r")

3 of 3 tests passed.

In [41]:
# BitArray manipulation functions

def compress(x, mask):
    """Moves the masked bits in `x` to the right and sets the rest to 0."""
    from math import log2

    steps = int(log2(x.len()))
    assert 2**steps == x.len(), "BitArray length must be a power of 2."

    x &= mask
    mk = ~mask << 1
    for i in range(steps):
        p = 1
        mp = mk ^ (mk << p)
        for _ in range(steps - 1):
            p *= 2
            mp = mp ^ (mp << p)
        mv = mp & mask
        mask = (mask ^ mv) | (mv >> (1 << i))
        t = x & mv
        x = x ^ t | (t >> (1 << i))
        mk &= ~mp
    return x


def delta_swap(x, mask, shift):
    """Moves the masked bits to the left by `shift` positions. For this function to work properly,
    the mask and the shifted mask should not overlap, ie. `mask & (mask << shift) == 0` and no bits
    should be shifted out of the 64-bit integer, ie. `((mask << shift) >> shift) == mask`.
    """
    t = ((x >> shift) ^ x) & mask
    return (x ^ t) ^ (t << shift)


def delta_exchange(x, y, mask, shift):
    """Exchanges the masked bits in `x` with the bits in `y` masked by `mask << shift`. For this
    function to work properly, no bits should be shifted out of the 64-bit integers, ie.
    `((mask << shift) >> shift) == mask`.
    """
    t = ((y >> shift) ^ x) & mask
    return (x ^ t), (y ^ (t << shift))


def exchange(x, y, mask):
    """Exchanges the masked bits between `x` and `y`."""
    x = x ^ y
    y = y ^ (x & mask)
    x = x ^ y
    return x, y


x = BitArray(["a", "b", "c", "d", "e", "f", "g", "h"])
y = BitArray(["0", "1", "2", "3", "4", "5", "6", "7"])

mask = BitArray([0, 0, 0, 1, 0, 0, 0, 1])

swapped = delta_swap(x, mask, 3)
assert f"{swapped}" == "d_b_c_a_h_f_g_e"

x_swapped, y_swapped = delta_exchange(x, y, mask, 3)
assert f"{x_swapped}" == "a_b_c_0_e_f_g_4"
assert f"{y_swapped}" == "d_1_2_3_h_5_6_7"

x, y = exchange(x, y, mask)
assert f"{x}" == "a_b_c_3_e_f_g_7"
assert f"{y}" == "0_1_2_d_4_5_6_h"

assert f"{compress(x, mask)}" == "0_0_0_0_0_0_3_7"
assert f"{compress(y, mask)}" == "0_0_0_0_0_0_d_h"

In [42]:
# Transposing a 8x8 bit matrix stored in a 64-bit integer:

# Input:    Mask 0:   Step 0:   Mask 1:   Step 1:  Mask 2:   Result:
# .$ZYXWVU  00000000  .TZRXPVN  00000000  .TLDXPHz 00000000  .TLDvnf7
# TSRQPONM  10101010  $SYQWOUM  00000000  $SKCWOGy 00000000  $SKCume6
# LKJIHGFE  00000000  LDJBHzFx  11001100  ZRJBVNFx 00000000  ZRJBtld5
# DCBAzyxw  10101010  KCIAGyEw  11001100  YQIAUMEw 00000000  YQIAskc4
# vutsrqpo  00000000  vntlrjph  00000000  vnf7rjb3 11110000  XPHzrjb3
# nmlkjihg  10101010  umskqiog  00000000  ume6qia2 11110000  WOGyqia2
# fedcba98  00000000  f7d5b391  11001100  tld5ph91 11110000  VNFxph91
# 76543210  10101010  e6c4a280  11001100  skc4og80 11110000  UMEwog80

# fmt: off
bits = [
    '.','$','Z','Y','X','W','V','U',
    'T','S','R','Q','P','O','N','M',
    'L','K','J','I','H','G','F','E',
    'D','C','B','A','z','y','x','w',
    'v','u','t','s','r','q','p','o',
    'n','m','l','k','j','i','h','g',
    'f','e','d','c','b','a','9','8',
    '7','6','5','4','3','2','1','0',
]
# fmt: on

transpose = [bits[8 * i + j] for j in range(8) for i in range(8)]
matrix = BitArray(bits)

print("Input:")
matrix.print(8)

m0 = BitArray.from_int(0x00AA00AA00AA00AA, 64)
m1 = BitArray.from_int(0x0000CCCC0000CCCC, 64)
m2 = BitArray.from_int(0x00000000F0F0F0F0, 64)

assert matrix.len() == 64
assert m0.len() == 64
assert m1.len() == 64
assert m2.len() == 64

matrix = delta_swap(matrix, m0, 7)
matrix = delta_swap(matrix, m1, 14)
matrix = delta_swap(matrix, m2, 28)

assert matrix == BitArray(transpose)

print("\nResult:")
matrix.print(8)

Input:
. $ Z Y X W V U
T S R Q P O N M
L K J I H G F E
D C B A z y x w
v u t s r q p o
n m l k j i h g
f e d c b a 9 8
7 6 5 4 3 2 1 0

Result:
. T L D v n f 7
$ S K C u m e 6
Z R J B t l d 5
Y Q I A s k c 4
X P H z r j b 3
W O G y q i a 2
V N F x p h 9 1
U M E w o g 8 0


In [43]:
# Multiplication of two 8x8 bit matrices stored in 64-bit integers.
#
# A                               x B
#
# a00 a01 a02 a03 a04 a05 a06 a07   b00 b01 b02 b03 b04 b05 b06 b07
# a10 a11 a12 a13 a14 a15 a16 a17   b10 b11 b12 b13 b14 b15 b16 b17
# a20 a21 a22 a23 a24 a25 a26 a27   b20 b21 b22 b23 b24 b25 b26 b27
# a30 a31 a32 a33 a34 a35 a36 a37   b30 b31 b32 b33 b34 b35 b36 b37
# a40 a41 a42 a43 a44 a45 a46 a47   b40 b41 b42 b43 b44 b45 b46 b47
# a50 a51 a52 a53 a54 a55 a56 a57   b50 b51 b52 b53 b54 b55 b56 b57
# a60 a61 a62 a63 a64 a65 a66 a67   b60 b61 b62 b63 b64 b65 b66 b67
# a70 a71 a72 a73 a74 a75 a76 a77   b70 b71 b72 b73 b74 b75 b76 b77
#
# The element cij of the resulting matrix C is
#
# cij = (ai0 & b0j) ^ (ai1 & b1j) ^ ... ^ (ai7 & b7j).
#
# We'll use two masks. The first masks a row in A and the second masks a column in B.
#
# ROW       COL
#
# 00000000  00000001
# 00000000  00000001
# 00000000  00000001
# 00000000  00000001
# 00000000  00000001
# 00000000  00000001
# 00000000  00000001
# 11111111  00000001

expr = BitArray([f"a{i}{j}" for j in range(8) for i in range(8)])
b = BitArray([f"b{i}{j}" for j in range(8) for i in range(8)])

row = BitArray.from_int(0xFF, 64)
col = BitArray.from_int(0x0101010101010101, 64)

c = BitArray([0] * 64)

for i in range(8):
    col_a = col & (expr >> i)
    row_b = row & (b >> 8 * i)
    (prod, overflow) = col_a * row_b
    assert overflow.to_int() == 0
    c ^= prod
    
print("\nMultiplication:")
c.print(8)


Multiplication:
((((((((a70&b07)^(a60&b06))^(a50&b05))^(a40&b04))^(a30&b03))^(a20&b02))^(a10&b01))^(a00&b00)) ((((((((a70&b17)^(a60&b16))^(a50&b15))^(a40&b14))^(a30&b13))^(a20&b12))^(a10&b11))^(a00&b10)) ((((((((a70&b27)^(a60&b26))^(a50&b25))^(a40&b24))^(a30&b23))^(a20&b22))^(a10&b21))^(a00&b20)) ((((((((a70&b37)^(a60&b36))^(a50&b35))^(a40&b34))^(a30&b33))^(a20&b32))^(a10&b31))^(a00&b30)) ((((((((a70&b47)^(a60&b46))^(a50&b45))^(a40&b44))^(a30&b43))^(a20&b42))^(a10&b41))^(a00&b40)) ((((((((a70&b57)^(a60&b56))^(a50&b55))^(a40&b54))^(a30&b53))^(a20&b52))^(a10&b51))^(a00&b50)) ((((((((a70&b67)^(a60&b66))^(a50&b65))^(a40&b64))^(a30&b63))^(a20&b62))^(a10&b61))^(a00&b60)) ((((((((a70&b77)^(a60&b76))^(a50&b75))^(a40&b74))^(a30&b73))^(a20&b72))^(a10&b71))^(a00&b70))
((((((((a71&b07)^(a61&b06))^(a51&b05))^(a41&b04))^(a31&b03))^(a21&b02))^(a11&b01))^(a01&b00)) ((((((((a71&b17)^(a61&b16))^(a51&b15))^(a41&b14))^(a31&b13))^(a21&b12))^(a11&b11))^(a01&b10)) ((((((((a71&b27)^(a61&b26))^(a51&b25))^(a41

In [44]:
# Transposing a 16x16 bit matrix stored in four 64-bit integers.

# We want to get from this:
# fmt: off
expr = ['a00', 'a01', 'a02', 'a03', 'a04', 'a05', 'a06', 'a07', 'a08', 'a09', 'a10', 'a11', 'a12', 'a13', 'a14', 'a15',
     'a16', 'a17', 'a18', 'a19', 'a20', 'a21', 'a22', 'a23', 'a24', 'a25', 'a26', 'a27', 'a28', 'a29', 'a30', 'a31',
     'a32', 'a33', 'a34', 'a35', 'a36', 'a37', 'a38', 'a39', 'a40', 'a41', 'a42', 'a43', 'a44', 'a45', 'a46', 'a47',
     'a48', 'a49', 'a50', 'a51', 'a52', 'a53', 'a54', 'a55', 'a56', 'a57', 'a58', 'a59', 'a60', 'a61', 'a62', 'a63']

b = ['b00', 'b01', 'b02', 'b03', 'b04', 'b05', 'b06', 'b07', 'b08', 'b09', 'b10', 'b11', 'b12', 'b13', 'b14', 'b15',
     'b16', 'b17', 'b18', 'b19', 'b20', 'b21', 'b22', 'b23', 'b24', 'b25', 'b26', 'b27', 'b28', 'b29', 'b30', 'b31',
     'b32', 'b33', 'b34', 'b35', 'b36', 'b37', 'b38', 'b39', 'b40', 'b41', 'b42', 'b43', 'b44', 'b45', 'b46', 'b47',
     'b48', 'b49', 'b50', 'b51', 'b52', 'b53', 'b54', 'b55', 'b56', 'b57', 'b58', 'b59', 'b60', 'b61', 'b62', 'b63']

c = ['c00', 'c01', 'c02', 'c03', 'c04', 'c05', 'c06', 'c07', 'c08', 'c09', 'c10', 'c11', 'c12', 'c13', 'c14', 'c15',
     'c16', 'c17', 'c18', 'c19', 'c20', 'c21', 'c22', 'c23', 'c24', 'c25', 'c26', 'c27', 'c28', 'c29', 'c30', 'c31',
     'c32', 'c33', 'c34', 'c35', 'c36', 'c37', 'c38', 'c39', 'c40', 'c41', 'c42', 'c43', 'c44', 'c45', 'c46', 'c47',
     'c48', 'c49', 'c50', 'c51', 'c52', 'c53', 'c54', 'c55', 'c56', 'c57', 'c58', 'c59', 'c60', 'c61', 'c62', 'c63']

d = ['d00', 'd01', 'd02', 'd03', 'd04', 'd05', 'd06', 'd07', 'd08', 'd09', 'd10', 'd11', 'd12', 'd13', 'd14', 'd15',
     'd16', 'd17', 'd18', 'd19', 'd20', 'd21', 'd22', 'd23', 'd24', 'd25', 'd26', 'd27', 'd28', 'd29', 'd30', 'd31',
     'd32', 'd33', 'd34', 'd35', 'd36', 'd37', 'd38', 'd39', 'd40', 'd41', 'd42', 'd43', 'd44', 'd45', 'd46', 'd47',
     'd48', 'd49', 'd50', 'd51', 'd52', 'd53', 'd54', 'd55', 'd56', 'd57', 'd58', 'd59', 'd60', 'd61', 'd62', 'd63']

# To this:
a_t = ['a00', 'a16', 'a32', 'a48', 'b00', 'b16', 'b32', 'b48', 'c00', 'c16', 'c32', 'c48', 'd00', 'd16', 'd32', 'd48',
       'a01', 'a17', 'a33', 'a49', 'b01', 'b17', 'b33', 'b49', 'c01', 'c17', 'c33', 'c49', 'd01', 'd17', 'd33', 'd49',
       'a02', 'a18', 'a34', 'a50', 'b02', 'b18', 'b34', 'b50', 'c02', 'c18', 'c34', 'c50', 'd02', 'd18', 'd34', 'd50',
       'a03', 'a19', 'a35', 'a51', 'b03', 'b19', 'b35', 'b51', 'c03', 'c19', 'c35', 'c51', 'd03', 'd19', 'd35', 'd51']

b_t = ['a04', 'a20', 'a36', 'a52', 'b04', 'b20', 'b36', 'b52', 'c04', 'c20', 'c36', 'c52', 'd04', 'd20', 'd36', 'd52',
       'a05', 'a21', 'a37', 'a53', 'b05', 'b21', 'b37', 'b53', 'c05', 'c21', 'c37', 'c53', 'd05', 'd21', 'd37', 'd53',
       'a06', 'a22', 'a38', 'a54', 'b06', 'b22', 'b38', 'b54', 'c06', 'c22', 'c38', 'c54', 'd06', 'd22', 'd38', 'd54',
       'a07', 'a23', 'a39', 'a55', 'b07', 'b23', 'b39', 'b55', 'c07', 'c23', 'c39', 'c55', 'd07', 'd23', 'd39', 'd55']

c_t = ['a08', 'a24', 'a40', 'a56', 'b08', 'b24', 'b40', 'b56', 'c08', 'c24', 'c40', 'c56', 'd08', 'd24', 'd40', 'd56',
       'a09', 'a25', 'a41', 'a57', 'b09', 'b25', 'b41', 'b57', 'c09', 'c25', 'c41', 'c57', 'd09', 'd25', 'd41', 'd57',
       'a10', 'a26', 'a42', 'a58', 'b10', 'b26', 'b42', 'b58', 'c10', 'c26', 'c42', 'c58', 'd10', 'd26', 'd42', 'd58',
       'a11', 'a27', 'a43', 'a59', 'b11', 'b27', 'b43', 'b59', 'c11', 'c27', 'c43', 'c59', 'd11', 'd27', 'd43', 'd59']

d_t = ['a12', 'a28', 'a44', 'a60', 'b12', 'b28', 'b44', 'b60', 'c12', 'c28', 'c44', 'c60', 'd12', 'd28', 'd44', 'd60',
       'a13', 'a29', 'a45', 'a61', 'b13', 'b29', 'b45', 'b61', 'c13', 'c29', 'c45', 'c61', 'd13', 'd29', 'd45', 'd61',
       'a14', 'a30', 'a46', 'a62', 'b14', 'b30', 'b46', 'b62', 'c14', 'c30', 'c46', 'c62', 'd14', 'd30', 'd46', 'd62',
       'a15', 'a31', 'a47', 'a63', 'b15', 'b31', 'b47', 'b63', 'c15', 'c31', 'c47', 'c63', 'd15', 'd31', 'd47', 'd63']
# fmt: on

expr = BitArray(expr)
b = BitArray(b)
c = BitArray(c)
d = BitArray(d)

# The strategy is to (0) swap the bits in the integers `a`, `b`, `c`, and `d` so that they contain correct bits,
# but in the wrong order. Then (1) we will permute the integers to get the correct order.
#
# The following performs the first step:
#
# A helper function to print the 16x4 matrix to understand what's going on:
# def print16x4(bits):
#     for i in range(4):
#         print(" ".join(str(bits[i * 16 + j]) for j in range(16)))
#     print()


m = BitArray.from_int(0xF000F000F000F000, 64)

expr, b = delta_exchange(expr, b, m >> 4, 4)
expr, c = delta_exchange(expr, c, m >> 8, 8)
expr, d = delta_exchange(expr, d, m >> 12, 12)
b, c = delta_exchange(b, c, m >> 8, 4)
b, d = delta_exchange(b, d, m >> 12, 8)
c, d = delta_exchange(c, d, m >> 12, 4)

# Now the integers `a`, `b`, `c`, and `d` contain the correct bits. The following code will permute the bits
# to get the correct order. The index mapping for the bit permutation calculator can be generated by the following
# code:
#
# mapping = [a.bits.index(a_t[i]) for i in range(64)]
# print(" ".join(str(i) for i in mapping))

m0 = BitArray.from_int(0x0000AAAA0000AAAA, 64)
m1 = BitArray.from_int(0x00000000CCCCCCCC, 64)

expr = delta_swap(expr, m0, 15)
b = delta_swap(b, m0, 15)
c = delta_swap(c, m0, 15)
d = delta_swap(d, m0, 15)

expr = delta_swap(expr, m1, 30)
b = delta_swap(b, m1, 30)
c = delta_swap(c, m1, 30)
d = delta_swap(d, m1, 30)

assert expr.bits == a_t
assert b.bits == b_t
assert c.bits == c_t
assert d.bits == d_t