In [1]:
import random

l = 64  # Default bit width
n = 32  # Total bit width for specific operations
J = 16  # Width of the MSB
j = n - J  # Width of the LSB

class BinaryNumber:
    def __init__(self, value=None, width=l):
        self.width = width
        if value is not None:
            self.value = value
        else:
            max_value = (1 << self.width) - 1
            self.value = random.randint(0, max_value)

    def __eq__(self, other):
        # Equality comparison based on value and width
        return self.value == other.value and self.width == other.width

    def __lt__(self, other):
        # Less than comparison based on value and width
        if self.width == other.width:
            return BinaryNumber(self.value < other.value)
        else:
            return NotImplemented

    def __gt__(self, other):
        # Greater than comparison based on value and width
        if self.width == other.width:
            return BinaryNumber(self.value > other.value)
        else:
            return NotImplemented

    def __add__(self, other):
        # Add two BinaryNumber instances and return the result as a new BinaryNumber
        if self.width == other.width:
            result = (self.value + other.value) % (1 << self.width)
            return BinaryNumber(result, self.width)
        else:
            return NotImplemented

    def __sub__(self, other):
        # Subtract one BinaryNumber from another and return the result as a new BinaryNumber
        if self.width == other.width:
            result = (self.value - other.value) % (1 << self.width)
            return BinaryNumber(result, self.width)
        else:
            return NotImplemented

    def __str__(self):
        # Format the binary number with leading zeros to match the width
        return f'Value: {format(self.value, f"0{self.width}b")} (Width: {self.width})'
    
def get_msb(size, num: BinaryNumber):
    # Extract the top J bits
    msb_shift = num.width - size
    msb_value = (num.value >> msb_shift) & ((1 << size) - 1)
    return BinaryNumber(msb_value, size)

def get_lsb(size, num: BinaryNumber):
    # Extract the bottom j bits
    lsb_value = num.value & ((1 << size) - 1)
    return BinaryNumber(lsb_value, size)

def generate_random_pair(num: BinaryNumber):
    max_value = (1 << num.width) - 1
    num1 = random.randint(0, max_value)
    num2 = (num.value - num1) % (1 << num.width)
    return BinaryNumber(num1, num.width), BinaryNumber(num2, num.width)

print(BinaryNumber(0b00001010001000111000000111111011, width=41))

Value: 00000000000001010001000111000000111111011 (Width: 41)


# This one will only look at the simpler case, we consider all values in n bits, nothing is in l bits. 

In [10]:
r = BinaryNumber(width=n)
r1, r2 = generate_random_pair(r)

a = BinaryNumber(width=n)
a1, a2 = generate_random_pair(a)

x = r - a

t1 = r1 - a1
t2 = r2 - a2

# get_lsb(j, x) > get_lsb(j, r)
carryin = get_lsb(j, x) > get_lsb(j, r)
carryin = BinaryNumber(carryin.value, J) 


max_label_length = 50

print("*"*3*max_label_length)
print(f"{'r'.ljust(max_label_length)} {r}")
print(f"{'x'.ljust(max_label_length)} {x}")
print(f"{'a'.ljust(max_label_length)} {a}")
print()

print(f"{'r1'.ljust(max_label_length)} {r1}")
print(f"{'r2'.ljust(max_label_length)} {r2}")
print(f"{'a1'.ljust(max_label_length)} {a1}")
print(f"{'a2'.ljust(max_label_length)} {a2}")
print()

print(f"{'r1 - a1'.ljust(max_label_length)} {t1}")
print(f"{'r2 - a2'.ljust(max_label_length)} {t2}")
print()

print(f"{'(r1-a1).msb'.ljust(max_label_length)} {get_msb(J, t1)}")
print(f"{'(r2-a2).msb'.ljust(max_label_length)} {get_msb(J, t2)}")
print()

print(f"{'r.msb'.ljust(max_label_length)} {get_msb(J, r)}")
print(f"{'a.msb'.ljust(max_label_length)} {get_msb(J, a)}")
print()

# Check if msbJ(r) - msbJ(a) == msbJ(x) + carryin
print(f"{'msbJ(r) - msbJ(a) == msbJ(x) + carryin'.ljust(max_label_length)} {get_msb(J, r) - get_msb(J, a) == get_msb(J, x) + carryin}")
print()

