# Shor's Algorithm with Qiskit Aer

This notebook demonstrates Shor's Algorithm for integer factorization using Qiskit and the Aer Simulator. 

**Goal**: Factor a composite integer $N$ (e.g., $N=15$) into its prime factors.  
**Note**: To change $N$ from 15, modify the `N` variable in the `__main__` block.

**Method**: 
1.  **Classical Part**: Reduce the factoring problem to the problem of finding the period of a function $f(x) = a^x \pmod N$.
2.  **Quantum Part**: Use Quantum Phase Estimation (QPE) to find this period $r$.
3.  **Classical Post-processing**: Use the period $r$ to find factors via $\text{gcd}(a^{r/2} \pm 1, N)$.

**Note about N selection** (I have not yet verified all of these, 143 is where I stopped):  
  
N=15 (4 bits): 3 \times 4 = 12 qubits.
Memory: Negligible (KB). Instant.

N=35 (6 bits): 3 \times 6 = 18 qubits.
Memory: ~4 MB. Very Fast (seconds).

N=143 (8 bits): 3 \times 8 = 24 qubits.
Memory: ~256 MB. Slower (minutes).

N=511 (9 bits): 3 \times 9 = 27 qubits.
Memory: ~2 GB. Heavy (ugh).

N=1023 (10 bits): 3 \times 10 = 30 qubits.
Memory: ~16 GB. Heavier? (Goodluck).

N=2047 (11 bits): 3 \times 11 = 33 qubits.
Memory: ~128 GB. Lol (I wish I had this computer).

In [1]:
from qiskit import QuantumCircuit, transpile
from qiskit.circuit.library import QFTGate, UnitaryGate
from qiskit_aer import AerSimulator
from fractions import Fraction
import numpy as np
import random
import math
import sympy
import matplotlib.pyplot as plt

## 1. The Quantum Circuit

We define a custom `QuantumCircuit` class that implements the Quantum Phase Estimation routine.

### Key Components:
1.  **Superposition**: Initialize counting qubits to $|+\rangle$ state.
2.  **Modular Exponentiation**: Apply controlled unitary gates $U^{2^j}$ where $U|y\rangle = |a y \pmod N\rangle$. We construct these matrices manually to ensure correctness for small $N$.
3.  **Inverse QFT**: Transform the phase information from the Fourier basis back to the computational basis to measure the period.

In [2]:
class ShorCircuit(QuantumCircuit):
    def __init__(self, a, N):
        self.n_target = N.bit_length()
        # We need enough counting qubits to estimate the phase with sufficient precision.
        # 2 * n_target is a standard choice.
        self.n_count = 2 * self.n_target 
        total_qubits = self.n_count + self.n_target 
        
        super().__init__(total_qubits, self.n_count)
        
        self.a = a
        self.N = N
        
        self._create_circuit()

    def _get_controlled_unitary_matrix(self, power_of_a):
        """
        Creates the matrix for the operation U^x |y> = |(a^x * y) mod N>.
        We build this manually as a UnitaryGate.
        """
        dim_target = 2 ** self.n_target
        U_matrix = np.zeros((dim_target, dim_target), dtype=complex)
        
        # Calculate a^(power_of_a) mod N efficiently
        effective_multiplier = pow(self.a, power_of_a, self.N)
        
        for y in range(dim_target):
            if y < self.N:
                target_y = (effective_multiplier * y) % self.N
            else:
                # Identity for states >= N (which shouldn't be populated ideally)
                target_y = y 
            U_matrix[target_y, y] = 1
            
        # Create the Controlled-U matrix: block diagonal [I, 0; 0, U]
        CU_matrix = np.block([
            [np.eye(dim_target), np.zeros((dim_target, dim_target))],
            [np.zeros((dim_target, dim_target)), U_matrix]
        ])
            
        return UnitaryGate(CU_matrix, label=f"C-{self.a}^{power_of_a}")

    def _create_circuit(self):
        """
        Creates the quantum circuit for Shor's algorithm.
        """
        # 1. Initialize counting qubits to superposition |+>
        self.h(range(self.n_count))
        
        # 2. Initialize target register to |1> (eigenstate of the unitary operator)
        self.x(self.num_qubits - 1) 
        
        # 3. Apply Controlled-U operations (Phase Kickback)
        for i in range(self.n_count):
            power_of_a = 2**i
            CU_gate = self._get_controlled_unitary_matrix(power_of_a)
            
            target_qubits = list(range(self.n_count, self.num_qubits))
            control_qubit = i
            
            self.append(CU_gate, target_qubits + [control_qubit])

        # 4. Apply Inverse QFT to the counting qubits
        qft_gate = QFTGate(self.n_count).inverse()
        self.append(qft_gate, range(self.n_count))
        
        # 5. Measure the counting qubits
        self.measure(range(self.n_count), range(self.n_count))

    def run_simulation(self, simulator):
        """
        Transpiles and runs the circuit on the provided simulator.
        """
        transpiled_circuit = transpile(self, simulator)
        result = simulator.run(transpiled_circuit, shots=1, memory=True).result()
        return result

## 2. The Algorithm Driver

This class manages the classical logic:
1.  Check if $N$ is even, prime, or a perfect power (classical shortcuts).
2.  Pick a random $a < N$.
3.  Check $\text{gcd}(a, N)$. If $>1$, we found a factor by luck!
4.  If not, run the quantum circuit to find period $r$.
5.  Use $r$ to calculate factors.

