<a href="https://colab.research.google.com/github/ge96lip/Quantum-Computing/blob/main/QC_Shor's_Factorization_Algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prerequisites

In [238]:
!pip install qiskit
!pip install pylatexenc
!pip install qiskit_aer



In [239]:
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit.circuit.library import XGate, CSwapGate
from qiskit import transpile
from qiskit_aer import AerSimulator

# General Functions

**Here are some values of N to try:**

15, 21, 35, 39, 51, 55, 69, 77, 85, 87, 91, 93, 95, 111, 115, 117,
119, 123, 133, 155, 187, 203, 221, 247, 259, 287, 341, 451

**Larger numbers require more bits of precision.**

N = 15    precision_bits >= 4

N = 21    precision_bits >= 5

N = 35    precision_bits >= 6

N = 123   precision_bits >= 7

N = 341   precision_bits >= 8  time: about 6 seconds

N = 451   precision_bits >= 9  time: about 23 seconds

In [240]:
def shor_sample():
    N = 15             # The number we're factoring
    precision_bits = 4
    coprime = 2

    result = Shor(N, precision_bits, coprime)

    if result != None:
        print('Success! '+str(N)+'='+str(result[0])+'*'+str(result[1])+'\n');
    else:
        print('Failure: No non-trivial factors were found.\n')

In [241]:
def ShorLogic(N, repeat_period_candidates, coprime):
    print('Repeat period candidates: '+str(repeat_period_candidates)+'\n')
    factor_candidates = []
    for i in range(len(repeat_period_candidates)):
        repeat_period = repeat_period_candidates[i];
        # Given the repeat period, find the actual factors
        ar2 = pow(coprime, repeat_period / 2.0)
        factor1 = int(gcd(N, ar2 - 1))
        factor2 = int(gcd(N, ar2 + 1))
        factor_candidates.append([factor1, factor2])
    return factor_candidates

In [242]:
def gcd(a, b):
    # return the greatest common divisor of a,b
    while b:
        m = a % b
        a = b
        b = m
    return a

In [243]:
def check_result(N, factor_candidates):
    for i in range(len(factor_candidates)):
        factors = factor_candidates[i]
        if factors[0] * factors[1] == N:
            if factors[0] != 1 and factors[1] != 1:
                # Success!
                return factors
    # Failure
    return None

# Classical Factorization Algorithm

In [245]:
def Shor(N, precision_bits, coprime):
    repeat_period = ShorNoQPU(N, precision_bits, coprime) # quantum part
    factors = ShorLogic(N, repeat_period, coprime)        # classical part
    return check_result(N, factors)

In [246]:
def ShorNoQPU(N, precision_bits, coprime):
    # Classical replacement for the quantum part of Shor
    repeat_period_candidates = []
    work = 1
    max_loops = pow(2, precision_bits)
    for iter in range(max_loops):
        work = (work * coprime) % N
        if work == 1: # found a repeat period
            repeat_period_candidates.append(iter + 1)
    return repeat_period_candidates

In [247]:
shor_sample()

Repeat period candidates: [4, 8, 12, 16]

Success! 15=3*5



# Shor's Factorization Algorithm

In [248]:
def Shor(N, precision_bits, coprime):
    repeat_period = ShorQPU(N, precision_bits, coprime) # quantum part
    factors = ShorLogic(N, repeat_period, coprime)      # classical part
    return check_result(N, factors)

In [249]:
def ShorQPU(N, precision_bits, coprime):
    # Quantum part of Shor's algorithm
    # For this implementation, the coprime must be 2.
    coprime = 2;

    # For some numbers (like 15 and 21) the "mod" in a^xmod(N)
    # is not needed, because a^x wraps neatly around. This makes the
    # code simpler, and much easier to follow.
    if N == 15 or N == 21:
        return ShorQPU_WithoutModulo(N, precision_bits, coprime)
    else:
        return ShorQPU_WithModulo(N, precision_bits, coprime)

In [250]:
def RollLeft(qc, work, num_shifts, control):
    # Ensure we're not out of range for work qubits
    for j in range(work.size - num_shifts-1,-1,-1):
        # Apply CSWAP between work[j] <-> work[j + num_shifts]
        qc.append(CSwapGate(), [control, work[j], work[j + num_shifts]])
    return qc

