In [None]:
# bb84_simulation.py
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from qiskit_aer import AerSimulator
from qiskit_aer.noise import NoiseModel, depolarizing_error
import random
from hashlib import sha256

def random_bits_string(n):
    return ''.join(random.choice('01') for _ in range(n))

def random_basis(n):
    return ''.join(random.choice('ZX') for _ in range(n))

def encode_qubits(alice_bits: str, alice_bases: str):
    """
    Prepare qubits according to alice_bits and alice_bases.
    - If bit == '1' -> apply X (prepare |1>)
    - If alice_bases[i] == 'X' -> apply H (|0/1> -> |+/->)
    Returns: QuantumCircuit, qreg, creg
    """
    assert len(alice_bits) == len(alice_bases)
    n = len(alice_bits)
    qreg = QuantumRegister(n, 'q')
    creg = ClassicalRegister(n, 'c')  # measurement bits for Bob
    qc = QuantumCircuit(qreg, creg, name='alice_encode')

    for i, (bit, basis) in enumerate(zip(alice_bits, alice_bases)):
        if bit == '1':
            qc.x(qreg[i])
        if basis == 'X':
            qc.h(qreg[i])

    # Do NOT measure here. Bob measures later.
    return qc, qreg, creg

def measure_qubits(qc: QuantumCircuit, qreg: QuantumRegister, creg: ClassicalRegister, bob_bases: str):
    """
    Apply basis rotation (H for X) then measure q[i] -> c[i].
    """
    for i, b in enumerate(bob_bases):
        if b == 'X':
            qc.h(qreg[i])
    qc.measure(qreg, creg)
    return qc

def run_simulator(qc: QuantumCircuit, shots=1024, noise_model=None):
    """
    Run AerSimulator and return memory (list of shot bitstrings) and counts.
    """
    simulator = AerSimulator(noise_model=noise_model)
    transpiled = transpile(qc, simulator)
    job = simulator.run(transpiled, memory=True, shots=shots)
    result = job.result()
    memory = result.get_memory(transpiled)  # list of bitstrings (one string per shot)
    counts = result.get_counts(transpiled)
    return memory, counts

def sift_keys(alice_bits: str, alice_bases: str, bob_bases: str, bob_shot: str):
    """
    Single-shot sifting: keep positions where alice_bases == bob_bases.
    bob_shot is a bitstring where position i corresponds to qubit i (because measure(q[i], c[i]) was used).
    Returns (alice_sift, bob_sift) as strings.
    """
    assert len(alice_bits) == len(alice_bases) == len(bob_bases) == len(bob_shot)
    alice_sift = []
    bob_sift = []
    for i in range(len(alice_bits)):
        if alice_bases[i] == bob_bases[i]:
            alice_sift.append(alice_bits[i])
            bob_sift.append(bob_shot[i])
    return ''.join(alice_sift), ''.join(bob_sift)

def compute_qber(alice_bits: str, alice_bases: str, bob_bases: str, bob_memory_list):
    """
    Compute QBER across all shots using only positions where bases match.
    """
    total_matching = 0
    total_errors = 0
    if not bob_memory_list:
        return 0.0
    for shot in bob_memory_list:
        for i in range(len(alice_bits)):
            if alice_bases[i] == bob_bases[i]:
                total_matching += 1
                if shot[i] != alice_bits[i]:
                    total_errors += 1
    if total_matching == 0:
        return 0.0
    return total_errors / total_matching

def bb84(n=29, shots=1024, use_noise=False):
    alice_bits = random_bits_string(n)
    alice_bases = random_basis(n)
    bob_bases = random_basis(n)

    qc, qreg, creg = encode_qubits(alice_bits, alice_bases)
    qc = measure_qubits(qc, qreg, creg, bob_bases)

    noise_model = None
    if use_noise:
        err = depolarizing_error(0.01, 1)
        noise_model = NoiseModel()
        noise_model.add_all_qubit_quantum_error(err, ['x', 'h'])

    memory, counts = run_simulator(qc, shots=shots, noise_model=noise_model)

    # Example: use the first shot as demonstration of sifting
    first_shot = memory[0]
    alice_sift, bob_sift = sift_keys(alice_bits, alice_bases, bob_bases, first_shot)

  
    qber_value = compute_qber(alice_bits, alice_bases, bob_bases, memory)

    return {
        'alice_bits': alice_bits,
        'alice_bases': alice_bases,
        'bob_bases': bob_bases,
        'first_shot': first_shot,
        'alice_sift': alice_sift,
        'bob_sift': bob_sift,
        'qber': qber_value, 
        'counts': counts
    }

def privacy_amplify(key_str):
    return sha256(key_str.encode()).hexdigest()

if __name__ == '__main__':
    QBER_THRESHOLD = 0.11  
    out = bb84(n=29, shots=2048, use_noise=True)
    print("Alice bits: ", out['alice_bits'])
    print("Alice bases:", out['alice_bases'])
    print("Bob bases:  ", out['bob_bases'])
    print("First shot:  ", out['first_shot'])
    print("Alice sift:  ", out['alice_sift'])
    print("Bob sift:    ", out['bob_sift'])
    print("QBER:        ", out['qber'])
    if out['alice_sift']:
        print("Amplified:   ", privacy_amplify(out['alice_sift']))
    else:
        print("No sifted key (all bases mismatched).")
    if(out['qber']>QBER_THRESHOLD):
        print("EVA detected" )


Alice bits:  00100110100010001110111001111
Alice bases: XZZZXZZXZXZZXXZXZZXXZZXZZXZXX
Bob bases:   XXXXXZZZZZZXXZXZZXXXXXXXXXZZX
First shot:   10111111101010111001101101000
Alice sift:   00111011101111
Bob sift:     11111111011100
QBER:         0.5013602120535714
Amplified:    d87490977a9dbc6c58fc7d227f4c468ad2b0fa85bf9a1e1eee44cf4d412afbe5
EVA detected
