# Test Shor's algorithm with Cirq


Setup requirements, import and verify required libraries. Source links and more information at the end of this notebook.

In [1]:
try:
    import cirq
except ImportError:
    print("installing cirq...")
    !pip install --quiet cirq
    print("installed cirq.")

In [2]:
import fractions
import math
import random
import matplotlib.pyplot as plt
import numpy as np
import sympy
from typing import Callable, List, Optional, Sequence, Union

import cirq
#import cirq_google

print(cirq.google.Foxtail)  
# verify cirq, should print a circuit as shown belown
# (0, 0)───(0, 1)───(0, 2)───(0, 3)───(0, 4)───(0, 5)───(0, 6)───(0, 7)───(0, 8)───(0, 9)───(0, 10)
# │        │        │        │        │        │        │        │        │        │        │
# │        │        │        │        │        │        │        │        │        │        │
# (1, 0)───(1, 1)───(1, 2)───(1, 3)───(1, 4)───(1, 5)───(1, 6)───(1, 7)───(1, 8)───(1, 9)───(1, 10)

(0, 0)───(0, 1)───(0, 2)───(0, 3)───(0, 4)───(0, 5)───(0, 6)───(0, 7)───(0, 8)───(0, 9)───(0, 10)
│        │        │        │        │        │        │        │        │        │        │
│        │        │        │        │        │        │        │        │        │        │
(1, 0)───(1, 1)───(1, 2)───(1, 3)───(1, 4)───(1, 5)───(1, 6)───(1, 7)───(1, 8)───(1, 9)───(1, 10)


 it will be removed in cirq v0.14.
 Use cirq_google instead.

  if sys.path[0] == '':


# Understand Order Finding


## Case A: Classical Order Finding

A function for classically computing the order $r$ of an element $x \in \mathbb{Z}_n$ is provided below. This function simply computes the sequence 

$$ x^2 \text{ mod } n $$
$$ x^3 \text{ mod } n $$
$$ x^4 \text{ mod } n $$
$$ \vdots $$

until an integer $r$ is found such that $x^r = 1 \text{ mod } n$. Since $|\mathbb{Z}_n| = \phi(n)$, this algorithm for order finding has time complexity $O(\phi(n))$ which is inefficient. (Roughly $O(2^{L / 2})$ where $L$ is the number of bits in $n$.)

In [3]:
def classical_order_finder(x: int, n: int) -> Optional[int]:

    # Make sure x is both valid and in Z_n.
    if x < 2 or x >= n or math.gcd(x, n) > 1:
        raise ValueError(f"Invalid x={x} for modulus n={n}.")
    
    # Determine the order.
    r, y = 1, x
    while y != 1:
        y = (x * y) % n
        r += 1
    return r

An example of classically computing $r$ for a given $x \in \mathbb{Z}_n$ and given $n$ is shown in the code block below.

In [4]:
n = 15  # The multiplicative group is [1, 2, 4, 7, 8, 11, 13, 14].
x = 8
r = classical_order_finder(x, n)

# Check that the order is indeed correct.
print(f"x^r mod n = {x}^{r} mod {n} = {x**r % n}")

x^r mod n = 8^4 mod 15 = 1


The quantum part of Shor's algorithm is order finding, but done via a quantum circuit, which is discussed below.

