In [118]:
from copy import copy
from dataclasses import dataclass
from typing import Any, List


def evaluate(a):
    try:
        return copy(a.eval())
    except AttributeError:
        return copy(a)


def validate(a):
    if not (a in [0, 1] or isinstance(a, (str, And, Or, Xor, Not))):
        raise ValueError(
            "Value must be 0, 1, or an instance of str, And, Or, Xor, or Not."
        )
    else:
        return copy(a)


@dataclass
class And:
    a: Any
    b: Any

    def __init__(self, a, b):
        self.a = validate(a)
        self.b = validate(b)

    def eval(self):
        match evaluate(self.a), evaluate(self.b):
            case (0, _) | (_, 0):
                return 0
            case (Not(a), b) | (a, Not(b)) if a == b:
                return 0
            case (1, a) | (a, 1):
                return a
            case _:
                return self

    def __str__(self):
        r = self.eval()
        if r in [0, 1] or isinstance(r, str):
            return f"{r}"
        else:
            return f"({self.a}&{self.b})"


@dataclass
class Or:
    a: Any
    b: Any

    def __init__(self, a, b):
        self.a = validate(a)
        self.b = validate(b)

    def eval(self):
        match evaluate(self.a), evaluate(self.b):
            case (1, _) | (_, 1):
                return 1
            case (0, a) | (a, 0):
                return a
            case (Not(a), b) | (a, Not(b)) if a == b:
                return 1
            case _:
                return self

    def __str__(self):
        r = self.eval()
        if r in [0, 1] or isinstance(r, str):
            return f"{r}"
        else:
            return f"({self.a}|{self.b})"


@dataclass
class Xor:
    a: Any
    b: Any

    def __init__(self, a, b):
        self.a = validate(a)
        self.b = validate(b)

    def eval(self):
        match evaluate(self.a), evaluate(self.b):
            case (0, 1) | (1, 0):
                return 1
            case (0, 0) | (1, 1):
                return 0
            case (0, a) | (a, 0):
                return a
            case (1, a) | (a, 1):
                return Not(a)
            case (a, b) if a == b:
                return 0
            case (Not(a), b) | (a, Not(b)) if a == b:
                return 1
            case (Xor(a, b), c) | (a, Xor(b, c)) if a == b:
                return c
            case (Xor(a, b), c) | (a, Xor(b, c)) if a == c:
                return b
            case (Xor(a, b), c) | (a, Xor(b, c)) if b == c:
                return a
            case _:
                return copy(self)

    def __str__(self):
        r = self.eval()
        if r in [0, 1] or isinstance(r, str):
            return f"{r}"
        else:
            return f"({self.a}^{self.b})"


@dataclass
class Not:
    a: Any

    def __init__(self, a):
        self.a = validate(a)

    def eval(self):
        match evaluate(self.a):
            case 0:
                return 1
            case 1:
                return 0
            case Not(a):
                return a
            case _:
                return copy(self)

    def __str__(self):
        r = self.eval()
        if r in [0, 1]:
            return f"{r}"
        return f"~{self.a}"

In [119]:
# Example usage:
assert And("a", 0).eval() == 0
assert And("a", 1).eval() == "a"
assert And("a", Not("a")).eval() == 0
assert Or("a", 1).eval() == 1
assert Or("a", 0).eval() == "a"
assert Or("a", Not("a")).eval() == 1
assert Xor("a", 1).eval() == Not("a")
assert Xor("a", 0).eval() == "a"
assert Xor("a", Not("a")).eval() == 1
assert Xor("a", "a").eval() == 0
assert Xor("a", Xor("a", "b")).eval() == "b"
assert Not(Not("a")).eval() == "a"

In [120]:
@dataclass
class BitArray:
    bits: List

    def from_int(n, length):
        """Create a BitArray from an integer"""
        return BitArray([int(i) for i in bin(n)[2:].zfill(length)])
    
    def to_int(self):
        """Converts the BitArray to an integer. If the BitArray contains bits with unknown values, returns None."""
        if not all(isinstance(evaluate(bit), int) for bit in self.bits):
            return None
        return int("".join(str(bit) for bit in self.bits), 2)

    def len(self):
        """Returns the length of the BitArray"""
        return len(self.bits)

    def __and__(self, other):
        """Bitwise AND"""
        if self.len() != other.len():
            raise ValueError("BitArray lengths must be equal.")
        return BitArray([And(a, b).eval() for a, b in zip(self.bits, other.bits)])

    def __or__(self, other):
        """Bitwise OR"""
        if self.len() != other.len():
            raise ValueError("BitArray lengths must be equal.")
        return BitArray([Or(a, b).eval() for a, b in zip(self.bits, other.bits)])

    def __xor__(self, other):
        """Bitwise XOR"""
        if self.len() != other.len():
            raise ValueError("BitArray lengths must be equal.")
        return BitArray([Xor(a, b).eval() for a, b in zip(self.bits, other.bits)])

    def __invert__(self):
        """Bitwise NOT"""
        return BitArray([Not(a) for a in self.bits])

    def __iand__(self, other):
        """In-place bitwise AND"""
        return self & other

    def __ior__(self, other):
        """In-place bitwise OR"""
        return self | other

    def __ixor__(self, other):
        """In-place bitwise XOR"""
        return self ^ other

    def __lshift__(self, n):
        """Shift left by n bits"""
        shifted = self.bits[n:] + [0] * n
        return BitArray(shifted)

    def __rshift__(self, n):
        """Shift right by n bits"""
        shifted = [0] * n + self.bits[:self.len() - n]
        return BitArray(shifted)

    def __str__(self):
        return "_".join(str(bit) for bit in self.bits)

