<a href="https://colab.research.google.com/github/ishwet/Shor-s-Algorithm-Implementation/blob/main/shor's_algorithm_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""Install Cirq."""

try:
    import cirq
except ImportError:
    print("installing cirq...")
    !pip install --quiet cirq
    print("installed cirq.")

installing cirq...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.8/670.8 kB[0m [31m39.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.5/73.5 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m430.5/430.5 kB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m70.4 MB/s[0m eta [36m0:00:00[0m
[?25hinstalled cirq.


In [2]:
"""Imports for the notebook."""

import fractions
import math
import random

import numpy as np
import sympy
from typing import Callable, Iterable, Sequence

import cirq

In [6]:
"""Zn and finding the GCD part."""


def multiplicative_group(n: int) -> list[int]:
    assert n > 1
    group = [1]
    for x in range(2, n):
        if math.gcd(x, n) == 1:
            group.append(x)
    return group

In [7]:
"""Example explained."""

n = 15
print(f"The multiplicative group modulo n = {n} is:")
print(multiplicative_group(n))

The multiplicative group modulo n = 15 is:
[1, 2, 4, 7, 8, 11, 13, 14]


In [9]:
""" Classical Order Finding """


def classical_order_finder(x: int, n: int) -> int | None:

    # Take correct values.
    if x < 2 or x >= n or math.gcd(x, n) > 1:
        raise ValueError(f"Invalid x={x} for modulus n={n}.")

    # order determination part.
    r, y = 1, x
    while y != 1:
        y = (x * y) % n
        r += 1
    return r

In [11]:
"""Example solved"""

n = 15
x = 8
r = classical_order_finder(x, n)


print(f"x^r mod n = {x}^{r} mod {n} = {x**r % n}")

x^r mod n = 8^4 mod 15 = 1


In [12]:
"""Example of defining an arithmetic (quantum) gate in Cirq."""


class Adder(cirq.ArithmeticGate):
    """Quantum addition."""

    def __init__(self, target_register: int | Sequence[int], input_register: int | Sequence[int]):
        self.target_register = target_register
        self.input_register = input_register

    def registers(self) -> Sequence[int | Sequence[int]]:
        return self.target_register, self.input_register

    def with_registers(self, *new_registers: int | Sequence[int]) -> 'Adder':
        return Adder(*new_registers)

    def apply(self, *register_values: int) -> int | Iterable[int]:
        return sum(register_values)

    def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs):
        wire_symbols = [' + ' for _ in range(len(self.input_register) + len(self.target_register))]
        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))

In [13]:
"""Example of using an Adder in a circuit."""

# Two qubit registers.
qreg1 = cirq.LineQubit.range(2)
qreg2 = cirq.LineQubit.range(2, 4)

# Define an adder gate for two 2D input and target qubits.
adder = Adder(input_register=[2, 2], target_register=[2, 2])

# Define the circuit.
circ = cirq.Circuit(
    cirq.X.on(qreg1[0]),
    cirq.X.on(qreg2[1]),
    adder.on(*qreg1, *qreg2),
    cirq.measure_each(*qreg1),
    cirq.measure_each(*qreg2),
)

# Display it.
print("Circuit:\n")
print(circ)

# Print the measurement outcomes.
print("\n\nMeasurement outcomes:\n")
print(cirq.sample(circ, repetitions=5).data)

Circuit:

0: ───X─── + ───M───
          │
1: ─────── + ───M───
          │
2: ─────── + ───M───
          │
3: ───X─── + ───M───


Measurement outcomes:

   q(0)  q(1)  q(2)  q(3)
0     1     1     0     1
1     1     1     0     1
2     1     1     0     1
3     1     1     0     1
4     1     1     0     1


In [14]:
"""Example of the unitary of an Adder gate."""

cirq.unitary(Adder(target_register=[2, 2], input_register=1)).real

array([[0., 0., 0., 1.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.]])

In [15]:
"""Defines the modular exponential gate used in Shor's algorithm."""


class ModularExp(cirq.ArithmeticGate):
    """Quantum modular exponentiation.

    This class represents the unitary which multiplies base raised to exponent
    into the target modulo the given modulus. More precisely, it represents the
    unitary V which computes modular exponentiation x**e mod n:

        V|y⟩|e⟩ = |y * x**e mod n⟩ |e⟩     0 <= y < n
        V|y⟩|e⟩ = |y⟩ |e⟩                  n <= y

    where y is the target register, e is the exponent register, x is the base
    and n is the modulus. Consequently,

        V|y⟩|e⟩ = (U**e|y)|e⟩

    where U is the unitary defined as

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y
    """

    def __init__(
        self, target: Sequence[int], exponent: int | Sequence[int], base: int, modulus: int
    ) -> None:
        if len(target) < modulus.bit_length():
            raise ValueError(
                f'Register with {len(target)} qubits is too small for modulus' f' {modulus}'
            )
        self.target = target
        self.exponent = exponent
        self.base = base
        self.modulus = modulus

    def registers(self) -> Sequence[int | Sequence[int]]:
        return self.target, self.exponent, self.base, self.modulus

    def with_registers(self, *new_registers: int | Sequence[int]) -> 'ModularExp':
        """Returns a new ModularExp object with new registers."""
        if len(new_registers) != 4:
            raise ValueError(
                f'Expected 4 registers (target, exponent, base, '
                f'modulus), but got {len(new_registers)}'
            )
        target, exponent, base, modulus = new_registers
        if not isinstance(target, Sequence):
            raise ValueError(f'Target must be a qubit register, got {type(target)}')
        if not isinstance(base, int):
            raise ValueError(f'Base must be a classical constant, got {type(base)}')
        if not isinstance(modulus, int):
            raise ValueError(f'Modulus must be a classical constant, got {type(modulus)}')
        return ModularExp(target, exponent, base, modulus)

    def apply(self, *register_values: int) -> int:
        """Applies modular exponentiation to the registers.

        Four values should be passed in.  They are, in order:
          - the target
          - the exponent
          - the base
          - the modulus

        Note that the target and exponent should be qubit
        registers, while the base and modulus should be
        constant parameters that control the resulting unitary.
        """
        assert len(register_values) == 4
        target, exponent, base, modulus = register_values
        if target >= modulus:
            return target
        return (target * base**exponent) % modulus

    def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
        """Returns a 'CircuitDiagramInfo' object for printing circuits.

        This function just returns information on how to print this operation
        out in a circuit diagram so that the registers are labeled
        appropriately as exponent ('e') and target ('t').
        """
        assert args.known_qubits is not None
        wire_symbols = [f't{i}' for i in range(len(self.target))]
        e_str = str(self.exponent)
        if isinstance(self.exponent, Sequence):
            e_str = 'e'
            wire_symbols += [f'e{i}' for i in range(len(self.exponent))]
        wire_symbols[0] = f'ModularExp(t*{self.base}**{e_str} % {self.modulus})'
        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))

In [16]:
"""Create the target and exponent registers for phase estimation,
and see the number of qubits needed for Shor's algorithm.
"""

n = 15
L = n.bit_length()

# The target register has L qubits.
target = cirq.LineQubit.range(L)

# The exponent register has 2L + 3 qubits.
exponent = cirq.LineQubit.range(L, 3 * L + 3)

# Display the total number of qubits to factor this n.
print(f"To factor n = {n} which has L = {L} bits, we need 3L + 3 = {3 * L + 3} qubits.")

To factor n = 15 which has L = 4 bits, we need 3L + 3 = 15 qubits.


In [17]:
"""See (part of) the unitary for a modular exponential gate."""

# Pick some element of the multiplicative group modulo n.
x = 5

# Display (part of) the unitary. Uncomment if n is small enough.
# cirq.unitary(ModularExp(target, exponent, x, n))

In [18]:
"""Function to make the quantum circuit for order finding."""


def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:
    """Returns quantum circuit which computes the order of x modulo n.

    The circuit uses Quantum Phase Estimation to compute an eigenvalue of
    the following unitary:

        U|y⟩ = |y * x mod n⟩      0 <= y < n
        U|y⟩ = |y⟩                n <= y

    Args:
        x: positive integer whose order modulo n is to be found
        n: modulus relative to which the order of x is to be found

    Returns:
        Quantum circuit for finding the order of x modulo n
    """
    L = n.bit_length()
    target = cirq.LineQubit.range(L)
    exponent = cirq.LineQubit.range(L, 3 * L + 3)

    # Create a ModularExp gate sized for these registers.
    mod_exp = ModularExp([2] * L, [2] * (2 * L + 3), x, n)

    return cirq.Circuit(
        cirq.X(target[L - 1]),
        cirq.H.on_each(*exponent),
        mod_exp.on(*target, *exponent),
        cirq.qft(*exponent, inverse=True),
        cirq.measure(*exponent, key='exponent'),
    )

In [19]:
"""Example of the quantum circuit for period finding."""

n = 15
x = 7
circuit = make_order_finding_circuit(x, n)
print(circuit)

0: ────────ModularExp(t*7**e % 15)────────────────────────────
           │
1: ────────t1─────────────────────────────────────────────────
           │
2: ────────t2─────────────────────────────────────────────────
           │
3: ────X───t3─────────────────────────────────────────────────
           │
4: ────H───e0────────────────────────qft^-1───M('exponent')───
           │                         │        │
5: ────H───e1────────────────────────#2───────M───────────────
           │                         │        │
6: ────H───e2────────────────────────#3───────M───────────────
           │                         │        │
7: ────H───e3────────────────────────#4───────M───────────────
           │                         │        │
8: ────H───e4────────────────────────#5───────M───────────────
           │                         │        │
9: ────H───e5────────────────────────#6───────M───────────────
           │                         │        │
10: ───H───e6─────────────────

In [20]:
"""Measuring Shor's period finding circuit."""

circuit = make_order_finding_circuit(x=5, n=6)
res = cirq.sample(circuit, repetitions=8)

print("Raw measurements:")
print(res)

print("\nInteger in exponent register:")
print(res.data)

Raw measurements:
exponent=00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000

Integer in exponent register:
   exponent
0         0
1         0
2         0
3         0
4         0
5         0
6         0
7         0


In [21]:
def process_measurement(result: cirq.Result, x: int, n: int) -> int | None:
    """Interprets the output of the order finding circuit.

    Specifically, it determines s/r such that exp(2πis/r) is an eigenvalue
    of the unitary

        U|y⟩ = |xy mod n⟩  0 <= y < n
        U|y⟩ = |y⟩         n <= y

    then computes r (by continued fractions) if possible, and returns it.

    Args:
        result: result obtained by sampling the output of the
            circuit built by make_order_finding_circuit

    Returns:
        r, the order of x modulo n or None.
    """
    # Read the output integer of the exponent register.
    exponent_as_integer = result.data["exponent"][0]
    exponent_num_bits = result.measurements["exponent"].shape[1]
    eigenphase = float(exponent_as_integer / 2**exponent_num_bits)

    # Run the continued fractions algorithm to determine f = s / r.
    f = fractions.Fraction.from_float(eigenphase).limit_denominator(n)

    # If the numerator is zero, the order finder failed.
    if f.numerator == 0:
        return None

    # Else, return the denominator if it is valid.
    r = f.denominator
    if x**r % n != 1:
        return None
    return r

In [22]:
"""Example of the classical post-processing."""

# Set n and x here
n = 6
x = 5

print(f"Finding the order of x = {x} modulo n = {n}\n")
measurement = cirq.sample(circuit, repetitions=1)
print("Raw measurements:")
print(measurement)

print("\nInteger in exponent register:")
print(measurement.data)

r = process_measurement(measurement, x, n)
print("\nOrder r =", r)
if r is not None:
    print(f"x^r mod n = {x}^{r} mod {n} = {x**r % n}")

Finding the order of x = 5 modulo n = 6

Raw measurements:
exponent=1, 0, 0, 0, 0, 0, 0, 0, 0

Integer in exponent register:
   exponent
0       256

Order r = 2
x^r mod n = 5^2 mod 6 = 1


In [23]:
def quantum_order_finder(x: int, n: int) -> int | None:
    """Computes smallest positive r such that x**r mod n == 1.

    Args:
        x: integer whose order is to be computed, must be greater than one
           and belong to the multiplicative group of integers modulo n (which
           consists of positive integers relatively prime to n),
        n: modulus of the multiplicative group.
    """
    # Check that the integer x is a valid element of the multiplicative group
    # modulo n.
    if x < 2 or n <= x or math.gcd(x, n) > 1:
        raise ValueError(f'Invalid x={x} for modulus n={n}.')

    # Create the order finding circuit.
    circuit = make_order_finding_circuit(x, n)

    # Sample from the order finding circuit.
    measurement = cirq.sample(circuit)

    # Return the processed measurement result.
    return process_measurement(measurement, x, n)

In [24]:
"""Functions for factoring from start to finish."""


def find_factor_of_prime_power(n: int) -> int | None:
    """Returns non-trivial factor of n if n is a prime power, else None."""
    for k in range(2, math.floor(math.log2(n)) + 1):
        c = math.pow(n, 1 / k)
        c1 = math.floor(c)
        if c1**k == n:
            return c1
        c2 = math.ceil(c)
        if c2**k == n:
            return c2
    return None


def find_factor(
    n: int,
    order_finder: Callable[[int, int], int | None] = quantum_order_finder,
    max_attempts: int = 30,
) -> int | None:
    """Returns a non-trivial factor of composite integer n.

    Args:
        n: Integer to factor.
        order_finder: Function for finding the order of elements of the
            multiplicative group of integers modulo n.
        max_attempts: number of random x's to try, also an upper limit
            on the number of order_finder invocations.

    Returns:
        Non-trivial factor of n or None if no such factor was found.
        Factor k of n is trivial if it is 1 or n.
    """
    # If the number is prime, there are no non-trivial factors.
    if sympy.isprime(n):
        print("n is prime!")
        return None

    # If the number is even, two is a non-trivial factor.
    if n % 2 == 0:
        return 2

    # If n is a prime power, we can find a non-trivial factor efficiently.
    c = find_factor_of_prime_power(n)
    if c is not None:
        return c

    for _ in range(max_attempts):
        # Choose a random number between 2 and n - 1.
        x = random.randint(2, n - 1)

        # Most likely x and n will be relatively prime.
        c = math.gcd(x, n)

        # If x and n are not relatively prime, we got lucky and found
        # a non-trivial factor.
        if 1 < c < n:
            return c

        # Compute the order r of x modulo n using the order finder.
        r = order_finder(x, n)

        # If the order finder failed, try again.
        if r is None:
            continue

        # If the order r is even, try again.
        if r % 2 != 0:
            continue

        # Compute the non-trivial factor.
        y = x ** (r // 2) % n
        assert 1 < y < n
        c = math.gcd(y - 1, n)
        if 1 < c < n:
            return c

    print(f"Failed to find a non-trivial factor in {max_attempts} attempts.")
    return None

In [25]:
"""Example of factoring via Shor's algorithm (order finding)."""

# Number to factor
n = 184573

# Attempt to find a factor
p = find_factor(n, order_finder=classical_order_finder)
q = n // p

print("Factoring n = pq =", n)
print("p =", p)
print("q =", q)

Factoring n = pq = 184573
p = 487
q = 379


In [26]:
"""Check the answer is correct."""

p * q == n

True