In [None]:
%pip install qiskit==1.2.4
%pip install qiskit-aer==0.15.1
%pip install pylatexenc==2.10

In [None]:
from qiskit import QuantumCircuit
from qiskit.converters import circuit_to_gate
from qiskit.visualization import array_to_latex
from qiskit.quantum_info import Operator
from qiskit.quantum_info import Statevector
from qiskit import transpile 
from qiskit.providers.basic_provider import BasicSimulator
from qiskit.visualization import plot_histogram
from qiskit.circuit import ControlledGate
import math 

# The aim of the assignment is to simulate the Ekert91 key distribution protocol.

# This notebook is for a simulation of the protocol with an attacker, to demonstrate that the attacker can be detected.

In [None]:
class CustomCircuit(QuantumCircuit):
    #custom class to implement w and v methods
    
    def w(self, q):
        #rotate value of q into W eigenbasis
        self.ry(-math.pi/4, q)        
        return self

    def v(self, q):
        #rotate value of q into -V eigenbases instead of V, its the same basis except it 
        #swaps the +1/-1 outcomes
        self.ry(+math.pi/4, q)
        return self

class RandomCircuit(QuantumCircuit):
    # custom class to implement the 1/3:2/3 probability for the randomizer
    third_prob_gate = UnitaryGate((1/np.sqrt(3))*np.array([[1, -np.sqrt(2)],[np.sqrt(2), 1]], dtype=complex),label="P")
    
    def p(self,q):
        self.append(self.third_prob_gate, [q])
        return self

In [None]:
#function creates entangled qubit pair
def initialize_entangled_circuit():
    qc = CustomCircuit(2,2)
    
    qc.h(0)
    qc.cx(0,1)
    qc.x(1)
    qc.z(0)

    return qc

#random number between 0 and 2 inclusive generator. Uses quantum probabilities,
# is not pseudo random (if implemented on a quantum chip in this case this is a simulator)
def random_number_3(sim):
    qc = RandomCircuit(2,2)

    #use of custom method on RandomCircuit. Delegates the 1/3:2/3 probabilities
    qc.p(0)
    qc.h(1)
    qc.measure([0,1],[0,1])

    transpiled = transpile(qc,sim)

    bits = sim.run(transpiled, shots=1, memory=True).result().get_memory()[0]
    
    bits = bits[::-1]

    if bits[0] == "0":
        return 0
    else:
        if bits[1] == "0":
            return 1
        else:
            return 2

In [None]:
#function to run through the 5 step protocol 9*N/2 times for measuring in random basis
def run_rounds(N,sim,A,B,rounds=None):
    if rounds is None:
        rounds = 9*N//2

    alice_choices = []
    bob_choices = []
    alice_meas = []
    bob_meas = []
    matches = []
    
    # 9*N was chosen to have a high probability that at least N matches would occur.
    # on average 9*N/2 would be enough
    for i in range(rounds):
        
        cs = initialize_entangled_circuit()
        a_choice = random_number_3(sim)
        b_choice = random_number_3(sim)
    
        alice_choices.append(a_choice)
        bob_choices.append(b_choice)
    
        if (a_choice == 1 and b_choice == 0) or (a_choice == 2 and b_choice == 1):
            matches.append(i)
        
        A[a_choice](cs,0)
        B[b_choice](cs,1)
        cs.measure([0,1],[0,1])
    
        transpiled = transpile(cs,sim)
        bits = sim.run(transpiled, shots=1, memory=True).result().get_memory()[0]

        bits = bits[::-1]
        
        alice_meas.append(int(bits[0]))
        bob_meas.append(int(bits[1]))
        
    return alice_choices, bob_choices, alice_meas, bob_meas, matches

def extract_keys(N, alice_meas, bob_meas, matches):
    alice_key=[]
    bob_key=[]
    mismatches = []
    
    for match in matches:
        if len(alice_key) >= N:
            break
        alice_key.append(alice_meas[match])
        bob_key.append(1-bob_meas[match])
        if alice_key[-1]!=bob_key[-1]:
            mismatches.append(match)
    return alice_key, bob_key, mismatches


def compute_statistic(alice_meas,bob_meas,alice_choices,bob_choices):
    totals = [0,0,0,0]
    counts = [0,0,0,0]
    averages = []
    
    for i in range(len(alice_meas)):
        val = (1 - 2*alice_meas[i])*(1-2*bob_meas[i]) 
        if alice_choices[i] == 0 and bob_choices[i] == 0:
            totals[0]+= val
            counts[0]+=1
        elif alice_choices[i] == 0 and bob_choices[i] == 2:
            totals[1]+=val
            counts[1]+=1
        elif alice_choices[i] == 2 and bob_choices[i] == 0:
            totals[2]+=val
            counts[2]+=1
        elif alice_choices[i] == 2 and bob_choices[i] == 2:
            totals[3]+=val
            counts[3]+=1
            
    for i in range(len(counts)):
        if counts[i] !=0:
            averages.append(totals[i]/counts[i])
        else:
            averages.append(0)
    
    stat = abs(averages[0]-averages[1]+averages[2]+averages[3])
    return stat, counts, averages

def validate_security(stat,counts):
    if min(counts)>0:
        # threshold is calculated this way as each average's sampling error scales like
        # 1/sqrt(count). This threshold adapts to the number of items in each bucket so it 
        # is better than just a set threshold
        threshold = 2 + math.sqrt(1/counts[0] + 1/counts[1] + 1/counts[2] + 1/counts[3])
    else:
        threshold = None

    if not threshold:
        return False,threshold
    
    if stat> threshold:
        return True,threshold
    else:
        return False, threshold

In [None]:
# A,B are Alice and bobs choices of measurement basis respectively. Implemented this 
# way to remove excess conditionals and index directly off the random number. 
A = [
    lambda q,j: q.h(j), 
     lambda q,j: q.w(j), 
     lambda q,j: None
]

B = [
    lambda q,j: q.w(j),
     lambda q,j: None,
     lambda q,j: q.v(j)
]


alice_key,bob_key,mismatches = extract_keys(N, alice_meas, bob_meas, matches)
stat,counts,averages = compute_statistic(alice_meas,bob_meas,alice_choices,bob_choices)

print("counts:", counts)
print("averages [XW, XV, ZW, ZV]:", averages)

secure, threshold = validate_security(stat,counts)

if secure:
    print("No intuder, S:", abs(stat))
elif threshold is None:
    print("Insufficient rounds ran, increase the number of rounds")
else:
    print("Intruder detected, S:", abs(stat))


print("Alice key: ", "".join(map(str, alice_key)))
print("Bob key: ", "".join(map(str, bob_key)))
print("Desired length of key:", N, ". Actual length of key:", len(alice_key))
print("Number of key mismatches: ", len(mismatches))


