In [73]:
#!pip install qiskit
#!pip install qiskit[visualization]

In [74]:
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector
from qiskit.visualization import plot_histogram
from qiskit_aer import Aer
import numpy as np

### Initial parameters

In [75]:
n = 3
n_cube = pow(n,3)
norm = np.sqrt(n_cube)
nr_bits = int(np.floor(np.log2(n_cube)) + 1)
dims_value = int(pow(2,nr_bits))

### Quantum Factory

In [76]:
class QuantumFactory:
    def __init__(self) -> None:
        self.coin = None
        self.leader = None
        self.generate_coin_circuit()
        self.generate_leader_circuit()
    def generate_coin_circuit(self):
        qc = QuantumCircuit(n)
        qc.h(0)
        for i in range(1,n):
            qc.cx(0,i)
            #qc.cx(i-1,i)
        qc.measure_all()
        qc.draw('mpl')
        
        self.coin = qc
    
    def tensor_n_qudits(self, qudit, nr_times):
        partial_copy = qudit.copy()
        tensor_vec = qudit.tensor(partial_copy)
        for i in range(1,nr_times-1):
                tensor_vec = tensor_vec.tensor(partial_copy)
        return tensor_vec
    
    def generate_leader_circuit(self):
        qudit = Statevector.from_int(1,dims=dims_value)
        final_ket = (1 / norm) * self.tensor_n_qudits(qudit,n)

        for i in range(2,n_cube+1):
            qudit = Statevector.from_int(i,dims=dims_value)
            final_ket += (1 / norm) * self.tensor_n_qudits(qudit,n)

        circuit_nr_bits = n * nr_bits
        another_qc = QuantumCircuit(circuit_nr_bits)
        another_qc.initialize(final_ket.data)
        another_qc.measure_all()

        self.leader = another_qc
    def get_coin_circuit(self) -> QuantumCircuit:
         return self.coin.copy()
    def get_leader_circuit(self) -> QuantumCircuit:
         return self.leader.copy()


In [77]:
class Circuit():
    def __init__(self, system) -> None:
        self.system = system
        self.result = None
        self.memory = None
    def measure_circuit(self):
        backend = Aer.get_backend('aer_simulator')
        self.system.save_statevector()
        self.result = backend.run(self.system, memory=True, shots=1).result()
        self.memory = self.result.get_memory()[0]

### Send messages between processes
sender → sender_id 

receivers → set of processes meant to receive the message

circuit → message from sender to receiver(s)

In [78]:
class Message:
    def __init__(self, sender, receivers, system) -> None:
        self.sender = sender
        self.receivers = receivers
        self.circuit = Circuit(system)
    def __str__(self):
        return f"sender: {self.sender} | receivers: {self.receivers} | memory: {self.circuit.memory}"
    

### Global variables

In [79]:
quantum_factory = QuantumFactory()
coin_msgs = []
leader_msgs = []

### Generate coin state

In [80]:
def send_coin(process_id):
    coin_qc = quantum_factory.get_coin_circuit()
    new_msg = Message(process_id, list(range(0,n)), coin_qc)
    coin_msgs.append(new_msg)

### Generate Leader state

In [81]:
def send_leader(process_id):
    leader_qc = quantum_factory.get_leader_circuit()
    new_msg = Message(process_id, list(range(0,n)), leader_qc)
    leader_msgs.append(new_msg)

### QuantumCoinFlip

In [82]:
def get_highest_leader():
    leader_measurements = {}
    for i in range(len(leader_msgs)):
            process_id = leader_msgs[i].sender
            leader_outcome = leader_msgs[i].circuit.memory[:nr_bits]

            existing_ids = leader_measurements.get(leader_outcome)
            if existing_ids is not None:
                existing_ids.append(process_id)
            else:
                existing_ids = [process_id]

            leader_measurements.update({leader_outcome:existing_ids})

    highest_leader_outcome = max(leader_measurements, key=lambda x:x[0])
    leader_processes_ids = leader_measurements.get(highest_leader_outcome)
    leader_processes_ids.sort()
    return leader_processes_ids[0]

In [83]:
for i in range(n):
    send_coin(i)
    send_leader(i)

for i in range(len(leader_msgs)):
    leader_msgs[i].circuit.measure_circuit()

leader_process_id  = get_highest_leader()

leader_coin = None

for i in range(len(coin_msgs)):
    if coin_msgs[i].sender == leader_process_id:
        coin_msgs[i].circuit.measure_circuit()
        leader_coin = coin_msgs[i].circuit.memory