In [None]:
import numpy as np
from numpy.random import randint
from rotation_functions import *

%reload_ext jupyter_black

In [None]:
n = 128

## Step 1
# Alice generates bits
alice_bits = randint(2, size=n)

## Step 2
# Create an array to tell us which qubits
# are encoded in which bases
alice_bases = randint(2, size=n)
message = encode_message(alice_bits, alice_bases, n)

## Step 3
# Decide which basis to measure in:
bob_bases = randint(2, size=n)
bob_results = decode_message(message, bob_bases, n)

## Step 4
alice_key = remove_garbage(alice_bases, bob_bases, alice_bits, n)
bob_key = remove_garbage(alice_bases, bob_bases, bob_results, n)

## Step 5
# Sample a subset of the bits to compare
sample_size = int(n * 0.1)
bit_selection = randint(n, size=sample_size)

alice_sample = sample_bits(alice_key, bit_selection)
print("Alice's sample:   ", alice_sample)
bob_sample = sample_bits(bob_key, bit_selection)
print("Bob's sample:     ", bob_sample)

bob_sample == alice_sample
print("Alice's key:      ", alice_key)
print("Bob's key:        ", bob_key)
print("Key lengths:      ", len(alice_key), len(bob_key))
print("Are the keys equal?", alice_key == bob_key)

In [None]:
otp = alice_key

# Ensure the length of the OTP is divisible by 2 for equal distribution
otp_length = len(otp) - len(otp) % 2
otp = otp[:otp_length]  # Truncate the OTP to make it divisible by 2

# Calculate the length for a_bases and b_bases
segment_length = otp_length // 2

# Extract a_bases and b_bases from the OTP
a_bases = otp[:segment_length]
b_bases = otp[segment_length : 2 * segment_length]

# Generate c_bits , in future with a QRNG
c_bits = randint(2, size=segment_length)

print("a_bases: ", a_bases)
print("b_bases: ", b_bases)
print("c_bits:  ", c_bits)

In [None]:
# Encode c in a
message = encode_message(c_bits, a_bases, segment_length)

In [None]:
# Prover compares a and b bases and applies a rotation if they are different
message_back = apply_specific_rotation(a_bases, b_bases, message)

In [None]:
# Verifier measures the qubits in the b basis
c_prime = decode_message(message_back, b_bases, segment_length)

In [None]:
# Verifier compares c and c'
print("\nc bits:", c_bits)
print("\nc' bits:", c_prime)
print(np.array_equal(c_bits, c_prime))