In [1]:
""" This module implements Shor factorization algorithm"""
import fractions
import random

from math import gcd
from sympy import isprime
from sympy.physics.quantum.qapply import qapply
from sympy.physics.quantum.qubit import Qubit, matrix_to_qubit, \
    IntQubit, measure_partial_oneshot, measure_all_oneshot
from sympy.physics.quantum import TensorProduct
from sympy.physics.quantum.represent import represent
from sympy.physics.quantum.qft import IQFT
from oracle import oracle
from util.util import hn

In [2]:
def fan(x, N, a):
    """
    Function to calculate a**x%N
    :param x: |x> qubit
    :param N: product we want to find the factors of
    :param a: relatively prime number to N
  """
    x = IntQubit(Qubit(*x)).as_int()
    r = (a ** x) % N
    return r

## The problem:
Factor an interger N into its prime components

## The Solution:
```
                     +-----------------+
  |x>  |0>-/m-H*m----|                 |--/m--QFT+---M |x>
                     |       U_f_a,N   |
  |y>  |0>-/n--------|                 |--/n--------  |y XOR f(x, a, N)>
                     +-----------------+
```

Quantum portion of Shor's algorithm

In [3]:
def shor_period(f, m, n, N, a):
    """
    Shor's algorithm, find the period quantumly
    :param func f: fan function returning a**x%N
    :param m: |x> bit number
    :param n: |y> bit number
    :param N: product we want to find the factors of
    :param a: relatively prime number to N
    :returns int period

    """
    # apply H gate to both inputs
    x = qapply(hn(m) * Qubit('0' * m))
    print(f"|x>: {x}")
    y = Qubit('0' * n)
    print(f"|y>: {y}")
    xy = TensorProduct(x, y)
    print(f"|xy>: {xy}")
    xy = matrix_to_qubit(represent(xy))
    print(f"|xy>: {xy}")

    # apply oracle
    xy_xor_fx = oracle(x, y, f, N, a)
    print(f"|xy_xor_fx>: {xy_xor_fx}")

    xy_xor_fx = measure_partial_oneshot(xy_xor_fx, range(n))
    print(f"|xy_xor_fx>: {xy_xor_fx} n = {n}")

    # create a new state without y qubit
    y_xor_fx = 0
    for expr in xy_xor_fx.args:
        ee = 1
        for e in expr.args:
            if isinstance(e, Qubit):
                qbit = e
            else:
                ee *= e
        # throw y part and just keep x
        y_xor_fx += ee * Qubit(*qbit.qubit_values[0:qbit.dimension - n])

    print(f"New state without y |y_xor_fx>: {y_xor_fx}")

    # apply QFT to
    y_xor_fx = qapply(IQFT(0, m).decompose() * y_xor_fx)
    print(f"After IQFT |y_xor_fx>={y_xor_fx}")

    mea = measure_all_oneshot(y_xor_fx)
    print(f"measure(state) = {mea} as int = {IntQubit(mea).as_int()}")

    f = fractions.Fraction.from_float(float(IntQubit(mea).as_int() /
                                            2 ** m)).limit_denominator(N)
    r = f.denominator

    if f.numerator == 0:
        return None

    if a ** r % N != 1:
        return None

    return r

Full Shor's algorithm

In [4]:
def shor(N, *args):
    """
    Shor's algorithm, non-quantum part
    :param N: product we want to find the factors of
    :returns [int] factors
    """
    if isprime(N):
        print(f"{N} is prime!")
        return None

    r = None
    max_tries = 30
    i = 0
    factors = []
    while i < max_tries and r is None:
        if args:
            a = args[0]
            assert a < N
        else:
            a = random.randint(2, N - 1)
        m = n = len(bin(N)[2:]) + 1

        if 1 < gcd(a, N) < N:
            print(f"a = {a} and N = {N} are not relatively prime")
            factors.append(gcd(a, N))
            factors.append(int(N / gcd(a, N)))
            return factors

        print(f"{'=' * 40}")
        print(f"Shore N={N}, m={m}, n={n}, a={a} iter {i}")
        r = shor_period(fan, m, n, N, a)
        i = i + 1
    print(f"period r = {r}")
    if r is None:
        return factors

    factors.append(gcd(int(a ** (r / 2)) - 1, N))
    factors.append(gcd(int(a ** (r / 2)) + 1, N))
    return factors

