# Grover Algorithm

Import needed modules

In [6]:
"""This module implements Grover algorithm"""

from sympy import sqrt, Mul
from sympy.physics.quantum.qapply import qapply
from sympy.physics.quantum.qubit import (
    Qubit,
    IntQubit,
    matrix_to_qubit,
    measure_partial,
)
from sympy.physics.quantum import TensorProduct
from sympy.physics.quantum.gate import HadamardGate
from sympy.physics.quantum.represent import represent

from util.util import hn, get_sub_state
from oracle import oracle

In [7]:
def inversion_about_mean(state):
    """
    calculate inversion about mean of input state

    :param state state: state for which we want the inversion
    """
    factors = []
    # get avg
    mean = 0
    for e in state.args:
        aa = [e for e in e.args if isinstance(e, Qubit) is False]
        factors.append(Mul(*aa))
        mean += Mul(*aa)
    mean = (mean / len(factors)).doit()

    # apply inversion about mean
    r = 0
    for s, f in zip(state.args, factors):
        aa = [e for e in s.args if isinstance(e, Qubit)]
        r += Mul(*aa, -f + 2 * mean)

    return r

## The problem:
Given a function `f {0, 1}^n -> {0, 1}` that we can evaluate 

The oracle takes as input Qubits |x> and|y> and returns |x, y XOR f(x)>

## The solution:
```
                    +--------------------------------------------------------+
                    |   sqrt(2**n) times                                     |
                    |            +-----------------+         +---------+     |
  |x>  |0>-/n--H*n--|------------|                 |---\n----| -I - 2A |--\n-|----M
                    |            |       U_f       |         +---------+     |
  |y>               |  |1>---H---|                 |------- |y XOR f(x)>     |
                    |            +-----------------+                         |
                    +--------------------------------------------------------+
```                    
We obtain a high probability on the input chosen by f


In [8]:
def grover(f, n):
    """
    Run Grover algorithm and return |x>*|y XOR f(x)>
    :param func f: oracle function
    :param int n: string length

    """
    # apply H*n gate to |x>
    x = qapply(hn(n) * Qubit("0" * n))
    print(f"|x>: {x}")

    # iterate over phase inversion block
    for i in range(int(sqrt(2**n))):
        print(f"\nIter {i}")
        y = qapply(HadamardGate(0) * Qubit(1))
        print(f"|y>: {y}")

        xy = TensorProduct(x, y)
        xy = matrix_to_qubit(represent(xy))
        print(f"|xy>: {xy}")

        # apply oracle
        state = oracle(x, y, f)

        # we measure the y bit
        measure = measure_partial(state, (0,))

        # get x component of the tensor product, remove y
        x = get_sub_state(measure[0][0], 0, n)

        x = inversion_about_mean(x)
        print(f"|x>: {x}")
    return x

## Test

In [None]:
def test_0():
    def f(x, *args):
        x = IntQubit(Qubit(*x)).as_int()
        match x:
            case 0b101:
                return 1
            case _:
                return 0

    n = 3
    r = grover(f, n)
    truth = (
        -sqrt(2) * Qubit("000") / 16
        - sqrt(2) * Qubit("001") / 16
        - sqrt(2) * Qubit("010") / 16
        - sqrt(2) * Qubit("011") / 16
        - sqrt(2) * Qubit("100") / 16
        + 11 * sqrt(2) * Qubit("101") / 16
        - sqrt(2) * Qubit("110") / 16
        - sqrt(2) * Qubit("111") / 16
    )
    assert r == truth


test_0()

|x>: sqrt(2)*|000>/4 + sqrt(2)*|001>/4 + sqrt(2)*|010>/4 + sqrt(2)*|011>/4 + sqrt(2)*|100>/4 + sqrt(2)*|101>/4 + sqrt(2)*|110>/4 + sqrt(2)*|111>/4

