In [1]:
import random
import re

from qiskit_aer import Aer
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile

%reload_ext jupyter_black

In [2]:
numberOfSinglets = 1000

In [3]:
qr = QuantumRegister(2)
cr = ClassicalRegister(2)

In [4]:
singlet = QuantumCircuit(qr, cr)
singlet.h(qr[0])
singlet.cx(qr[0], qr[1])

<qiskit.circuit.instructionset.InstructionSet at 0x722101fcbb50>

In [5]:
## Verifier
# X
measureAV0 = QuantumCircuit(qr, cr)
measureAV0.h(qr[0])
measureAV0.measure(qr[0], cr[0])

# W
measureAV1 = QuantumCircuit(qr, cr)
measureAV1.s(qr[0])
measureAV1.h(qr[0])
measureAV1.t(qr[0])
measureAV1.h(qr[0])
measureAV1.measure(qr[0], cr[0])

# Z
measureAV2 = QuantumCircuit(qr, cr)
measureAV2.measure(qr[0], cr[0])

## Prover
# W
measureAP0 = QuantumCircuit(qr, cr)
measureAP0.s(qr[1])
measureAP0.h(qr[1])
measureAP0.t(qr[1])
measureAP0.h(qr[1])
measureAP0.measure(qr[1], cr[1])

# Z
measureAP1 = QuantumCircuit(qr, cr)
measureAP1.measure(qr[1], cr[1])

# V
measureAP2 = QuantumCircuit(qr, cr)
measureAP2.s(qr[1])
measureAP2.h(qr[1])
measureAP2.tdg(qr[1])
measureAP2.h(qr[1])
measureAP2.measure(qr[1], cr[1])

avBases = [measureAV0, measureAV1, measureAV2]
apBases = [measureAP0, measureAP1, measureAP2]

In [6]:
av = [random.randint(0, 2) for i in range(numberOfSinglets)]
ap = [random.randint(0, 2) for i in range(numberOfSinglets)]

b = [random.randint(0, 1) for i in range(numberOfSinglets)]

In [7]:
circuits = []

for i in range(numberOfSinglets):
    circuitName = str(i) + ":V" + str(av[i]) + "_P" + str(ap[i])

    combined_circuit = singlet.compose(avBases[av[i]]).compose(apBases[ap[i]])
    combined_circuit.name = circuitName

    circuits.append(combined_circuit)

In [8]:
aer_sim = Aer.get_backend("aer_simulator")
result = aer_sim.run(transpile(circuits, aer_sim), shots=1, memory=True).result()

In [9]:
abPatterns = [
    re.compile("00$"),
    re.compile("01$"),
    re.compile("10$"),
    re.compile("11$"),
]

In [10]:
VerifierResults = []
ProverResults = []

for i in range(numberOfSinglets):
    res = list(result.get_counts(circuits[i]).keys())[0]

    if abPatterns[0].search(res):
        VerifierResults.append(-1)
        ProverResults.append(-1)
    if abPatterns[1].search(res):
        VerifierResults.append(1)
        ProverResults.append(-1)
    if abPatterns[2].search(res):
        VerifierResults.append(-1)
        ProverResults.append(1)
    if abPatterns[3].search(res):
        VerifierResults.append(1)
        ProverResults.append(1)

In [11]:
m = []
m_prime = []

for i in circuits:
    memory = result.get_memory(i)

    m_result = int(memory[0][0])
    m_prime_result = int(memory[0][1])

    m.append(m_result)
    m_prime.append(m_prime_result)

In [12]:
def encode_message(bits, bases, n):
    message = []
    for i in range(n):
        qc = QuantumCircuit(1, 1)
        if bases[i] == 0:
            if bits[i] == 0:
                pass
            else:
                qc.x(0)
        else:
            if bits[i] == 0:
                qc.h(0)
            else:
                qc.x(0)
                qc.h(0)
        qc.barrier()
        message.append(qc)
    return message


def decode_message(message, bases, n, draw_circuit=False):
    backend = Aer.get_backend("aer_simulator")
    __ = backend
    measurements = []
    for q in range(n):
        if bases[q] == 1:
            message[q].h(0)
        message[q].measure(0, 0)

        if draw_circuit:
            print(f"Circuit {q}:")
            display(message[q].draw(output="mpl"))

        aer_sim = Aer.get_backend("aer_simulator")
        result = aer_sim.run(
            transpile(message[q], aer_sim), shots=1, memory=True
        ).result()
        measured_bit = int(result.get_memory()[0])
        measurements.append(measured_bit)
    return measurements

In [13]:
message = encode_message(m_prime, b, numberOfSinglets)

In [14]:
m_prime_two = decode_message(message, b, numberOfSinglets, False)

In [None]:
def chsh_corr(result):
    countA0B0, countA0B2, countA2B0, countA2B2 = (
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
    )

    for i in range(numberOfSinglets):
        res = list(result.get_counts(circuits[i]).keys())[0]

        index = int(res, 2)

        if av[i] == 0 and ap[i] == 0:
            countA0B0[index] += 1
        elif av[i] == 0 and ap[i] == 2:
            countA0B2[index] += 1
        elif av[i] == 2 and ap[i] == 0:
            countA2B0[index] += 1
        elif av[i] == 2 and ap[i] == 2:
            countA2B2[index] += 1

    def expectation(counts):
        total = sum(counts)
        if total == 0:
            return 0
        return (counts[0] - counts[1] - counts[2] + counts[3]) / total

    expect00 = expectation(countA0B0)
    expect02 = expectation(countA0B2)
    expect20 = expectation(countA2B0)
    expect22 = expectation(countA2B2)

    corr = expect00 - expect02 + expect20 + expect22

    return corr

In [None]:
corr = chsh_corr(result)
print("CHSH correlation value: " + str(round(corr, 3)))


def compare_results(m, m_prime_two, av, ap):
    matches = 0
    mismatches = 0
    for i in range(len(av)):
        if (av[i] == 1 and ap[i] == 0) or (av[i] == 2 and ap[i] == 1):
            if m[i] == m_prime_two[i]:
                matches += 1
            else:
                mismatches += 1

    return matches, mismatches


matches, mismatches = compare_results(m, m_prime_two, av, ap)
print("Number of matches:", matches)
print("Number of mismatches:", mismatches)

CHSH correlation value: 2.92
Number of matches: 223
Number of mismatches: 0