print(f"{'want msbJ(r) - msbJ(a)'.ljust(max_label_length)} {get_msb(J, r) - get_msb(J, a)}")
print(f"{'have open(msbJ(rshare-ashare))'.ljust(max_label_length)} {get_msb(J, t1) + get_msb(J, t2)}")
print(f"{'msbJ(x)'.ljust(max_label_length)} {get_msb(J, x)}")
print("*"*3*max_label_length)

******************************************************************************************************************************************************
r                                                  Value: 01100110100001010010101000001101 (Width: 32)
x                                                  Value: 11001001010101100101001011100000 (Width: 32)
a                                                  Value: 10011101001011101101011100101101 (Width: 32)

r1                                                 Value: 10101111100011010011010010110010 (Width: 32)
r2                                                 Value: 10110110111101111111010101011011 (Width: 32)
a1                                                 Value: 10101110000011011011010100111100 (Width: 32)
a2                                                 Value: 11101111001000010010000111110001 (Width: 32)

r1 - a1                                            Value: 00000001011111110111111101110110 (Width: 32)
r2 - a2                

# Note that msb(x) is either equal to open(msb(rshare-ashare)) or plus one. 

# Doing the Haar-Pika for l bits

Protocol might have a few typos. 

In [24]:
r = BinaryNumber(width=l)
r1, r2 = generate_random_pair(r)

a = BinaryNumber(width=l)
a1, a2 = generate_random_pair(a)

x = r - a

t1 = r1 - a1
t2 = r2 - a2

# get_lsb(j, x) > get_lsb(j, r)
carryin = get_lsb(l-J, x) > get_lsb(l-J, r)
carryin = BinaryNumber(carryin.value, J)


max_label_length = 50

print("*"*3*max_label_length)
print(f"{'r'.ljust(max_label_length)} {r}")
print(f"{'x'.ljust(max_label_length)} {x}")
print(f"{'a'.ljust(max_label_length)} {a}")
print()

print(f"{'r1'.ljust(max_label_length)} {r1}")
print(f"{'r2'.ljust(max_label_length)} {r2}")
print(f"{'a1'.ljust(max_label_length)} {a1}")
print(f"{'a2'.ljust(max_label_length)} {a2}")
print()

print(f"{'r1 - a1'.ljust(max_label_length)} {t1}")
print(f"{'r2 - a2'.ljust(max_label_length)} {t2}")
print()

print(f"{'(r1-a1).msb'.ljust(max_label_length)} {get_msb(J, t1)}")
print(f"{'(r2-a2).msb'.ljust(max_label_length)} {get_msb(J, t2)}")
print()

print(f"{'r.msb'.ljust(max_label_length)} {get_msb(J, r)}")
print(f"{'a.msb'.ljust(max_label_length)} {get_msb(J, a)}")
print()

# Check if msbJ(r) - msbJ(a) == msbJ(x) + carryin
print(f"{'msbJ(r) - msbJ(a) == msbJ(x) + carryin'.ljust(max_label_length)} {get_msb(J, r) - get_msb(J, a) == get_msb(J, x) + carryin}")
print()

print(f"{'want msbJ(r) - msbJ(a)'.ljust(max_label_length)} {get_msb(J, r) - get_msb(J, a)}")
print(f"{'have open(msbJ(rshare-ashare))'.ljust(max_label_length)} {get_msb(J, t1) + get_msb(J, t2)}")
print(f"{'msbJ(x)'.ljust(max_label_length)} {get_msb(J, x)}")
print("*"*3*max_label_length)

******************************************************************************************************************************************************
r                                                  Value: 1000011110001101010001011110000001000110010100010100000001101011 (Width: 64)
x                                                  Value: 0011000011011100011011011010010011001001011110000000011011101100 (Width: 64)
a                                                  Value: 0101011010110000110110000011101101111100110110010011100101111111 (Width: 64)

r1                                                 Value: 1001101000011001110010110010111100111001110101011011100000110011 (Width: 64)
r2                                                 Value: 1110110101110011011110101011000100001100011110111000100000111000 (Width: 64)
a1                                                 Value: 1101001111101100110111011101101010100111110011010101000101000111 (Width: 64)
a2                                    