Iter 0
|y>: sqrt(2)*|0>/2 - sqrt(2)*|1>/2
|xy>: |0000>/4 - |0001>/4 + |0010>/4 - |0011>/4 + |0100>/4 - |0101>/4 + |0110>/4 - |0111>/4 + |1000>/4 - |1001>/4 + |1010>/4 - |1011>/4 + |1100>/4 - |1101>/4 + |1110>/4 - |1111>/4
|x>: sqrt(2)*|000>/8 + sqrt(2)*|001>/8 + sqrt(2)*|010>/8 + sqrt(2)*|011>/8 + sqrt(2)*|100>/8 + 5*sqrt(2)*|101>/8 + sqrt(2)*|110>/8 + sqrt(2)*|111>/8

Iter 1
|y>: sqrt(2)*|0>/2 - sqrt(2)*|1>/2
|xy>: |0000>/8 - |0001>/8 + |0010>/8 - |0011>/8 + |0100>/8 - |0101>/8 + |0110>/8 - |0111>/8 + |1000>/8 - |1001>/8 + 5*|1010>/8 - 5*|1011>/8 + |1100>/8 - |1101>/8 + |1110>/8 - |1111>/8
|x>: -sqrt(2)*|000>/16 - sqrt(2)*|001>/16 - sqrt(2)*|010>/16 - sqrt(2)*|011>/16 - sqrt(2)*|100>/16 + 11*sqrt(2)*|101>/16 - sqrt(2)*|110>/16 - sqrt(2)*|111>/16


In [10]:
def test_1():
    def f(x, *args):
        x = IntQubit(Qubit(*x)).as_int()
        match x:
            case 0b1101:
                return 1
            case _:
                return 0

    n = 4
    r = grover(f, n)
    truth = (
        -171 * Qubit("0000") / 1024
        - 171 * Qubit("0001") / 1024
        - 171 * Qubit("0010") / 1024
        - 171 * Qubit("0011") / 1024
        - 171 * Qubit("0100") / 1024
        - 171 * Qubit("0101") / 1024
        - 171 * Qubit("0110") / 1024
        - 171 * Qubit("0111") / 1024
        - 171 * Qubit("1000") / 1024
        - 171 * Qubit("1001") / 1024
        - 171 * Qubit("1010") / 1024
        - 171 * Qubit("1011") / 1024
        - 171 * Qubit("1100") / 1024
        + 781 * Qubit("1101") / 1024
        - 171 * Qubit("1110") / 1024
        - 171 * Qubit("1111") / 1024
    )
    assert r == truth


test_1()

|x>: |0000>/4 + |0001>/4 + |0010>/4 + |0011>/4 + |0100>/4 + |0101>/4 + |0110>/4 + |0111>/4 + |1000>/4 + |1001>/4 + |1010>/4 + |1011>/4 + |1100>/4 + |1101>/4 + |1110>/4 + |1111>/4

Iter 0
|y>: sqrt(2)*|0>/2 - sqrt(2)*|1>/2
|xy>: sqrt(2)*|00000>/8 - sqrt(2)*|00001>/8 + sqrt(2)*|00010>/8 - sqrt(2)*|00011>/8 + sqrt(2)*|00100>/8 - sqrt(2)*|00101>/8 + sqrt(2)*|00110>/8 - sqrt(2)*|00111>/8 + sqrt(2)*|01000>/8 - sqrt(2)*|01001>/8 + sqrt(2)*|01010>/8 - sqrt(2)*|01011>/8 + sqrt(2)*|01100>/8 - sqrt(2)*|01101>/8 + sqrt(2)*|01110>/8 - sqrt(2)*|01111>/8 + sqrt(2)*|10000>/8 - sqrt(2)*|10001>/8 + sqrt(2)*|10010>/8 - sqrt(2)*|10011>/8 + sqrt(2)*|10100>/8 - sqrt(2)*|10101>/8 + sqrt(2)*|10110>/8 - sqrt(2)*|10111>/8 + sqrt(2)*|11000>/8 - sqrt(2)*|11001>/8 + sqrt(2)*|11010>/8 - sqrt(2)*|11011>/8 + sqrt(2)*|11100>/8 - sqrt(2)*|11101>/8 + sqrt(2)*|11110>/8 - sqrt(2)*|11111>/8
|x>: 3*|0000>/16 + 3*|0001>/16 + 3*|0010>/16 + 3*|0011>/16 + 3*|0100>/16 + 3*|0101>/16 + 3*|0110>/16 + 3*|0111>/16 + 3*|1000>/16 + 3*|