# BB84_Demo
## import and tools setup

In [None]:
import numpy as np
import random

from qiskit_network.components import Network
from qiskit_network.components.storage import QuantumStorage, ClassicalStorage
from qiskit_network.channels import Channels, QuantumChannels, ClassicalChannels
from qiskit_network import logger
from qiskit_network.kernel import Timeline

from qiskit.utils import QuantumInstance
from qiskit import Aer

In [None]:
# the backend to run the circuit
qi = QuantumInstance(Aer.get_backend('aer_simulator_statevector'), shots=1)

# pennylane-qiskit plugin may use here, but not sure if plugin will include pennylane-pulse.
#pennylane = qml.device('default.qubit', wires=1)

## Node

In [None]:
#alice will create its own network, it also can interact with other network, we should auto determind backend.
network  = Network.Environment(name='alice_net', user_nodes=['Alice', 'Bob','Eve'], backend=qi, hardware_node = None)  

Alice = network.get_user_node('Alice')
Bob = network.get_user_node('Bob')
# so this will be seperate out from normal setup.
# todo: set eve
Eve = network.get_user_node('Eve').eve_interception()

## Create Network

In [None]:
# auto, which it will depend on the action and hardware to determine the timeline length
timeline = Timeline('auto')

In [None]:
# set all network with idendical setup
# this will use for storing packet message
Cstorage = ClassicalStorage()
# Not sure how quantum storage will works out with real device yet, but we want to simulate the photon components in pulse, 
# there should be some research paper for references in #14 or pennylane photon.
# todo: need more inventigation to how to tranform it into pulse level
Qstorage = QuantumStorage(name="", timeline=timeline, num_memories=16,fidelity=0.85, frequency=80e6, efficiency=1, coherence_time=-1, wavelength=500)

# if any settings require qiskit pulse, will automatically turn into qiskit-pulse scheduler, but it only work for clean simulator.
QKD_QC = QuantumChannels(name="", timeline=timeline,distance = 1000,attenuation=0, polarization_fidelity = 1.0, light_speed=2e-4, frequency=8e7) 
ethernet_CC = ClassicalChannels(name="",timeline=timeline, distance = 1000,delay=1e9)
#other type of channels, like service channels for other protocols
# todo: how to setup custom encryption, will need to look into https://github.com/OpenKMIP/PyKMIP/tree/master/kmip/demos
encrypted_CC = ClassicalChannels(name="encrypted data over fiber",timeline= timeline, distance = 1000,delay=1e9, ) 

# todo: what else need to set here, need to learn more about channels from different quantum internet Standardization and latest research
channels = Channels(Classical=[ethernet_CC,encrypted_CC], Quantum=QKD_QC, ) 
# connection can be any known type of ethernet connection or custom coupling map
# (which can be useful for vqe optimize quantum network connection, link: https://pennylane.ai/blog/2022/10/the-quantum-internet-and-variational-quantum-optimization/)
network.setup_all(storage=[Qstorage,Cstorage], channels=channels, connection='mesh') 

## setup Action function

In [None]:
#cascade, reference: https://github.com/upsideon/qkd-qchack-2022/blob/main/qkd/src/cascade.py
import numpy as np

def quantum_bit_error_rate(local_set, remote_set):
    """
    Estimates the quantum bit error rate based on two sets.
    """
    num_incorrect = 0
    for i in range(len(local_set)):
        if remote_set[i] != local_set[i]:
            num_incorrect += 1
    return num_incorrect / len(local_set)

def binary_algorithm(block, block_indices, ask_parity_fn):
    """
    Recursively splits a block with odd error parity into
    left and right sub-blocks to find and correct one-bit
    errors.
    """

    # If we have a block of size one, we correct the bit as
    # it must have an odd number of errors per the input
    # assumptions of the binary algorithm.
    if len(block) == 1:
        if block[0] == 0:
            block[0] = 1
        else:
            block[0] = 0
        return block

    # The block split index selection ensures that the left
    # block has one more bit than the right when the block
    # size is odd.
    block_split_index = (len(block) + 1) // 2

    left_block = block[:block_split_index]
    right_block = block[block_split_index:]

    left_block_indices = block_indices[:block_split_index]
    right_block_indices = block_indices[block_split_index:]

    # Computing the current parity of the left block.
    current_left_block_parity = np.sum(left_block) % 2

    # Asking for the correct parity of the left block. The
    # parity of the right block can be inferred from the left
    # block's parity.
    correct_left_block_parity = ask_parity_fn(left_block_indices)

    # Determining the error parity for the left block.
    left_block_error_parity = current_left_block_parity ^ correct_left_block_parity

    # Recursing on the block with odd error parity.
    if left_block_error_parity == 1:
        left_block = binary_algorithm(left_block, left_block_indices, ask_parity_fn)
    else:
        right_block = binary_algorithm(right_block, right_block_indices, ask_parity_fn)

    return np.concatenate((left_block, right_block))