Quantum order finding is quantum phase estimation with [unitary $U$](https://en.wikipedia.org/wiki/Unitary_matrix) that computes the modular exponential function $f_x(z)$ for some randomly chosen $x \in \mathbb{Z}_n$. In this tutorial, use arithmetic operations in Cirq which can implement such a unitary $U$. To start we define and use a Modular Expontial operation.

<img src="https://i1.wp.com/www.physicsmindboggler.co/wp-content/uploads/2020/10/Screenshot-2020-10-26-212025-1.png?w=462&ssl=1">

### Define Modular Exponential

The key part of the code block below is the `apply` method which defines the arithmetic operation.

The `apply` method evaluates `(target * base**exponent) % modulus`. The `target` and the `exponent` depend on the values of the respective qubit registers, and the `base` and `modulus` are constant -- namely, the `modulus` is $n$ and the `base` is some $x \in \mathbb{Z}_n$. 

In [5]:
class ModularExp(cirq.ArithmeticOperation):
    
    def __init__(
        self, 
        target: Sequence[cirq.Qid],
        exponent: Union[int, Sequence[cirq.Qid]], 
        base: int,
        modulus: int
    ) -> None:
        if len(target) < modulus.bit_length():
            raise ValueError(f'Register with {len(target)} qubits is too small '
                             f'for modulus {modulus}')
        self.target = target
        self.exponent = exponent
        self.base = base
        self.modulus = modulus

    def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]:
        return self.target, self.exponent, self.base, self.modulus

    def with_registers(
            self,
            *new_registers: Union[int, Sequence['cirq.Qid']],
    ) -> cirq.ArithmeticOperation:
        if len(new_registers) != 4:
            raise ValueError(f'Expected 4 registers (target, exponent, base, '
                             f'modulus), but got {len(new_registers)}')
        target, exponent, base, modulus = new_registers
        if not isinstance(target, Sequence):
            raise ValueError(
                f'Target must be a qubit register, got {type(target)}')
        if not isinstance(base, int):
            raise ValueError(
                f'Base must be a classical constant, got {type(base)}')
        if not isinstance(modulus, int):
            raise ValueError(
                f'Modulus must be a classical constant, got {type(modulus)}')
        return ModularExp(target, exponent, base, modulus)

    def apply(self, *register_values: int) -> int:
        assert len(register_values) == 4
        target, exponent, base, modulus = register_values
        if target >= modulus:
            return target
        return (target * base**exponent) % modulus

    def _circuit_diagram_info_(
            self,
            args: cirq.CircuitDiagramInfoArgs,
    ) -> cirq.CircuitDiagramInfo:
        assert args.known_qubits is not None
        wire_symbols: List[str] = []
        t, e = 0, 0
        for qubit in args.known_qubits:
            if qubit in self.target:
                if t == 0:
                    if isinstance(self.exponent, Sequence):
                        e_str = 'e'
                    else:
                        e_str = str(self.exponent)
                    wire_symbols.append(
                        f'ModularExp(t*{self.base}**{e_str} % {self.modulus})')
                else:
                    wire_symbols.append('t' + str(t))
                t += 1
            if isinstance(self.exponent, Sequence) and qubit in self.exponent:
                wire_symbols.append('e' + str(e))
                e += 1
        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))

Use a total number of qubits of $3 (L + 1)$ where $L$ is the number of bits needed to store the integer $n$ to factor. The size of the unitary which implements the modular exponential is thus $4^{3(L + 1)}$. For $n = 15$, the unitary requires storing $2^{30}$ floating point numbers in memory which is out of reach of most current standard laptops.

In [6]:
n = 15
L = n.bit_length()

# The target register has L qubits.
target = cirq.LineQubit.range(L)

# The exponent register has 2L + 3 qubits.
exponent = cirq.LineQubit.range(L, 3 * L + 3)

# Display the total number of qubits to factor this n.
print(f"To factor n = {n} which has L = {L} bits, we need 3L + 3 = {3 * L + 3} qubits.")

To factor n = 15 which has L = 4 bits, we need 3L + 3 = 15 qubits.


### Use ModularExp in a circuit

The quantum part of Shor's algorithm is phase estimation with the unitary $U$ corresponding to the modular exponential operation. The following cell defines a function which creates the circuit for Shor's algorithm using the `ModularExp` operation defined above.

Using this function to visualize the circuit for a given $x$ and $n$ as follows.

In [7]:
def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:
    L = n.bit_length()
    target = cirq.LineQubit.range(L)
    exponent = cirq.LineQubit.range(L, 3 * L + 3)
    return cirq.Circuit(
        cirq.X(target[L - 1]),
        cirq.H.on_each(*exponent),
        ModularExp(target, exponent, x, n),
        cirq.qft(*exponent, inverse=True),
        cirq.measure(*exponent, key='exponent'),
    )

n = 15
x = 7
circuit = make_order_finding_circuit(x, n)
print(circuit)

0: ────────ModularExp(t*7**e % 15)────────────────────────────
           │
1: ────────t1─────────────────────────────────────────────────
           │
2: ────────t2─────────────────────────────────────────────────
           │