In [3]:
class ShorAlgorithm:
    def __init__(self, N, max_attempts=-1, random_coprime_only=False, simulator=None):
        self.N = N
        self.simulator = simulator
        self.max_attempts = max_attempts
        self.random_coprime_only = random_coprime_only
        self.chosen_a = None
        self.r = None
        self.qpe_circuit = None

    def execute(self):
        # 1. Sanity Checks
        is_N_invalid = self._is_N_invalid()
        if is_N_invalid: 
            print(f"[INFO] N={self.N} is trivially factorable: {is_N_invalid}")
            return is_N_invalid
        
        # 2. Generate valid 'a' candidates
        a_values = [a for a in range(2, self.N) if not self.random_coprime_only or (math.gcd(a, self.N) == 1)]
        print(f'[INFO] {len(a_values)} possible values of a: {a_values}')

        if not a_values: return None
        
        limit = len(a_values) if self.max_attempts <= -1 else min(self.max_attempts, len(a_values))
        attempts_count = 0

        # 3. Main Loop
        while attempts_count < limit:
            print(f'\n===== Attempt {attempts_count + 1} =====')
            attempts_count += 1
            
            self.chosen_a = random.choice(a_values)
            self.r = None
            print(f'[START] Chosen base a: {self.chosen_a}')
            
            # A. Classical GCD Shortcut (The "Lucky Guess")
            if not self.random_coprime_only:
                gcd = math.gcd(self.chosen_a, self.N)
                if gcd != 1:
                    factor2 = self.N // gcd
                    print(f'=> Lucky Guess! {self.chosen_a} shares a factor with {self.N}')
                    print(f'[SUCCESS] Found factors: {gcd} and {factor2}')
                    return gcd, factor2

            # B. Quantum Period Finding
            print(f'>>> {self.chosen_a} and {self.N} are coprime. Running Quantum Circuit...')
            success = self._quantum_period_finding()
            
            if not success:
                if self.chosen_a in a_values: a_values.remove(self.chosen_a)
                continue

            # C. Classical Post-Processing (Period -> Factors)
            factors = self._classical_postprocess()
            if factors: 
                return factors
            
            if self.chosen_a in a_values: a_values.remove(self.chosen_a)
            
        print(f'[FAIL] No factors found after {limit} attempts.')
        return None

    def _is_N_invalid(self):
        if self.N <= 3: return 1, self.N
        if self.N % 2 == 0: return 2, self.N // 2
        if sympy.isprime(self.N): return 1, self.N
        for k in range(int(math.log2(self.N)), 1, -1):
            p = round(self.N ** (1 / k))
            if p ** k == self.N: return p, k
        return False
    
    def _quantum_period_finding(self):
        self.qpe_circuit = ShorCircuit(self.chosen_a, self.N)
        
        try:
            result = self.qpe_circuit.run_simulation(self.simulator)
        except Exception as e:
            print(f"[ERR] Simulation failed: {e}")
            import traceback
            traceback.print_exc()
            return False

        state_bin = result.get_memory()[0]
        state_dec = int(state_bin, 2)
        
        # Convert measured integer to phase: phase = measured / 2^n_count
        phase = state_dec / (2 ** self.qpe_circuit.n_count)
        
        # Use continued fractions to find the rational approximation r/s close to phase
        self.r = Fraction(phase).limit_denominator(self.N).denominator
        
        print(f'   -> Measured: |{state_dec}⟩ (binary {state_bin})')
        print(f'   -> Phase: {phase:.4f}')
        print(f'   -> Estimated Period r: {self.r}')

        if self.r > self.N or self.r == 1:
            print(f'[WARN] Invalid period r={self.r}. Retrying...')
            return False
            
        # Verify if r is actually the period: a^r = 1 (mod N)
        if pow(self.chosen_a, self.r, self.N) != 1:
            print(f'[WARN] a^r != 1 (mod N) [a={self.chosen_a}, r={self.r}]. Phase estimation failed. Retrying...')
            return False

        return True

    def _classical_postprocess(self):
        if self.r % 2 != 0:
            print(f'[INFO] Period r={self.r} is odd. Cannot split to find factors.')
            return None

        guess_1 = pow(self.chosen_a, self.r // 2, self.N) - 1
        guess_2 = pow(self.chosen_a, self.r // 2, self.N) + 1
        
        factor1 = math.gcd(guess_1, self.N)
        factor2 = math.gcd(guess_2, self.N)
        
        if factor1 not in [1, self.N]: 
            print(f'[SUCCESS] Found factors: {factor1} and {self.N // factor1}')
            return factor1, self.N // factor1
        if factor2 not in [1, self.N]: 
            print(f'[SUCCESS] Found factors: {factor2} and {self.N // factor2}')
            return factor2, self.N // factor2
        return None

## 3. Running the Algorithm

In [None]:
if __name__ == "__main__":
    N = 143
    
    # Initialize the simulator
    simulator = AerSimulator()
    
    shor = ShorAlgorithm(N, simulator=simulator)
    factors = shor.execute()
    
    print(f"\nFinal Result: {factors}")
    
    # Visualize the circuit from the last attempt
    if shor.qpe_circuit:
        print("\nCircuit Diagram (Last Attempt):")
        display(shor.qpe_circuit.draw(output='mpl', fold=-1, style="iqp"))

[INFO] 141 possible values of a: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142]

===== Attempt 1 =====
[START] Chosen base a: 111
>>> 111 and 143 are coprime. Running Quantum Circuit...
   -> Measured: |54613⟩ (binary 1101010101010101)
   -> Phase: 0.8333
   -> Estimated Period r: 6
[WARN] a^r != 1 (mod N) [a=111, r=6]. Phase estimation failed. Retrying...

===== Attempt 2 =====
[START] Chosen base a: 81
>>> 81 and 143 are coprime. Running Qua