# Grover algorithm

In [1]:
import cirq
import random


In [2]:
def buildOracle(inputQubits, outputQubit, x):
    yield(cirq.X(q) for (q, bit) in zip(inputQubits, x) if not bit)
    yield(cirq.TOFFOLI(inputQubits[0], inputQubits[1], outputQubit))
    yield(cirq.X(q) for (q, bit) in zip(inputQubits, x) if not bit)
    

In [3]:
def buildGroverCircuit(inputQubits, output, oracle):
    circuit = cirq.Circuit()

    circuit.append([cirq.X(output),
                    cirq.H(output),
                    cirq.H.on_each(*inputQubits)
                    ])

    circuit.append(oracle)

    # Construction of the Grover operator
    circuit.append(cirq.H.on_each(*inputQubits))
    circuit.append(cirq.X.on_each(*inputQubits))
    circuit.append(cirq.H.on(inputQubits[1]))
    circuit.append(cirq.CNOT(inputQubits[0], inputQubits[1]))
    circuit.append(cirq.H.on(inputQubits[1]))
    circuit.append(cirq.X.on_each(*inputQubits))
    circuit.append(cirq.H.on_each(*inputQubits))

    # Measure the result.
    circuit.append(cirq.measure(*inputQubits, key='result'))

    return circuit


Parameter setup

In [4]:
qubitCount = 2
circuitSampleCount = 10

inputQubits = [cirq.GridQubit(i, 0) for i in range(qubitCount)]
outputQubit = cirq.GridQubit(qubitCount, 0)


Set up x'

In [5]:
xBits = [random.randint(0, 1) for _ in range(qubitCount)]
print(xBits)


[1, 0]


Make oracle

In [6]:
oracle = buildOracle(inputQubits, outputQubit, xBits)


Build circuit that implement Grover's algorithm

In [7]:
circuit = buildGroverCircuit(inputQubits, outputQubit, oracle)
print(circuit)


(0, 0): ───H───────@───H───X───────────@───X───H───────M('result')───
                   │                   │               │
(1, 0): ───H───X───@───X───H───X───H───X───H───X───H───M─────────────
                   │
(2, 0): ───X───H───X─────────────────────────────────────────────────


Run a simulator


In [9]:
simulator = cirq.Simulator()
state = simulator.run(circuit, repetitions=circuitSampleCount)


Define folding function

In [10]:
def bitstring(bits):
    return ''.join(str(int(b)) for b in bits )


Get the result

In [12]:
frequencies = state.histogram(key='result', fold_func=bitstring)


In [16]:
print(frequencies)

Counter({'10': 10})


The most common bit string:

In [18]:
mostCommonBitString = frequencies.most_common(1)[0][0]
print(mostCommonBitString)


10


In [19]:
matched = (mostCommonBitString == bitstring(xBits))
print(matched)


True