3: ────X───t3─────────────────────────────────────────────────
           │
4: ────H───e0────────────────────────qft^-1───M('exponent')───
           │                         │        │
5: ────H───e1────────────────────────#2───────M───────────────
           │                         │        │
6: ────H───e2────────────────────────#3───────M───────────────
           │                         │        │
7: ────H───e3────────────────────────#4───────M───────────────
           │                         │        │
8: ────H───e4────────────────────────#5───────M───────────────
           │                         │        │
9: ────H───e5────────────────────────#6───────M───────────────
           │                         │        │
10: ───H───e6─────────────────

Put the exponent register into an equal superposition via Hadamard gates. The $X$ gate on the last qubit in the target register is used for phase kickback. The modular exponential operation performs the sequence of controlled unitaries in phase estimation, then apply the inverse quantum Fourier transform (inv-QFT) to the exponent register and measure to read out the result.

To illustrate the measurement results, sample from a smaller circuit. (Note that in practice we would never run Shor's algorithm with $n = 6$ because it is even. This is just an example to illustrate the measurement outcomes.)

Interpret each measured bitstring as an integer, but what do these integers tell us? In the next section use a classically post-process to interpret them.

In [8]:
circuit = make_order_finding_circuit(x=5, n=6)
res = cirq.sample(circuit, repetitions=8)

print("Raw measurements:")
print(res)

print("\nInteger in exponent register:")
print(res.data)

def binary_labels(num_qubits):
    return [bin(x)[2:].zfill(num_qubits) for x in range(2 ** num_qubits)]

# Refactor this plot to use this cell's output
# q = cirq.LineQubit.range(3)
# circuit = cirq.Circuit([cirq.H.on_each(*q), cirq.measure(*q)])
# result = cirq.Simulator().run(circuit, repetitions=100)
# _ = cirq.vis.plot_state_histogram(result, plt.subplot(), title = 'Integer in exponent Register', xlabel = 'Integer', ylabel = 'Count', tick_label=binary_labels(3))

Raw measurements:
exponent=10110100, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000

Integer in exponent register:
   exponent
0       256
1         0
2       256
3       256
4         0
5       256
6         0
7         0


### Use Classical post-processing

The integer measured is close to $s / r$ where $r$ is the order of $x \in \mathbb{Z}_n$ and $0 \le s < r$ is an integer. Use the continued fractions algorithm to determine $r$ from $s / r$ then return it if the order finding circuit succeeded, else return `None`.

In [9]:
def process_measurement(result: cirq.Result, x: int, n: int) -> Optional[int]:
    # Read the output integer of the exponent register.
    exponent_as_integer = result.data["exponent"][0]
    exponent_num_bits = result.measurements["exponent"].shape[1]
    eigenphase = float(exponent_as_integer / 2**exponent_num_bits)

    # Run the continued fractions algorithm to determine f = s / r.
    f = fractions.Fraction.from_float(eigenphase).limit_denominator(n)
    
    if f.numerator == 0:
        return None
    
    r = f.denominator
    if x**r % n != 1:
        return None
    return r

The next code block shows an example of creating an order finding circuit, executing it, then using the classical postprocessing function to determine the order. Recall that the quantum part of the algorithm succeeds with some probability. If the order is `None`, try re-running the cell a few times. 

In [10]:
n = 6
x = 5

print(f"Finding the order of x = {x} modulo n = {n}\n")
measurement = cirq.sample(circuit, repetitions=1)
print("Raw measurements:")
print(measurement)

print("\nInteger in exponent register:")
print(measurement.data)

r = process_measurement(measurement, x, n)
print("\nOrder r =", r)
if r is not None:
    print(f"x^r mod n = {x}^{r} mod {n} = {x**r % n}")

Finding the order of x = 5 modulo n = 6

Raw measurements:
exponent=1, 0, 0, 0, 0, 0, 0, 0, 0

Integer in exponent register:
   exponent
0       256

Order r = 2
x^r mod n = 5^2 mod 6 = 1


## Case B: Quantum Order Finding

### Define Quantum Order Finder

We can now define a streamlined function for the quantum version of order finding using the functions we have previously written. The quantum order finder below creates the circuit, executes it, and processes the measurement result.  

You should see that the order of $x = 5$ in $\mathbb{Z}_6$ is $r = 2$. Indeed, $5^2 \text{ mod } 6 = 25 \text{ mod } 6 = 1$. 

This completes a quantum implementation of an order finder, and the quantum part of Shor's algorithm.

In [11]:
def quantum_order_finder(x: int, n: int) -> Optional[int]:
    if x < 2 or n <= x or math.gcd(x, n) > 1:
        raise ValueError(f'Invalid x={x} for modulus n={n}.')

    circuit = make_order_finding_circuit(x, n)
    
    measurement = cirq.sample(circuit)
    return process_measurement(measurement, x, n)

# Use Order Finding

### Use the Quantum Order Finder

Use this quantum order finder to **complete** Shor's algorithm. In the following code block, add a few pre-processing steps which check if $n$ is even or prime or a prime power.

These steps can be done efficiently with a **classical** computer. Then add a final post-processing step which uses the order $r$ to compute a non-trivial factor $p$ of $n$. This is achieved by computing $y = x^{r / 2} \text{ mod } n$ (assuming $r$ is even), then computing $p = \text{gcd}(y - 1, n)$.

The function `find_factor` uses the `quantum_order_finder` by default, which executes Shor's algorithm. 

In [12]:
def find_factor_of_prime_power(n: int) -> Optional[int]:
    for k in range(2, math.floor(math.log2(n)) + 1):
        c = math.pow(n, 1 / k)
        c1 = math.floor(c)
        if c1**k == n:
            return c1
        c2 = math.ceil(c)
        if c2**k == n:
            return c2
    return None

def find_factor(
    n: int,
    order_finder: Callable[[int, int], Optional[int]] = quantum_order_finder,
    max_attempts: int = 30
) -> Optional[int]:
   
    if sympy.isprime(n):
        print("n is prime!")
        return None
    
    if n % 2 == 0:
        return 2
    
    c = find_factor_of_prime_power(n)
    if c is not None:
        return c
    
    for _ in range(max_attempts):
        x = random.randint(2, n - 1)
        c = math.gcd(x, n)
        
        if 1 < c < n:
            return c
        
        r = order_finder(x, n)
        
        if r is None:
            continue
        
        if r % 2 != 0:
            continue
        
        y = x**(r // 2) % n
        assert 1 < y < n
        c = math.gcd(y - 1, n)
        if 1 < c < n:
            return c

    print(f"Failed to find a non-trivial factor in {max_attempts} attempts.")
    # return None

### Use the Classical Order Finder
The `find_factor` operation uses the`quantum_order_finder` method, which executes Shor's algorithm by defaul. Due to the large memory requirements for classically *simulating* this circuit, use the `classical_order_ finder`. 

NOTE: Use Shor's algorithm for $n \le 15$. 

In [13]:
# test with non-prime numbers only
n = 184572
p = find_factor(n, order_finder=classical_order_finder)
q = n // p
p * q == n

print("If p * q == n is True, then this answer is correct.")
print("The result of " + str(p) + " * " + str(q) + " = "+ str(n) + " is " + str(p * q == n))

If p * q == n is True, then this answer is correct.
The result of 2 * 92286 = 184572 is True


# Learn More and Sources
##### Copyright 2020 The Cirq Developers

### License

In [14]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This tutorial presents a pedagogical demonstration of Shor's algorithm. It is a modified and expanded version of [this Cirq example](https://github.com/quantumlib/Cirq/blob/master/examples/shor.py).

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://quantumai.google/cirq/tutorials/shor"><img src="https://quantumai.google/site-assets/images/buttons/quantumai_logo_1x.png" />View on QuantumAI</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/quantumlib/Cirq/blob/master/docs/tutorials/shor.ipynb"><img src="https://quantumai.google/site-assets/images/buttons/colab_logo_1x.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/quantumlib/Cirq/blob/master/docs/tutorials/shor.ipynb"><img src="https://quantumai.google/site-assets/images/buttons/github_logo_1x.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/Cirq/docs/tutorials/shor.ipynb"><img src="https://quantumai.google/site-assets/images/buttons/download_icon_1x.png" />Download notebook</a>
  </td>
</table>