In [121]:
# Example usage:
a = BitArray(["a", "b", "c", "d"])
b = BitArray.from_int(0b1100, 4)

assert f"{a}" == "a_b_c_d"
assert f"{b}" == "1_1_0_0"
assert b.to_int() == 0b1100
assert f"{a & b}" == "a_b_0_0"
assert f"{(a | b)}" == "1_1_c_d"
assert f"{(a ^ b)}" == "~a_~b_c_d"
assert f"{(~a)}" == "~a_~b_~c_~d"
assert f"{(a << 2)}" == "c_d_0_0"
assert f"{(a >> 2)}" == "0_0_a_b"

In [124]:
@dataclass
class U64:
    x: int

    def __init__(self, x):
        self.x = x & (2**64 - 1)

    def __and__(self, other):
        return U64(self.x & other.x)

    def __or__(self, other):
        return U64(self.x | other.x)

    def __xor__(self, other):
        return U64(self.x ^ other.x)

    def __invert__(self):
        return U64(~self.x & (2**64 - 1))

    def __lshift__(self, n):
        return U64(self.x << n)

    def __rshift__(self, n):
        return U64(self.x >> n)

    def __str__(self):
        return f"{self.x:#0{18}x}"

    def __int__(self):
        return self.x

    def rand():
        from random import randint

        return U64(randint(0, 2**64 - 1))

    def rotate_left(self, n):
        return U64((self.x << n) | (self.x >> (64 - n)))

    def rotate_right(self, n):
        return U64((self.x >> n) | (self.x << (64 - n)))


# Randomized tests for BitArray with 64 bits
count = 100
for i in range(count):
    a = U64.rand()
    b = U64.rand()

    arr_a = BitArray.from_int(int(a), 64)
    arr_b = BitArray.from_int(int(b), 64)

    assert int(a & b) == (arr_a & arr_b).to_int()
    assert int(a | b) == (arr_a | arr_b).to_int()
    assert int(a ^ b) == (arr_a ^ arr_b).to_int()
    assert int(~a) == (~arr_a).to_int()
    assert int(a << 2) == (arr_a << 2).to_int()
    assert int(a >> 2) == (arr_a >> 2).to_int()
    print(f"{i+1} of {count} tests passed.", end="\r")

100 of 100 tests passed.

In [125]:
# The index mapping for transposing a 8x8 bit matrix stored in a 64-bit integer:
# 0 8 16 24 32 40 48 56 1 9 17 25 33 41 49 57 2 10 18 26 34 42 50 58 3 11 19 27 35 43 51 59 4 12 20 28 36 44 52 60 5 13 21 29 37 45 53 61 6 14 22 30 38 46 54 62 7 15 23 31 39 47 55 63 

def permute_step(x, mask, shift):
    t = ((x >> shift) ^ x) & mask
    return (x ^ t) ^ (t << shift)

x = BitArray([f"a{i:02}" for i in range(64)])
m0 = BitArray.from_int(0x00aa00aa00aa00aa, 64)
m1 = BitArray.from_int(0x0000cccc0000cccc, 64)
m2 = BitArray.from_int(0x00000000f0f0f0f0, 64)

assert x.len() == 64
assert m0.len() == 64
assert m1.len() == 64
assert m2.len() == 64

x = permute_step(x, m0, 7)
x = permute_step(x, m1, 14)
x = permute_step(x, m2, 28)

print(f"{x}")

a00_a08_a16_a24_a32_a40_a48_a56_a01_a09_a17_a25_a33_a41_a49_a57_a02_a10_a18_a26_a34_a42_a50_a58_a03_a11_a19_a27_a35_a43_a51_a59_a04_a12_a20_a28_a36_a44_a52_a60_a05_a13_a21_a29_a37_a45_a53_a61_a06_a14_a22_a30_a38_a46_a54_a62_a07_a15_a23_a31_a39_a47_a55_a63