In [251]:
# This is the short/simple version of ShorQPU() where we can perform a^x and
# don't need to be concerned with performing a quantum int modulus.
def ShorQPU_WithoutModulo(N, precision_bits, coprime):
    N_bits = 1
    while (1 << N_bits) < N:
        N_bits +=1
    if N != 15: # For this implementation, numbers other than 15 need an extra bit
        N_bits +=1

    # Set up the QPU and the working registers
    work = QuantumRegister(N_bits, name="work")
    precision = QuantumRegister(precision_bits, name="precision")
    classical = ClassicalRegister(precision_bits, name="precision_measure")
    qc = QuantumCircuit(work, precision, classical)

    # Initialization: set the working register to state |1>
    # and the precision register to state |0>
    # put the precision register into superposition using Hadamard gate
    qc.initialize(1, work)
    qc.initialize(0, precision)
    qc.h(precision)

    # Perform 2^x for all possible values of x in superposition
    for iter in range(precision_bits):
        num_shifts = 1 << iter
        qc = RollLeft(qc, work, num_shifts, precision[iter])
    print(qc)

    read_result = read_unsigned(qc, precision, classical)
    print('QPU read result: '+str(read_result)+'\n')
    repeat_period_candidates = estimate_num_spikes(read_result, 1 << precision_bits)

    return repeat_period_candidates

In [253]:
################################################################################
# This function is not working yet ! ###########################################
################################################################################

# This is the complicated version of ShorQPU() where we DO
# need to be concerned with performing a quantum int modulus.
# That's a complicated operation, and it also requires us to
# do the shifts one at a time.
def ShorQPU_WithModulo(N, precision_bits, coprime):
    scratch = None
    max_value = 1
    mod_engaged = False

    N_bits = 1
    scratch_bits = 0
    while (1 << N_bits) < N:
        N_bits +=1
    if N != 15: # For this implementation, numbers other than 15 need an extra bit
        N_bits +=1
    scratch_bits = 1
    total_bits = N_bits + precision_bits + scratch_bits

    # Set up the QPU and the working registers
    work = QuantumRegister(N_bits, name="work")
    precision = QuantumRegister(precision_bits, name="precision")
    scratch = QuantumRegister(1, name="scratch")
    classical = ClassicalRegister(precision_bits, name="precision_measure")
    qc = QuantumCircuit(work, precision, scratch, classical)

    # Initialization: set the working register to state |1>
    # and the precision register to state |0>
    # put the precision register into superposition using Hadamard gate
    qc.initialize(1, work)
    qc.initialize(0, precision)
    qc.initialize(0, scratch)
    qc.h(precision)

    print(qc)

    N_sign_bit_place = 1 << (N_bits - 1)
    #N_sign_bit = num.bits(N_sign_bit_place)
    for iter in range(precision_bits):
        #condition = precision.bits(1 << iter)
        #N_sign_bit_with_condition = num.bits(N_sign_bit_place)
        #N_sign_bit_with_condition.orEquals(condition)

        shifts = 1 << iter
        for sh in range(shifts):
            #qc.label('num *= coprime')
            qc = RollLeft(qc, work, 1, precision[iter])
            #for j in range(N_bits -2,-1,-1):
                # Apply CSWAP with precision[iter] as control and work[j] <-> work[j + 1] as targets
            #    qc.append(CSwapGate(), [precision[iter], num[j], num[j + 1]])

            #num.rollLeft(1, condition)   # Multiply by the coprime
            max_value <<= 1
            if max_value >= N:
                mod_engaged = True

            if mod_engaged:
                # Perform conditional subtraction and check for wrap
                qc = append_modulo_N(qc, N, work, scratch, precision[iter])
                qc.not_(work[0])
                qc.cx(work[0], scratch[0])  # Controlled-NOT based on the sign bit
                qc.not_(work[0])
                '''
                wrap_mask = scratch.bits()
                wrap_mask_with_condition = scratch.bits()
                wrap_mask_with_condition.orEquals(condition)

                # Here's the modulo code.
                num.subtract(N, condition) # subtract N, causing this to go negative if we HAVEN'T wrapped.
                scratch.cnot(N_sign_bit_with_condition) # Skim off the sign bit
                num.add(N, wrap_mask_with_condition) # If we went negative, undo the subtraction.
                num.not(1)
                scratch.cnot(num, 1, condition) # If it's odd, then we wrapped, so clear the wrap bit
                num.not(1)
                '''

    # Quantum Fourier Transform on precision register
    qc.qft(precision)

    read_result = read_unsigned(precision)
    print('QPU read result: '+str(read_result)+'\n')
    repeat_period_candidates = estimate_num_spikes(read_result, 1 << precision_bits)

    return repeat_period_candidates

