# Swap Test

We want to extract the quare of the inner product $|<a|b>|^2$

In [1]:
from sympy import sqrt
from sympy.physics.quantum.qubit import Qubit, measure_partial, matrix_to_qubit
from sympy.physics.quantum import TensorProduct
from sympy.physics.quantum.gate import HadamardGate, CGate, SwapGate
from sympy.physics.quantum.qapply import qapply
from sympy.physics.quantum.represent import represent

In [2]:
# example of inner products
a = Qubit("0")
b = Qubit("1")

r = (a.dual * b) ** 2
print(r)
r = r.doit()
print(r)

# example of inner products
a = Qubit("0")
b = Qubit("0")

r = (a.dual * b) ** 2
print(r)
r = r.doit()
print(r)

<0|1>**2
0
<0|0>**2
1


In [3]:
def SwapTest(a, b):
    print(f"a={a} b={b}")
    s = TensorProduct(TensorProduct(a, b), 1 / sqrt(2) * (Qubit(0) + Qubit(1)))
    s = qapply(s)
    s = matrix_to_qubit(represent(s))
    print(f"s={s}")

    # *note that qubits are indexed from right to left*
    g = CGate(0, SwapGate(1, 2))
    c = g * s
    print(c)
    c = HadamardGate(0) * c
    print(c)
    c = qapply(c)
    measure = measure_partial(c, [0])
    print(f"measure={measure}\n")


a = Qubit("0")
b = Qubit("0")
SwapTest(a, b)

a = Qubit("1")
b = Qubit("1")
SwapTest(a, b)

a = Qubit("1")
b = Qubit("0")
SwapTest(a, b)

a = Qubit("0")
b = Qubit("1")
SwapTest(a, b)

a=|0> b=|0>
s=sqrt(2)*|000>/2 + sqrt(2)*|001>/2
C((0),SWAP(1,2))*(sqrt(2)*|000>/2 + sqrt(2)*|001>/2)
H(0)*C((0),SWAP(1,2))*(sqrt(2)*|000>/2 + sqrt(2)*|001>/2)
measure=[(|000>, 1)]

a=|1> b=|1>
s=sqrt(2)*|110>/2 + sqrt(2)*|111>/2
C((0),SWAP(1,2))*(sqrt(2)*|110>/2 + sqrt(2)*|111>/2)
H(0)*C((0),SWAP(1,2))*(sqrt(2)*|110>/2 + sqrt(2)*|111>/2)
measure=[(|110>, 1)]

a=|1> b=|0>
s=sqrt(2)*|100>/2 + sqrt(2)*|101>/2
C((0),SWAP(1,2))*(sqrt(2)*|100>/2 + sqrt(2)*|101>/2)
H(0)*C((0),SWAP(1,2))*(sqrt(2)*|100>/2 + sqrt(2)*|101>/2)
measure=[(sqrt(2)*|010>/2 + sqrt(2)*|100>/2, 1/2), (-sqrt(2)*|011>/2 + sqrt(2)*|101>/2, 1/2)]

a=|0> b=|1>
s=sqrt(2)*|010>/2 + sqrt(2)*|011>/2
C((0),SWAP(1,2))*(sqrt(2)*|010>/2 + sqrt(2)*|011>/2)
H(0)*C((0),SWAP(1,2))*(sqrt(2)*|010>/2 + sqrt(2)*|011>/2)
measure=[(sqrt(2)*|010>/2 + sqrt(2)*|100>/2, 1/2), (sqrt(2)*|011>/2 - sqrt(2)*|101>/2, 1/2)]

