In [None]:
#!pip install qiskit
#!pip install qiskit[visualization]

In [None]:
from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
import numpy as np
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from qiskit_ibm_runtime import SamplerV2

from threading import Lock, Condition
from globals import *

import ipynb.fs.full.adversary as ADV

### Initial parameters

In [None]:
n_cube = pow(n,3)

qb_per_process = int(np.ceil(np.log2(n_cube)))

### Quantum Factory

In [None]:
class QuantumFactory:
    def __init__(self) -> None:
        self.coin = None
        self.leader = None
        self.generate_coin_circuit()
        self.generate_leader_circuit()

    def generate_coin_circuit(self):
        qc = QuantumCircuit(n)
        qc.h(0)
        qc.cx(0,range(1,n))
        qc.measure_all()    
        self.coin = qc
        print("generated coin circuit")
    
    def generate_leader_circuit(self):
        total_qubits = n * qb_per_process
        qc = QuantumCircuit(total_qubits)
        qc.h(range(0, qb_per_process))

        for j in range(1,n):
            for i in range(0, qb_per_process):
                qc.cx(i, i + j * qb_per_process)
        qc.measure_all()
        self.leader = qc
        print("generated leader circuit")

    def get_coin_circuit(self) -> QuantumCircuit:
         return self.coin.copy()
    
    def get_leader_circuit(self) -> QuantumCircuit:
         return self.leader.copy()

In [None]:
class Circuit():
    def __init__(self, system) -> None:
        self.system = system
        self.result = None
        self.memory = None
        self.measured = False
    def measure_circuit(self):
        aer_sim = AerSimulator(method="stabilizer")
        pm = generate_preset_pass_manager(backend=aer_sim, optimization_level=1)
        isa_qc = pm.run(self.system)
        sampler = SamplerV2(backend=aer_sim)
        
        result = sampler.run([isa_qc], shots=1).result()
        data_pub = result[0].data
        counts = data_pub.meas.get_counts()
        self.memory = list(counts.keys())[0]

In [None]:
class CircuitMessage(Message):
    def __init__(self, process_id, receivers, system) -> None:
        super().__init__(process_id, receivers)
        self.circuit = Circuit(system)  
    def __str__(self) -> str:
        return f"sender: {self.sender} | receivers: {self.receivers} | measured: {self.circuit.measured} | result: {self.circuit.result}"

### Global variables

In [None]:
quantum_factory = QuantumFactory()
coin_msgs = []
leader_msgs = []

waiting_num_of_msgs = []
msg_quantity_lock = Lock() 

coin_condition = Condition()
leader_condition = Condition()

In [None]:
def get_msgs_for_process(process_id, curr_leader_msgs=None, curr_coin_msgs=None):
    msg_count = 0
    if curr_leader_msgs:
        leader_lock.acquire()
        for msg in curr_leader_msgs:
            if process_id in msg.receivers:
                msg_count += 1
        leader_lock.release()
    elif curr_coin_msgs:
        coin_lock.acquire()
        for msg in curr_coin_msgs:
            if process_id in msg.receivers:
                msg_count += 1
        coin_lock.release()
    return msg_count

### Acquire coin state (for each process)

In [None]:
def send_coin(processes, process_id, coin_condition, epoch):
    coin_qc = quantum_factory.get_coin_circuit()
    coin_lock.acquire()
    new_msg = CircuitMessage(process_id, list(range(0,n)), coin_qc)
    
    if (len(coin_msgs) < epoch):
        coin_msgs.append([])
    coin_msgs[epoch-1].append(new_msg)
    print("adding coin =>", new_msg)
    coin_lock.release()
    
    with coin_condition:
        coin_condition.acquire()
        if len(coin_msgs[epoch-1]) >= MAX_ALIVE_PROCESSES:
            print("notifying coin_condition pr.", process_id)
            coin_condition.notify_all()
        coin_condition.release()

### Acquire Leader state (for each process)

In [None]:
def send_leader(processes, process_id, leader_condition, epoch):
    leader_qc = quantum_factory.get_leader_circuit()
    leader_lock.acquire()
    new_msg = CircuitMessage(process_id, list(range(0,n)), leader_qc)
    if len(leader_msgs) < epoch:
        leader_msgs.append([])
    leader_msgs[epoch-1].append(new_msg)
    print("adding leader =>", new_msg)
    leader_lock.release()
    
    with leader_condition:
        leader_condition.acquire()
        if len(leader_msgs[epoch-1]) >= MAX_ALIVE_PROCESSES:
            print("notifying leader_condition pr.", process_id)
            leader_condition.notify_all()
        leader_condition.release()

### Waiting number of messages

