In [None]:
#!pip install qiskit
#!pip install qiskit[visualization]
#!pip install waiting

In [None]:
from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
import numpy as np
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from qiskit_ibm_runtime import SamplerV2

from threading import Thread, Lock, Condition

### Initial parameters

In [None]:
n = 4
t = 0
n_cube = pow(n,3)

qb_per_process = int(np.ceil(np.log2(n_cube)))

MAX_ALIVE_PROCESSES = n - t

### Quantum Factory

In [None]:
class QuantumFactory:
    def __init__(self) -> None:
        self.coin = None
        self.leader = None
        self.generate_coin_circuit()
        self.generate_leader_circuit()

    def generate_coin_circuit(self):
        qc = QuantumCircuit(n)
        qc.h(0)
        qc.cx(0,range(1,n))
        qc.measure_all()    
        self.coin = qc
    
    def generate_leader_circuit(self):
        total_qubits = n * qb_per_process
        qc = QuantumCircuit(total_qubits)
        qc.h(range(0, qb_per_process))

        for j in range(1,n):
            for i in range(0, qb_per_process):
                qc.cx(i, i + j * qb_per_process)
        qc.measure_all()
        self.leader = qc

    def get_coin_circuit(self) -> QuantumCircuit:
         return self.coin.copy()
    
    def get_leader_circuit(self) -> QuantumCircuit:
         return self.leader.copy()

In [None]:
class Circuit():
    def __init__(self, system) -> None:
        self.system = system
        self.result = None
        self.memory = None
        self.measured = False
    def measure_circuit(self):
        aer_sim = AerSimulator(method="stabilizer")
        pm = generate_preset_pass_manager(backend=aer_sim, optimization_level=1)
        isa_qc = pm.run(self.system)
        sampler = SamplerV2(backend=aer_sim)
        
        result = sampler.run([isa_qc], shots=1).result()
        data_pub = result[0].data
        counts = data_pub.meas.get_counts()
        self.memory = list(counts.keys())[0]

### Send messages between processes
sender → sender_id 

receivers → set of processes meant to receive the message

system → message (i.e. a circuit) from sender to receiver(s)

In [None]:
class Message:
    def __init__(self, sender, receivers, system) -> None:
        self.sender = sender
        self.receivers = receivers
        self.circuit = Circuit(system)  
    def __str__(self) -> str:
        return f"sender: {self.sender} | measured: {self.circuit.measured}"

### Global variables

In [None]:
quantum_factory = QuantumFactory()
coin_msgs = []
leader_msgs = []

threads = []

coin_lock = Lock()
leader_lock = Lock()
coin_condition = Condition()
leader_condition = Condition()

### Acquire coin state (for each process)

In [None]:
def send_coin(process_id, coin_condition):
    coin_qc = quantum_factory.get_coin_circuit()
    coin_lock.acquire()
    new_msg = Message(process_id, list(range(0,n)), coin_qc)
    coin_msgs.append(new_msg)
    coin_lock.release()
    
    with coin_condition:
        coin_condition.acquire()
        if len(coin_msgs) >= MAX_ALIVE_PROCESSES:
            coin_condition.notify_all()
        coin_condition.release()

### Acquire Leader state (for each process)

In [None]:
def send_leader(process_id, leader_condition):
    leader_qc = quantum_factory.get_leader_circuit()
    leader_lock.acquire()
    new_msg = Message(process_id, list(range(0,n)), leader_qc)
    leader_msgs.append(new_msg)
    leader_lock.release()
    
    with leader_condition:
        leader_condition.acquire()
        if len(leader_msgs) >= MAX_ALIVE_PROCESSES:
            leader_condition.notify_all()
        leader_condition.release()

### QuantumCoinFlip

In [None]:
def get_highest_leader_id():
    leader_lock.acquire()
    leader_measurements = {}
    assert(len(leader_msgs) >= MAX_ALIVE_PROCESSES)
    for i in range(len(leader_msgs)):
        process_id = leader_msgs[i].sender

        if not leader_msgs[i].circuit.measured:
            leader_msgs[i].circuit.measure_circuit()
            leader_msgs[i].circuit.measured = True
        leader_outcome = int(leader_msgs[i].circuit.memory[:qb_per_process], 2)

        existing_ids = leader_measurements.get(leader_outcome)
        if existing_ids is not None:
            existing_ids.append(process_id)
        else:
            existing_ids = [process_id]

        leader_measurements.update({leader_outcome:existing_ids})
    highest_leader_outcome = max(leader_measurements)

    leader_processes_ids = leader_measurements.get(highest_leader_outcome)
    leader_processes_ids.sort()
    leader_lock.release()
    return leader_processes_ids[0]

In [None]:
def old_quantum_coin_flip():
    for i in range(n):
        send_coin(i)
        send_leader(i)

    for i in range(len(leader_msgs)):
        leader_msgs[i].circuit.measure_circuit()

    leader_process_id  = get_highest_leader_id()

    leader_coin = None

    for i in range(len(coin_msgs)):
        if coin_msgs[i].sender == leader_process_id:
            coin_msgs[i].circuit.measure_circuit()
            leader_coin = coin_msgs[i].circuit.memory
    return leader_coin

In [None]:
def quantum_coin_flip(process):
    send_coin(process.id, coin_condition)
    send_leader(process.id, leader_condition)

    with leader_condition:
        while len(leader_msgs) < MAX_ALIVE_PROCESSES:
            leader_condition.acquire()
            leader_condition.wait()
            leader_condition.release()

    leader_process_id  = get_highest_leader_id()

    leader_coin = None

    with coin_condition:
        while len(coin_msgs) < MAX_ALIVE_PROCESSES:
            coin_condition.acquire()
            coin_condition.wait()
            coin_condition.release()

    coin_lock.acquire()
    for i in range(len(coin_msgs)):
        if coin_msgs[i].sender == leader_process_id:
            if not coin_msgs[i].circuit.measured:
                coin_msgs[i].circuit.measure_circuit()
                coin_msgs[i].circuit.measured = True
            leader_coin = coin_msgs[i].circuit.memory
    process.coin_output = leader_coin[process.id]
    print("Process(", process.id, "): ", leader_coin[process.id])
    coin_lock.release()
    return leader_coin[process.id]

In [None]:
class Process:
    def __init__(self, id) -> None:
        self.id = id
        self.coin_output = None

In [None]:
processes = []
for i in range(0, n):
    pr = Process(i)
    processes.append(pr)
    thr = Thread(target=quantum_coin_flip, args=(pr,))
    threads.append(thr)

for thr in threads:    
    thr.start()

for thr in threads:
    thr.join()

for pr in processes:
    print("Coin for (", pr.id, ") ", pr.coin_output)