In [None]:
import numpy as np
import threading

In [None]:
n = 4
t = 0
round_messages = {}                  # contains the count for every value received {value: count} 
broadcasted_messages = []

QUESTION_MARK = "?"
WAITING_MESSAGE = "waiting"
HALF_PLUS_ONE = int(np.floor(n/2)) + 1

broadcasting_lock = threading.Lock()

In [None]:
class Process:
    def __init__(self, id, input_val) -> None:
        self.id = id
        self.input_val = input_val
        self.round_messages = {}
        self.output = None
        
class Message:
    def __init__(self, process_id, receivers) -> None:
        self.sender = process_id
        self.receivers = receivers 
        self.message = None

class BroadcastMessage(Message):
    def __init__(self, process_id, receivers, epoch, round, message) -> None:
        super().__init__(process_id, receivers)
        self.epoch = epoch
        self.round = round
        self.message = message
        self.read = False
    def __str__(self):
        return f"id: {self.sender} | epoch: {self.epoch} | round: {self.round} | message: {self.message}"

In [None]:
def broadcast(process_id, epoch, round, message):
    new_msg = BroadcastMessage(process_id, list(range(n)), epoch, round, message)
    global broadcasted_messages
    broadcasting_lock.acquire()
    broadcasted_messages.append(new_msg)
    broadcasting_lock.release()

def receive(process, epoch, round, required_val=None):
    global broadcasted_messages
    num_received_messages = 0
    while num_received_messages < n-t:
        for msg in broadcasted_messages:
            if msg.epoch == epoch and msg.round == round and process.id in msg.receivers and not msg.read:
                if round == 3:
                    assert(msg.message == required_val)
                
                existing_count = process.round_messages.get(msg.message, 0)
                existing_count += 1
                process.round_messages.update({msg.message:existing_count})            

                num_received_messages += 1
        if round != 3:
            break                  # no waiting condition on the number of received messages for round 1 and 2
def get_majority_value(process):
    for value, count in process.round_messages.items():
        if count >= HALF_PLUS_ONE:
            return value
    return QUESTION_MARK
def quantum_coin_flip():
    return n*"1"

In [None]:
def agreement(process):
    current = process.input_val
    next = False
    epoch = 0
    while True:
        epoch += 1
        broadcast(process.id, epoch, 1, current)
        receive(process, epoch, 1)
        
        current = get_majority_value(process)

        process.round_messages.clear()                  # needed so that round_messages can be reused for the counts of the next round
        broadcast(process.id, epoch, 2, current)
        receive(process, epoch, 2)

        most_frequent_val = max(process.round_messages, key=process.round_messages.get)
        answer = None
        number = 0
        if most_frequent_val != QUESTION_MARK:
            answer = most_frequent_val
            number = process.round_messages.get(most_frequent_val)

        broadcast(process.id, epoch, 3, WAITING_MESSAGE)
        receive(process, epoch, 3, WAITING_MESSAGE)
        process.round_messages.clear()

        coin = quantum_coin_flip()

        if next: 
            break
        
        if number >= HALF_PLUS_ONE:
            current = answer
            next = True
        elif number >= 1:
            current = answer
        else:
            current = coin
    process.output = current
    print("Decision for ", process.id, " = ", process.output, " | ", current)
    return current

In [None]:
threads = []
processes = []

for i in range(0, n):
    pr = Process(i, i)
    processes.append(pr)

    t = threading.Thread(target=agreement, args=(pr))
    threads.append(t)

for t in threads:    
    t.start()

for t in threads:
    t.join()