def client_cascade(noisy_key, qber, ask_parity_fn):
    """
    An implementation of the Cascade information reconciliation algorithm
    used for post-processing of keys exchanged via quantum key distribution.
    """

    # Representing the noisy key as a NumPy array, if it isn't already.
    noisy_key = np.array(noisy_key)

    key_length = len(noisy_key)

    # If the estimated quantum bit error rate is 0%, assume that a reasonable
    # amount of errors were present outside of the sampling set.
    if qber == 0.0:
        qber = 0.1

    # The top level block size is determined by the quantum bit error rate.
    block_size = int(np.round(0.73 / qber))

    iteration = 0

    while block_size <= key_length:
        # The identity permutation is used for the first iteration.
        permutation = np.arange(key_length)

        if iteration > 0:
            # Randomly shuffle Bob's key.
            rng = np.random.default_rng()
            permutation = rng.permutation(key_length)
            shuffled_key = noisy_key[permutation]

            # Increasing block size for current iteration.
            block_size *= 2
        else:
            # The key is not shuffled during the first iteration.
            shuffled_key = noisy_key.copy()

        num_blocks = int(np.ceil(key_length / block_size))

        for block_index in range(num_blocks):
            block = None
            block_indices = None

            block_start = block_size * block_index

            if block_index < num_blocks - 1:
                block_end = block_size * (block_index + 1)

                block = shuffled_key[block_start:block_end]
                block_indices = permutation[block_start:block_end]
            else:
                # The final block is not guaranteed to have the exact block size.
                block = shuffled_key[block_start:]
                block_indices = permutation[block_start:]

            # Computing current block parity.
            current_block_parity = np.sum(block) % 2

            # Requesting correct block parity.
            correct_block_parity = ask_parity_fn(block_indices)

            # Determining error parity.
            error_parity = current_block_parity ^ correct_block_parity

            # Correcting one-bit errors for blocks with odd error parity.
            if error_parity == 1:
                updated_block = binary_algorithm(block, block_indices, ask_parity_fn)
                noisy_key[block_indices] = updated_block

        iteration += 1

    return noisy_key

def get_ask_block_parity_fn(secret_key, socket):
    """
    Returns a function for requesting block parities that is compatible
    with the signature expected by client_cascade, but which communicates
    over a NetQasm socket.
    """

    secret_key = np.array(secret_key)

    def ask_block_parity(block_indices):
        request = ",".join([str(b) for b in list(block_indices)])

        socket.send(request)

        response = socket.recv()

        return int(response)

    return ask_block_parity

def get_block_parity_from_indices(full_key, indices):
    """
    Returns the parity of a subset of a key using indices.
    """
    element_sum = 0
    for i in indices:
        element_sum += full_key[i]
    return element_sum % 2

def send_cascade_stop(socket):
    socket.send("STOP")

def listen_and_respond_block_parity(correct_key, socket):
    """
    Listens for block parity questions and responds.
    """
    question = socket.recv()

    while question != "STOP":
        block_indices = [int(s) for s in question.split(",")]
        correct_parity = get_block_parity_from_indices(
            correct_key,
            block_indices,
        )
        socket.send(str(correct_parity))
        question = socket.recv()