## Tests

In [5]:
def test_0():
    N = 15
    factors = shor(N, 2)
    print("factors = {factors}".format(factors=factors))
    truth = [3, 5]
    assert truth == sorted(factors)


test_0()

Shore N=15, m=5, n=5, a=2 iter 0
|x>: sqrt(2)*|00000>/8 + sqrt(2)*|00001>/8 + sqrt(2)*|00010>/8 + sqrt(2)*|00011>/8 + sqrt(2)*|00100>/8 + sqrt(2)*|00101>/8 + sqrt(2)*|00110>/8 + sqrt(2)*|00111>/8 + sqrt(2)*|01000>/8 + sqrt(2)*|01001>/8 + sqrt(2)*|01010>/8 + sqrt(2)*|01011>/8 + sqrt(2)*|01100>/8 + sqrt(2)*|01101>/8 + sqrt(2)*|01110>/8 + sqrt(2)*|01111>/8 + sqrt(2)*|10000>/8 + sqrt(2)*|10001>/8 + sqrt(2)*|10010>/8 + sqrt(2)*|10011>/8 + sqrt(2)*|10100>/8 + sqrt(2)*|10101>/8 + sqrt(2)*|10110>/8 + sqrt(2)*|10111>/8 + sqrt(2)*|11000>/8 + sqrt(2)*|11001>/8 + sqrt(2)*|11010>/8 + sqrt(2)*|11011>/8 + sqrt(2)*|11100>/8 + sqrt(2)*|11101>/8 + sqrt(2)*|11110>/8 + sqrt(2)*|11111>/8
|y>: |00000>
|xy>: (sqrt(2)*|00000>/8 + sqrt(2)*|00001>/8 + sqrt(2)*|00010>/8 + sqrt(2)*|00011>/8 + sqrt(2)*|00100>/8 + sqrt(2)*|00101>/8 + sqrt(2)*|00110>/8 + sqrt(2)*|00111>/8 + sqrt(2)*|01000>/8 + sqrt(2)*|01001>/8 + sqrt(2)*|01010>/8 + sqrt(2)*|01011>/8 + sqrt(2)*|01100>/8 + sqrt(2)*|01101>/8 + sqrt(2)*|01110>/8 + sqrt

In [6]:
def test_1():
    N = 15
    factors = shor(N, 7)
    print("factors = {factors}".format(factors=factors))
    truth = [3, 5]
    assert truth == sorted(factors)


test_1()

Shore N=15, m=5, n=5, a=7 iter 0
|x>: sqrt(2)*|00000>/8 + sqrt(2)*|00001>/8 + sqrt(2)*|00010>/8 + sqrt(2)*|00011>/8 + sqrt(2)*|00100>/8 + sqrt(2)*|00101>/8 + sqrt(2)*|00110>/8 + sqrt(2)*|00111>/8 + sqrt(2)*|01000>/8 + sqrt(2)*|01001>/8 + sqrt(2)*|01010>/8 + sqrt(2)*|01011>/8 + sqrt(2)*|01100>/8 + sqrt(2)*|01101>/8 + sqrt(2)*|01110>/8 + sqrt(2)*|01111>/8 + sqrt(2)*|10000>/8 + sqrt(2)*|10001>/8 + sqrt(2)*|10010>/8 + sqrt(2)*|10011>/8 + sqrt(2)*|10100>/8 + sqrt(2)*|10101>/8 + sqrt(2)*|10110>/8 + sqrt(2)*|10111>/8 + sqrt(2)*|11000>/8 + sqrt(2)*|11001>/8 + sqrt(2)*|11010>/8 + sqrt(2)*|11011>/8 + sqrt(2)*|11100>/8 + sqrt(2)*|11101>/8 + sqrt(2)*|11110>/8 + sqrt(2)*|11111>/8
|y>: |00000>
|xy>: (sqrt(2)*|00000>/8 + sqrt(2)*|00001>/8 + sqrt(2)*|00010>/8 + sqrt(2)*|00011>/8 + sqrt(2)*|00100>/8 + sqrt(2)*|00101>/8 + sqrt(2)*|00110>/8 + sqrt(2)*|00111>/8 + sqrt(2)*|01000>/8 + sqrt(2)*|01001>/8 + sqrt(2)*|01010>/8 + sqrt(2)*|01011>/8 + sqrt(2)*|01100>/8 + sqrt(2)*|01101>/8 + sqrt(2)*|01110>/8 + sqrt

In [7]:
def test_2():
    N = 15
    factors = shor(N, 13)
    print("factors = {factors}".format(factors=factors))
    truth = [3, 5]
    assert truth == sorted(factors)


test_2()

Shore N=15, m=5, n=5, a=13 iter 0
|x>: sqrt(2)*|00000>/8 + sqrt(2)*|00001>/8 + sqrt(2)*|00010>/8 + sqrt(2)*|00011>/8 + sqrt(2)*|00100>/8 + sqrt(2)*|00101>/8 + sqrt(2)*|00110>/8 + sqrt(2)*|00111>/8 + sqrt(2)*|01000>/8 + sqrt(2)*|01001>/8 + sqrt(2)*|01010>/8 + sqrt(2)*|01011>/8 + sqrt(2)*|01100>/8 + sqrt(2)*|01101>/8 + sqrt(2)*|01110>/8 + sqrt(2)*|01111>/8 + sqrt(2)*|10000>/8 + sqrt(2)*|10001>/8 + sqrt(2)*|10010>/8 + sqrt(2)*|10011>/8 + sqrt(2)*|10100>/8 + sqrt(2)*|10101>/8 + sqrt(2)*|10110>/8 + sqrt(2)*|10111>/8 + sqrt(2)*|11000>/8 + sqrt(2)*|11001>/8 + sqrt(2)*|11010>/8 + sqrt(2)*|11011>/8 + sqrt(2)*|11100>/8 + sqrt(2)*|11101>/8 + sqrt(2)*|11110>/8 + sqrt(2)*|11111>/8
|y>: |00000>
|xy>: (sqrt(2)*|00000>/8 + sqrt(2)*|00001>/8 + sqrt(2)*|00010>/8 + sqrt(2)*|00011>/8 + sqrt(2)*|00100>/8 + sqrt(2)*|00101>/8 + sqrt(2)*|00110>/8 + sqrt(2)*|00111>/8 + sqrt(2)*|01000>/8 + sqrt(2)*|01001>/8 + sqrt(2)*|01010>/8 + sqrt(2)*|01011>/8 + sqrt(2)*|01100>/8 + sqrt(2)*|01101>/8 + sqrt(2)*|01110>/8 + sqr

In [8]:
def test_3():
    N = 3*7
    factors = shor(N, 13)
    print("factors = {factors}".format(factors=factors))
    truth = [3, 7]
    assert truth == sorted(factors)


test_3()

Shore N=21, m=6, n=6, a=13 iter 0
|x>: |000000>/8 + |000001>/8 + |000010>/8 + |000011>/8 + |000100>/8 + |000101>/8 + |000110>/8 + |000111>/8 + |001000>/8 + |001001>/8 + |001010>/8 + |001011>/8 + |001100>/8 + |001101>/8 + |001110>/8 + |001111>/8 + |010000>/8 + |010001>/8 + |010010>/8 + |010011>/8 + |010100>/8 + |010101>/8 + |010110>/8 + |010111>/8 + |011000>/8 + |011001>/8 + |011010>/8 + |011011>/8 + |011100>/8 + |011101>/8 + |011110>/8 + |011111>/8 + |100000>/8 + |100001>/8 + |100010>/8 + |100011>/8 + |100100>/8 + |100101>/8 + |100110>/8 + |100111>/8 + |101000>/8 + |101001>/8 + |101010>/8 + |101011>/8 + |101100>/8 + |101101>/8 + |101110>/8 + |101111>/8 + |110000>/8 + |110001>/8 + |110010>/8 + |110011>/8 + |110100>/8 + |110101>/8 + |110110>/8 + |110111>/8 + |111000>/8 + |111001>/8 + |111010>/8 + |111011>/8 + |111100>/8 + |111101>/8 + |111110>/8 + |111111>/8
|y>: |000000>
|xy>: (|000000>/8 + |000001>/8 + |000010>/8 + |000011>/8 + |000100>/8 + |000101>/8 + |000110>/8 + |000111>/8 + |00100