In [None]:
#!pip install ipynb

In [None]:
import numpy as np
import threading
import random
from protocol_tests import test_all

import sys
try:
    del sys.modules['ipynb.fs.full.weak_global_coin']
except KeyError:
    pass

import ipynb.fs.full.weak_global_coin as QCF

try:
    del sys.modules['globals']
except KeyError:
    pass
from globals import *

In [None]:
broadcasted_messages = []
first_to_decide = None

QUESTION_MARK = "?"
WAITING_MESSAGE = "waiting"
HALF_PLUS_ONE = int(np.floor(n/2)) + 1

broadcasting_lock = threading.Lock()
decision_lock = threading.Lock()

In [None]:
class Process:
    def __init__(self, id, input_val) -> None:
        self.id = id
        self.input_val = input_val
        self.round_messages = {}
        self.output = None
        self.decision_epoch = None
        self.non_faulty = True
    def __str__(self):
        return f"id: {self.id} | round_messages: {self.round_messages}"
    
class BroadcastMessage(Message):
    def __init__(self, process_id, receivers, epoch, round, message) -> None:
        super().__init__(process_id, receivers)
        self.epoch = epoch
        self.round = round
        self.message = message
        self.read = [False for number in range(n)]
    def __str__(self):
        return f"sender: {self.sender} | epoch: {self.epoch} | round: {self.round} | message: {self.message}"

In [None]:
def broadcast(process_id, epoch, round, message):
    new_msg = BroadcastMessage(process_id, list(range(n)), epoch, round, message)
    broadcasting_lock.acquire()
    broadcasted_messages.append(new_msg)
    print("process ", process_id, "adding =>", new_msg)
    broadcasting_lock.release()

def waiting_condition(num_received_messages, round, process_id):    # TO-DO: delete process_id, just for debugging
    if round == 1 or round == 2:
        actual_alive_processes = [1 for pr in processes if pr.non_faulty].count(1)
        #print("process ", process_id, " receiving msgs. missing: ", actual_alive_processes-num_received_messages)
        return num_received_messages < actual_alive_processes
    elif round == 3:
        #print("process ", process_id, " receiving msgs. missing: ", MAX_ALIVE_PROCESSES-num_received_messages)
        return num_received_messages < MAX_ALIVE_PROCESSES

def receive(process, epoch, round, required_val=None):
    num_received_messages = 0
    while waiting_condition(num_received_messages, round, process.id):
        broadcasting_lock.acquire()
        for msg in broadcasted_messages:
            if msg.epoch == epoch and msg.round == round and process.id in msg.receivers and not msg.read[process.id]:
                if round == 3:
                    assert(msg.message == required_val)
                existing_count = process.round_messages.get(msg.message, 0)
                existing_count += 1
                process.round_messages.update({msg.message:existing_count})            

                num_received_messages += 1
                msg.read[process.id] = True
        broadcasting_lock.release()

def get_majority_value(process):
    for value, count in process.round_messages.items():
        if count >= HALF_PLUS_ONE:
            return value
    return QUESTION_MARK

def get_most_frequent_val(process):
    most_frequent_val = max(process.round_messages, key=process.round_messages.get)
    
    if most_frequent_val == QUESTION_MARK:
        process.round_messages.pop(most_frequent_val)
        most_frequent_val = None
        if process.round_messages:          # if the dictionary is not empty after deleting the (first found) most frequent value
            most_frequent_val = max(process.round_messages, key=process.round_messages.get)
    
    answer = most_frequent_val
    number = process.round_messages.get(most_frequent_val, 0)
    return answer, number

In [None]:
def agreement(process):
    current = process.input_val
    next = False
    epoch = 0
    while True:
        epoch += 1
        
        broadcast(process.id, epoch, 1, current)
        if not next:
            receive(process, epoch, 1)        
            current = get_majority_value(process)
            print("pr. ", process.id, "current (r: 1 e: ", epoch, ")", current)
        process.round_messages.clear()                  # needed so that round_messages can be reused for the counts of the next round

        broadcast(process.id, epoch, 2, current)
        if not next:
            receive(process, epoch, 2)
            answer, number = get_most_frequent_val(process)
            print("pr. ", process.id, "e: ", epoch,  " ans: ", answer, "num:", number)
        process.round_messages.clear()

        broadcast(process.id, epoch, 3, WAITING_MESSAGE)
        if not next:
            receive(process, epoch, 3, WAITING_MESSAGE)
        process.round_messages.clear()

        coin = QCF.quantum_coin_flip(processes, process, epoch)
        print("pr. ", process.id, "coin: ", coin)

        if next: 
            break
        
        if number >= HALF_PLUS_ONE:
            print("current is more than HALF_PLUS_ONE in pr. ", process.id)
            current = answer
            next = True
            process.decision_epoch = epoch
            decision_lock.acquire()
            global first_to_decide
            if first_to_decide == None:
                first_to_decide = process.id
            decision_lock.release()
        elif number >= 1:
            current = answer
        else:
            current = coin
            print("current changed to coin pr. ", process.id)
        print("pr. ", process.id, "current at end of e", epoch, ": ", current)

        if not process.non_faulty:
            process.decision_epoch = epoch
            break
    process.output = current
    return current

In [None]:
threads = []
processes = []

for i in range(0, n):
    pr = Process(i, str(random.choice([0,1])))
    #pr = Process(i, str(i%2))
    #pr = Process(i, "1")
    processes.append(pr)
    print("pr.id: ", pr.id, " input: ", pr.input_val)
    thr = threading.Thread(target=agreement, args=((pr,)))
    threads.append(thr)

for thr in threads:    
    thr.start()

for thr in threads:
    thr.join()

print("----------Final broadcasted_msgs list----------")
for msg in broadcasted_messages:
    print(msg)

print("*******   SOLUTION:   *******")
for pr in processes:
    print("process(", pr.id, ") = ", pr.output, " @epoch: ", pr.decision_epoch, " | input: ", pr.input_val, " | faulty: ", not pr.non_faulty)

In [None]:
test_all(processes, first_to_decide, broadcasted_messages)