In [254]:
# In case our QPU read returns a "signed" negative value,
# convert it to unsigned.
def read_unsigned(qc, qreg, creg):
    qc.measure(qreg, creg) #value = qreg.read()

    # Simulate the circuit using the Aer simulator
    simulator = AerSimulator()
    job = simulator.run(transpile(qc, simulator), shots=1024)
    result = job.result()
    counts = result.get_counts(qc)

    # Extract value of most likely measurement outcome
    value = int(max(counts, key=counts.get),2)

    return value & ((1 << qreg.size) - 1)

In [255]:
def estimate_num_spikes(spike, spike_range):
    if spike < spike_range / 2:
        spike = spike_range - spike
    best_error = 1.0
    e0 = 0
    e1 = 0
    e2 = 0
    actual = spike / spike_range
    candidates = []
    for denominator in range(1,spike):
        numerator = round(denominator * actual)
        estimated = numerator / denominator
        error = abs(estimated - actual)
        e0 = e1
        e1 = e2
        e2 = error
        # Look for a local minimum which beats our
        # current best error
        if e1 <= best_error or e1 < e0 or e1 < e2:
            repeat_period = denominator - 1
            candidates.append(repeat_period)
            best_error = e1
    return candidates

In [256]:
################################################################################
# This function is not working yet ! ###########################################
################################################################################

def append_modulo_N(qc, N, work, scratch, condition=None):
    # Performs a controlled modular reduction on the 'work' register with
    # respect to N, using the 'scratch' register to store the sign bit.

    N_bits = work.size
    N_binary = format(N, 'b').zfill(N_bits)  # Binary representation of N with padding to match work size

    # Step 1: Conditionally subtract N (controlled by `condition` qubit)
    # We use multi-controlled X gates to simulate the effect of conditional subtraction.
    for i, bit in enumerate(reversed(N_binary)):
        if bit == '1':  # Only operate on bits where N has a '1'
            if condition:
                # Apply controlled X gate if the condition qubit is set
                qc.ccx(condition, work[i], scratch[0])  # Controlled-controlled-X for the i-th bit
            else:
                qc.cx(work[i], scratch[0])

    # Step 2: Use the `scratch` qubit to store the sign bit
    # Copy the most significant bit (MSB) of `work` to `scratch` to store the sign
    qc.cx(work[-1], scratch[0])

    # Step 3: Conditionally add N back if the sign bit (scratch) indicates overflow
    # This is done if the scratch bit is set
    for i, bit in enumerate(reversed(N_binary)):
        if bit == '1':  # Add N back for each '1' bit in N's binary representation
            # Apply addition based on the `scratch` bit as the control
            qc.ccx(scratch[0], work[i], work[i])  # Double-controlled addition bitwise

    # Step 4: Clear the wrap bit if necessary
    # Here we conditionally reset `scratch` to 0 if `work` is odd.
    qc.cx(work[0], scratch[0])  # Clear scratch based on the least significant bit (LSB) of `work`

    return qc

In [257]:
shor_sample()

                     ┌────────────────┐                    
             work_0: ┤0               ├────────────X─────X─
                     │                │            │     │ 
             work_1: ┤1               ├─────────X──X──X──┼─
                     │  Initialize(1) │         │  │  │  │ 
             work_2: ┤2               ├──────X──X──┼──┼──X─
                     │                │      │  │  │  │  │ 
             work_3: ┤3               ├──────X──┼──┼──X──┼─
                     ├────────────────┤┌───┐ │  │  │  │  │ 
        precision_0: ┤0               ├┤ H ├─■──■──■──┼──┼─
                     │                │├───┤          │  │ 
        precision_1: ┤1               ├┤ H ├──────────■──■─
                     │  Initialize(0) │├───┤               
        precision_2: ┤2               ├┤ H ├───────────────
                     │                │├───┤               
        precision_3: ┤3               ├┤ H ├───────────────
                     └────────────────┘└