In [None]:
def init_waiting_num_of_msgs(epoch):
    msg_quantity_lock.acquire()
    if len(waiting_num_of_msgs) < epoch:
        if epoch-1 == 0:
            waiting_num_of_msgs.append([n]*n)    
        else:
            min_val = min(waiting_num_of_msgs[epoch-2])
            waiting_num_of_msgs.append([min_val]*n)
    print("waiting_num_of_msgs: ", waiting_num_of_msgs)
    msg_quantity_lock.release()
    
def update_waiting_num_of_msgs(processes, epoch, new_receivers):
    msg_quantity_lock.acquire()
    for pr in processes:
        if pr.id not in new_receivers:
            waiting_num_of_msgs[epoch-1][pr.id] -= 1
    print(waiting_num_of_msgs[epoch-1])
    msg_quantity_lock.release()  

### QuantumCoinFlip

In [None]:
def get_highest_leader_id(process, epoch):
    leader_lock.acquire()
    leader_measurements = {}
    curr_leader_msgs = leader_msgs[epoch-1]
    assert(len(curr_leader_msgs) >= MAX_ALIVE_PROCESSES)
    for i in range(len(curr_leader_msgs)):      # TO-DO: change iterating with i to iterating whole message
        if process.id in curr_leader_msgs[i].receivers:
            owner_id = curr_leader_msgs[i].sender

            if not curr_leader_msgs[i].circuit.measured:
                curr_leader_msgs[i].circuit.measure_circuit()
                curr_leader_msgs[i].circuit.measured = True
            leader_outcome = int(curr_leader_msgs[i].circuit.memory[:qb_per_process], 2)

            existing_ids = leader_measurements.get(leader_outcome)
            if existing_ids is not None:
                existing_ids.append(owner_id)
            else:
                existing_ids = [owner_id]

            leader_measurements.update({leader_outcome:existing_ids})
    print("pr. ", process.id , "LEADER MEASUREMENTS e: ", epoch)
    print(leader_measurements)
    highest_leader_outcome = max(leader_measurements)

    leader_processes_ids = leader_measurements.get(highest_leader_outcome)
    leader_processes_ids.sort()
    print("pr. ", process.id , "highest_leader_val: ", highest_leader_outcome, " id: ", leader_processes_ids[0])
    leader_lock.release()
    return leader_processes_ids[0]

In [None]:
def get_coin_result(sender_id, curr_coin_msgs):
    coin_lock.acquire()
    memory_val = None
    for msg in curr_coin_msgs:
        if msg.sender == sender_id:
            if not msg.circuit.measured:
                msg.circuit.measure_circuit()
                msg.circuit.measured = True
            memory_val = msg.circuit.memory
            break
    coin_lock.release()
    assert(memory_val is not None)
    return memory_val

In [None]:
def print_after_adversary(process, epoch):
    leader_lock.acquire()
    print("pr. ", process.id, " leader msgs (after adversary intervened)")
    for msg in leader_msgs[epoch-1]:
        print(msg)
    leader_lock.release()

    coin_lock.acquire()
    print("pr. ", process.id, " coin msgs (after adversary intervened)")
    for msg in coin_msgs[epoch-1]:
        print(msg)
    coin_lock.release()

In [None]:
def quantum_coin_flip(processes, process, epoch):
    init_waiting_num_of_msgs(epoch)
    send_coin(processes, process.id, coin_condition, epoch)
    send_leader(processes, process.id, leader_condition, epoch)

    # Adversary's action goes here
    new_receivers = ADV.adversary_take_over(process, coin_msgs[epoch-1], leader_msgs[epoch-1])
    if not process.non_faulty:
        print_after_adversary(process, epoch)

        update_waiting_num_of_msgs(processes, epoch, new_receivers)
        
        own_coin_val = get_coin_result(process.id, coin_msgs[epoch-1])
        return own_coin_val[process.id]

    with leader_condition:
        while get_msgs_for_process(process.id, curr_leader_msgs=leader_msgs[epoch-1]) < MAX_ALIVE_PROCESSES:
            leader_condition.acquire()
            print("pr. ", process.id, " waiting other leaders")
            leader_condition.wait()
            print("pr. ", process.id, " finished waiting leaders")
            leader_condition.release()

    leader_process_id  = get_highest_leader_id(process, epoch)
    print("pr. ", process.id, " leader_process: ", leader_process_id)

    leader_coin = None

    with coin_condition:
        while get_msgs_for_process(process.id, curr_coin_msgs=coin_msgs[epoch-1]) < MAX_ALIVE_PROCESSES:
            coin_condition.acquire()
            print("pr. ", process.id, " waiting other coins")
            coin_condition.wait()
            print("pr. ", process.id, " finished waiting coins")
            coin_condition.release()

    leader_coin = get_coin_result(leader_process_id, coin_msgs[epoch-1])
    return leader_coin[process.id]