In [None]:
# this will be action of the node whether its a receiver or a sender
class all_func:
    def __init__(self, sender, receiver, n=10):
        self.sender = sender
        self.receiver = receiver
        self.n = n

        self.bit_flips = [None for _ in range(self.n)]
        self.basis_flips = [random.randint(0, 1) for _ in range(self.n)]
        self.num_test_bits = max(self.n // 4, 1)

    def distribute_bb84_states(self, conn, epr_socket, sender=False):
        for i in range(self.n):
            # Note that we will need to inlcude other things like bsm node, so this is where everything is differnet
            if sender:
                qc = epr_socket.create_epr(num_qubit=1)
            else:
                qc = epr_socket.receive_epr(num_qubit=1)

            if self.basis_flips[i]:
                qc.h(0)
            # todo: maybe include different measurement or settings
            m = qc.run_measure(0,0,)

            conn.flush()
            self.bit_flips[i] = int(m)
        return self.bit_flips, self.basis_flips
    
    def estimate_error_rate(self,socket,key,start = None, end = None, sender = True):
        if sender:
            test_outcomes = key[start:end]
            test_indices = start,end

            socket.send_structured(StructuredMessage("Test indices", test_indices))
            target_test_outcomes = socket.recv_structured().payload
            socket.send_structured(StructuredMessage("Test outcomes", test_outcomes))
        else:
            test_indices = socket.recv_structured().payload
            start,end = test_indices
            test_outcomes = key[start:end]

            #logger.info(f"bob test indices: {test_indices}")
            #logger.info(f"bob test outcomes: {test_outcomes}")

            socket.send_structured(StructuredMessage("Test outcomes", test_outcomes))
            target_test_outcomes = socket.recv_structured().payload

        num_error = 0
        for (i1, i2) in zip(test_outcomes, target_test_outcomes):
            #assert i1 == i2
            if i1 != i2:
                num_error += 1

        return (num_error / (end - start))*100
    
    def start_sender(self,start, end):
        self.start, self.end = start, end
        bit_flips, basis_flips = self.distribute_bb84_states(
            self.sender, channels.quantum(receiver = self.receiver)
        )

        #logger.info(f"sender outcomes: {bit_flips}")
        #logger.info(f"sender theta: {basis_flips}")
        
        socket = channels.classical(sender = self.sender, receiver = self.receiver)
        error_rate = self.estimate_error_rate(socket,bit_flips, start, end)

        socket.send('1' if error_rate<=0.0 else '0')
        return {
            "error_rate" : error_rate,
            "secret_key" : self.basis_flips,
        }

    def start_receiver(self):
        bit_flips, basis_flips = self.distribute_bb84_states(
            self.receiver, channels.quantum(sender = self.receiver), sender = False
        )

        #logger.info(f"receiver outcomes: {bit_flips}")
        #logger.info(f"receiver theta: {basis_flips}")

        socket = channels.classical(sender = self.sender, receiver = self.receiver)
        error_rate = self.estimate_error_rate(socket,bit_flips, sender = False)
        
        accept_string = socket.recv()
        accept_key = True if accept_string == '1' else False
        return {
            "error_rate" : error_rate,
            "secret_key" : self.basis_flips,
            "accept" :  accept_key
        }

        

In [None]:
# this will insert in the middle of circuit
def eve(qc):
    result = []
    key_string = random.randint(0, 1)
    if key_string == 0:
        qc.x(0)
    elif key_string == 1:
        qc.h(0)
    # qiskit dynamic circuit here
    result.append(qc.run_measure(0,0))
    Eve.get_classical(Alice, )
    
    
    # todo: eve tracking classical channels

In [None]:
# which may include more function settings, all_func will always include sender and receiver variable to run analysis or any thing.
network.set_function(all_func, interception=eve)

## Action timeline

In [None]:
Alice.start_sender(timeline=timeline, receiver=Bob)
Bob.start_receiver(timeline=timeline, sender=Alice)

In [None]:
# how the overall circuit look like
network.construct_circuit().draw('mpl')

### start timeline

In [None]:
result = network.run(run_time=False)
print(result)
# sample output:
"""
{"Alice": {"error_rate": 6.2, "secret_key": xxxxxx}, "Bob": {"error_rate": 6.2, "secret_key": xxxxxx, "accept" :  0}  }
"""

In [None]:
#result.logger

In [None]:
# visualization the timeline and result with widget, like
#result

In [None]:
# or setup for individually
#alice = network.get_node('Alice')


# seperate for the node
#def alice_func(Alice):
#        msg_buff = []
#        distribute_bb84_states(alice, msg_buff, secret_key, network.get_node('eve'))
#        estimate_error_rate(alice,key, )
#
#def bob_func(bob):
#        msg_buff = []
#        distribute_bb84_states(bob, msg_buff, secret_key, network.get_node('eve'))
#        estimate_error_rate(bob,key, )