# Modular Exponentiation
## General Notes
    - Looking to solve X**Y mod N
    - Central subroutine of **Shor's period finding algorithm** used for integer factorization

## Goal
Implement a quantum circuit for modular exponentiation from scratch.

# Limitations
    - Unless stated otherwise, you are only allowed to use the gates X, CNOT, CCNOT and Multicontrolled-NOT. See Appendix A for an example of use of the Multicontrolled-NOT gate.

    - Can implement additional auxiliary functions as long as the functions below are implemented.

    - For each implemented function, please give evidence that the implementation is correct by:
        - initializing each input register with some number up to 4 bits, 
        - each auxiliary register with |0>
        - measuring the output register to verify if the value is as expected

    - Reuse qubits from auxiliary registers as much as possible.
        - It is crucial that auxiliary registers are equal to |0> both at the beginning and at the end of computation of each function

In [38]:
from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit_ibm_runtime import SamplerV2 as Sampler
from qiskit import QuantumCircuit
from qiskit.circuit.library import MCXGate
from qiskit.providers.basic_provider import BasicSimulator
from qiskit_aer import AerSimulator
from qiskit import transpile
import numpy as np

# 1.1 Initialization
The function `set_bits(circuit,A,X)` initializes the bits of register `A` with the binary string `X`.

For each i in `len(X)`, if `X[i]=1`, then the function applies the **X**-gate to `A[i]`.

Otherwise, it does nothing.

Assume `len(A)=len(X)`.

If `qubits = [2,4,3,7,5]` and `X = 01011`, the **X**-gate is applied to qubits 4, 7, and 5

In [16]:
def set_bits(circuit,A,X):

    # Width is determined by A
    w = len(A)
    
    # Check if X is read-in as a string or an integer
    # If X is written as int, X = 11
    if isinstance(X, int):

        # Check not negative
        if X < 0:
            raise ValueError("X must be a positive integer.")
        
        # Convert to binary
        X = bin(X)
        # Remove leading 0b
        X = X[2:]
        # Pad
        X = X.zfill(w)

    elif isinstance(X, str):
            
            # String handling
            if X.startswith("0b"):

                # X has already been wrapped by bin(), ie X = bin(11)
                # Strip 0b
                X = X[2:]
                # Pad
                X = X.zfill(w)
        
            else:

                # X is just the string "01011" or "1011"
                # and we are assuming that input for the decimal number "1,011" 
                # will jsut be written as X=1011, not its binary equivalent 1111110011
                X = X.zfill(w)

    else:

        raise TypeError("X must be an int or str")
    
    if len(X) > w:

        raise ValueError("Binary value doesn't fit in target register")

    # Enforce width
    X = X.zfill(w)

    # Apply X gate
    for i in range(w):
        
        if X[i]=="1":
            circuit.x(A[i])
        

    circuit.barrier()

    return circuit

In [17]:
# Check 1.1 initialization
circuit = QuantumCircuit(8, 0) 
A = [2,4,3,7,5] # qubits
# X = 01011, which is 11 in decimal
# X = bin(11) 
X = "1011"

# print(X)
# print(len(X))
# print(A)
# print(len(A))

set_bits(circuit,A,X)
print(circuit)

           ░ 
q_0: ──────░─
           ░ 
q_1: ──────░─
           ░ 
q_2: ──────░─
           ░ 
q_3: ──────░─
     ┌───┐ ░ 
q_4: ┤ X ├─░─
     ├───┤ ░ 
q_5: ┤ X ├─░─
     └───┘ ░ 
q_6: ──────░─
     ┌───┐ ░ 
q_7: ┤ X ├─░─
     └───┘ ░ 


# 1.2 Copy
The function `copy(circuit,A,B)` copies the binary string `bin(A)` to register B.

Assume that `len(A)=len(B)` and before application of function, B is initialized to |0>

**Hint: use CNOT gates**

In [18]:
def copy(circuit,A,B):

    # Copy binary sting bin(A) to register B using CNOT gates
    for i in range(len(A)):
        circuit.cx(A[i], B[i])

    return circuit

In [19]:
# Check 1.2 copy()
circuit = QuantumCircuit(8, 4)
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]

# Use set_bits() (testing different inputs just in case)
set_bits(circuit, A, 11)
# set_bits(circuit, A, "1011")
# set_bits(circuit, A, "01011")

# Copy A to B
copy(circuit,A,B)

print(A)
print(B)
print(circuit)

[0, 1, 2, 3]
[4, 5, 6, 7]
     ┌───┐ ░                     
q_0: ┤ X ├─░───■─────────────────
     └───┘ ░   │                 
q_1: ──────░───┼────■────────────
     ┌───┐ ░   │    │            
q_2: ┤ X ├─░───┼────┼────■───────
     ├───┤ ░   │    │    │       
q_3: ┤ X ├─░───┼────┼────┼────■──
     └───┘ ░ ┌─┴─┐  │    │    │  
q_4: ──────░─┤ X ├──┼────┼────┼──
           ░ └───┘┌─┴─┐  │    │  
q_5: ──────░──────┤ X ├──┼────┼──
           ░      └───┘┌─┴─┐  │  
q_6: ──────░───────────┤ X ├──┼──
           ░           └───┘┌─┴─┐
q_7: ──────░────────────────┤ X ├
           ░                └───┘
c: 4/════════════════════════════
                                 


# 1.3 Full Adder
The function `full_adder(circuit,a,b,r,c_in,c_out,AUX)` implements a full adder.

Registers:

`a` and `b` store the bits to the added

`c_in` stores the carry-in bit

`c_out` stores the carry-out bit

`r` stores the result of the sum

`AUX` is the auxiliary register

In [20]:
def full_adder(circuit,a,b,r,c_in,c_out,AUX):

    # Sum with XOR: r ^ (a ^ b ^ c_in)
    # r <- r ^ a
    circuit.cx(a, r)
    # r <- r ^ b
    circuit.cx(b, r)
    # r <- r ^ c_in
    circuit.cx(c_in, r)

    # Carry with CCNOT: c_out <- c_out ^ (a & b) ^ (a & c_in) ^ (b & c_in)
    circuit.ccx(a, b, c_out)
    circuit.ccx(a, c_in, c_out)
    circuit.ccx(b, c_in, c_out)

    # TODO: does this reset AUX to zero?

    return circuit

In [21]:
# 1.3 Full Adder check
circuit = QuantumCircuit(5, 2)

# Assign qubits
a = 0
b = 1
c_in = 2
r = 3
c_out = 4
AUX = []

# Initialize a = 1, b = 1, c_in = 0
circuit.x(a)
circuit.x(b)

# Run full adder
full_adder(circuit, a, b, r, c_in, c_out, AUX)

# Results
circuit.measure(r, 0)
circuit.measure(c_out, 1)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'10': 1024}
     ┌───┐                                    
q_0: ┤ X ├──■──────────────■────■─────────────
     ├───┤  │              │    │             
q_1: ┤ X ├──┼────■─────────■────┼───────■─────
     └───┘  │    │         │    │       │     
q_2: ───────┼────┼────■────┼────■───────■─────
          ┌─┴─┐┌─┴─┐┌─┴─┐  │    │  ┌─┐  │     
q_3: ─────┤ X ├┤ X ├┤ X ├──┼────┼──┤M├──┼─────
          └───┘└───┘└───┘┌─┴─┐┌─┴─┐└╥┘┌─┴─┐┌─┐
q_4: ────────────────────┤ X ├┤ X ├─╫─┤ X ├┤M├
                         └───┘└───┘ ║ └───┘└╥┘
c: 2/═══════════════════════════════╩═══════╩═
                                    0       1 


# 1.4 Addition
The function `add(circuit,A,B,R,AUX)` implements a circuit that adds `number(A)` to `number(B)` and stores the result at register `R`.

Assume `len(A)==len(B)==lent(R)`

The circuit is obtained by creating as cascade of `full_adder` circuits.

The carry bits are part of the auxiliary register AUX.

Note the carry-in bit of the first adder (from right to left) is set to 0.

In [22]:
def add(circuit,A,B,R,AUX):

    # R <- R ^ (number(A) + number(B))
    # AUX is where carry bits live

    # Sanity checks
    n = len(A)
    if len(B) != n or len(R) != n:
        raise ValueError("len(A), len(B), and len(R) must be the same length")
    if len(AUX) < n + 1:
        raise ValueError("AUX must have at least len(A) + 1 qubits.")

    # Computing sum
    for i in range(len(A)):

        c_in = AUX[i]
        c_out = AUX[i+1]

        full_adder(circuit, A[i], B[i], R[i], c_in, c_out, AUX)

    # Resetting AUX by doing reverse
    for i in range(n-1, -1, -1):
        
        c_in = AUX[i]
        c_out = AUX[i+1]

        circuit.ccx(B[i], c_in, c_out)
        circuit.ccx(A[i], c_in, c_out)
        circuit.ccx(A[i], B[i], c_out)

    return circuit

In [23]:
# 1.4 Check Addition
circuit = QuantumCircuit(17, 4)

# Number of 4-bit numbers
n = 4

# Registers
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]
R = [8, 9, 10, 11]
AUX = [12, 13, 14, 15, 16]

# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Compute addition
add(circuit, A, B, R, AUX)

for i in range(n):
    circuit.measure(R[i], i)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'1000': 1024}
      ┌───┐                                                                 »
 q_0: ┤ X ├────────────■────────────────────────────────────────────■───────»
      ├───┤            │                                            │       »
 q_1: ┤ X ├────────────┼────■───────────────────────────────────────┼────■──»
      └───┘            │    │                                       │    │  »
 q_2: ───────■─────────┼────┼────────────────────────■──────────────┼────┼──»
             │         │    │                        │              │    │  »
 q_3: ───────┼────■────┼────┼────────────────────────┼────■─────────┼────┼──»
      ┌───┐  │    │    │    │                        │    │         │    │  »
 q_4: ┤ X ├──┼────┼────┼────┼──────────────■─────────┼────┼─────────■────┼──»
      └───┘  │    │    │    │              │         │    │         │    │  »
 q_5: ───────┼────┼────┼────┼──────────────┼────■────┼────┼─────────┼────■──»
      ┌───┐  │    │    │    │              │    │

# 1.5 Subtraction
The function `subtract(circuit,A,B,R,AUX)` implements a circuit that subtracts `Number(B)` from `Number(A)` and stores the result in the register `R`.

Assume that `len(A)=len(B)=len(R)`.

Such a circuit can be obtained by negating each bit stores in B, and applying the adder circuit with the first carry-in bit set to 1 instead of 0.

In [24]:
def subtract(circuit,A,B,R,AUX):

    # R = Number(A) - Number(B)
    # Using:
    # R <- R ^ (A - B mod 2**n)
    # R <- A + !B + 1 (two's complement)
    
    # Negate B
    for i in range(len(B)):
        circuit.x(B[i])

    # Set AUX[0] to 1
    circuit.x(AUX[0])

    # Add negated B
    add(circuit, A, B, R, AUX)

    # Undo AUX[0] to 1 by doing the same transform
    circuit.x(AUX[0])

    # Unnegate B
    for i in range(len(B)):
        circuit.x(B[i])

    # AUX is reset to 0 within add()

    return circuit

In [25]:
# 1.5 Check subtraction
circuit = QuantumCircuit(17, 4)

n = 4

# Registers
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]
R = [8, 9, 10, 11]
AUX = [12, 13, 14, 15, 16]

# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Run subtraction
subtract(circuit, A, B, R, AUX)

# Restults
for i in range(n):
    circuit.measure(R[i], i)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'1110': 1024}
      ┌───┐                                                                 »
 q_0: ┤ X ├─────────────────■──────────────────────────────────■────────────»
      ├───┤                 │                                  │            »
 q_1: ┤ X ├─────────────────┼────■─────────────────────────────┼────■───────»
      └───┘                 │    │                             │    │       »
 q_2: ───────■──────────────┼────┼─────────────────────────────┼────┼────■──»
             │              │    │                             │    │    │  »
 q_3: ───────┼────■─────────┼────┼───────────────────■─────────┼────┼────┼──»
      ┌───┐  │    │  ┌───┐  │    │                   │         │    │    │  »
 q_4: ┤ X ├──┼────┼──┤ X ├──┼────┼────■──────────────┼─────────■────┼────┼──»
      ├───┤  │    │  └───┘  │    │    │              │         │    │    │  »
 q_5: ┤ X ├──┼────┼─────────┼────┼────┼────■─────────┼─────────┼────■────┼──»
      ├───┤  │    │  ┌───┐  │    │    │    │     

# 1.6 Comparison
The function `greater_or_eq(circuit,A,B,r,AUX)` implements a circuit that tests whether `number(A)` is greater than or equal to `number(B)`

Results is stored in register r.

Assume `len(A)=len(B)`

In [26]:
def greater_or_eq(circuit,A,B,r,AUX):

    # Assuming unsigned bits and that A and B don't change and AUX must stay 0
    # X is greater or equal to Y when the final carry out X-Y is 1? need to double check

    # Sanity checks
    # TODO: put variations of these in previous functions
    n = len(A)

    if len(B) != n:
        raise ValueError("len(A) must equal len(B)")
    
    if len(AUX) < n + 1:
        raise ValueError("AUX register too small.")
    
    # Allow for flexible size bits, previously only worked for 4
    # TODO: add to previous fns
    A = A[:n]
    B = B[:n]
    AUX = AUX[:n + 1]

    # Subtract A - B but re-written since `subtract(circuit, A, B, R, AUX)` uses list R instead of single qubit register r
    # Complement B
    for i in range(n):
        circuit.x(B[i])

    # Set intitial carry-in to 1
    circuit.x(AUX[0])

    # Going left to right, compute carry chain into AUX indices 1 to n
    for i in range(n):
        c_in = AUX[i]
        c_out = AUX[i + 1]

        # Carry with CCNOT: c_out <- c_out ^ (A[i] & !B[i]) ^ (A[i] & c_in) ^ (!B[i] & c_in)
        # Recall B has already been negated above, but I am leaving negation here jsut to look at the overall equaiton
        circuit.ccx(A[i], B[i], c_out)
        circuit.ccx(A[i], c_in, c_out)
        circuit.ccx(B[i], c_in, c_out)

    # Check the final carry out val and put in r
    circuit.cx(AUX[n], r)

    # Set AUX back to zero by going right to left (reverse)
    for i in range(n - 1, -1, -1):
        c_in = AUX[i]
        c_out = AUX[i + 1]

        circuit.ccx(B[i], c_in, c_out)
        circuit.ccx(A[i], c_in, c_out)
        circuit.ccx(A[i], B[i], c_out)

    circuit.x(AUX[0])

    # Reverse the complement of B
    for i in range(n):
        circuit.x(B[i])
    
    
    return circuit

In [27]:
# 1.6 Check greater or equal
circuit = QuantumCircuit(17, 1) # does classical bit count matter as much on the previous fns

n = 4

# Registers
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]
# r is the only classical bit
r = 8
# recall len(AUX) = n + 1
AUX = [9, 10, 11, 12, 13]


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Run subtraction
greater_or_eq(circuit, A, B, r, AUX)

# Restults
circuit.measure(r, 0)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'0': 1024}
      ┌───┐                                                                 »
 q_0: ┤ X ├─────────────────■─────────■─────────────────────────────────────»
      ├───┤                 │         │                                     »
 q_1: ┤ X ├───────■─────────┼─────────┼─────────■───────────────────────────»
      └───┘       │         │         │         │                           »
 q_2: ────────────┼─────────┼────■────┼─────────┼─────────■─────────────────»
                  │         │    │    │         │         │                 »
 q_3: ────────────┼────■────┼────┼────┼─────────┼─────────┼─────────■───────»
      ┌───┐┌───┐  │    │    │    │    │         │         │         │       »
 q_4: ┤ X ├┤ X ├──┼────┼────■────┼────┼────■────┼─────────┼─────────┼───────»
      ├───┤└───┘  │    │    │    │    │    │    │         │         │       »
 q_5: ┤ X ├───────■────┼────┼────┼────┼────┼────┼────■────┼─────────┼───────»
      ├───┤┌───┐  │    │    │    │    │    │    │   

# 1.7 Addition Modulo N
The fucntion `add_mod(circuit,N,A,B,R,aux)` implements a circuit that adds `number(A)` to `number(B)` modulo `number(N)`.

Result is stored at register `R`.

Assume `len(A)==len(B)==len(N)`, `number(A) < number(N)`, and `number(B) < number(N)`

**HINT**
Since the numbers are stored in `A` and `B` are smaller than `N`, the sum will be SMALLER than 2*N.

It is enough to 1. add both numbers, 2. test whether the result is great_or_eq to N 3. If it greater than N, subtract N from result.

Test can be done by applying controlled subtraction (control bit is the  output bit of the comparison function)

In [None]:
def add_mod(circuit, N, A, B, R, aux):

    
 
    return circuit

In [60]:
# 1.7 Check add_mod
n = 4

aux_len = (2*n + 2)
n_qubits = 4*n + aux_len
n_classical = n + aux_len

circuit = QuantumCircuit(n_qubits, n_classical)

# Registers
A = list(range(0, n))
B = list(range(n, 2*n))
N = list(range(2*n, 3*n))
R = list(range(3*n, 4*n))
aux = list(range(4*n, 4*n + aux_len))


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Let N = 7 (0111)
circuit.x(N[0])
circuit.x(N[1])
circuit.x(N[2])


# Run add mod
add_mod(circuit, N, A, B, R, aux)

# Results
for i in range(n):
    circuit.measure(R[i], i)

for j in range(aux_len):
    circuit.measure(aux[j], n + j)

# too man qubits for BasicSimulator
simulator = AerSimulator()
transpiled_circuit = transpile(circuit, simulator)
result = simulator.run(transpiled_circuit).result()
counts = result.get_counts()

print(counts)
# print(transpiled_circuit)

# bitstring = next(iter(counts))
# r_bits = bitstring[-n:]          # last 4 bits
# aux_bits = bitstring[:-n]        # first 10 bits
# print("aux:", aux_bits, "R:", r_bits)


{'00000000000001': 1024}


# 1.8 Multiplication by Two Modulo N
The function `times_two_mod(circuit,N,A,R,AUX)` implements a circuit that doubles `number(A)` modulo `number(N)`

Result is stored in register `R`

**HINT** copy the register `A` and then compute `number(A) + number(A)`  modulo `number(N)`

In [None]:
def times_two_mod(circuit,N,A,R,AUX):

    
    
    return circuit

In [70]:
# 1.8 Check times_two_mod
n = 4

aux_len = (2*n + 2)
n_qubits = 3*n + aux_len
n_classical = n + aux_len

circuit = QuantumCircuit(n_qubits, n_classical)

# Registers
A = list(range(0, n))
N = list(range(n, 2*n))
R = list(range(2*n, 3*n))
AUX = list(range(3*n, 3*n + aux_len))


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let N = 7 (0111)
# circuit.x(N[0])
# circuit.x(N[1])
# circuit.x(N[2])

# Let N = 15 (1111)
for i in range(n):
    circuit.x(N[i])


# Run add mod
times_two_mod(circuit, N, A, R, AUX)

# Results
for i in range(n):
    circuit.measure(R[i], i)

for j in range(aux_len):
    circuit.measure(AUX[j], n + j)

# too man qubits for BasicSimulator
simulator = AerSimulator()
transpiled_circuit = transpile(circuit, simulator)
result = simulator.run(transpiled_circuit).result()
counts = result.get_counts()

print(counts)
# print(transpiled_circuit)

bitstring = next(iter(counts))
r_bits = bitstring[-n:]          # last 4 bits
aux_bits = bitstring[:-n]        # first 10 bits
print("aux:", aux_bits, "R:", r_bits)

{'00000000000101': 1024}
aux: 0000000000 R: 0101


# 1.9 Multiplication by a Power of Two Modulo N
The function `times_two_power_mod(circuit,N,A,k,R,AUX)` implements a circuit that multiplies `number(A)` by 2**k modulo

**HINT**: apply the function `times_two_mod(circuit,N,A,R,AUX)` k times in a row

In [None]:
def times_two_power_mod(circuit, N, A, k, R, AUX):

    

    return circuit

In [66]:
# 1.8 Check times_two_power_mod
n = 4
k = 2

aux_len = (2*n + 2)
n_qubits = 3*n + aux_len
n_classical = n + aux_len

circuit = QuantumCircuit(n_qubits, n_classical)

# Registers
A = list(range(0, n))
N = list(range(n, 2*n))
R = list(range(2*n, 3*n))
AUX = list(range(3*n, 3*n + aux_len))


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let N = 7 (0111)
circuit.x(N[0])
circuit.x(N[1])
circuit.x(N[2])


# Run times_two_power_mod()
times_two_power_mod(circuit, N, A, k, R, AUX)

# Results
for i in range(n):
    circuit.measure(R[i], i)

for j in range(aux_len):
    circuit.measure(AUX[j], n + j)

# too man qubits for BasicSimulator
simulator = AerSimulator()
transpiled_circuit = transpile(circuit, simulator)
result = simulator.run(transpiled_circuit).result()
counts = result.get_counts()

print(counts)
# print(transpiled_circuit)

bitstring = next(iter(counts))
r_bits = bitstring[-n:]          # last 4 bits
aux_bits = bitstring[:-n]        # first 10 bits
print("aux:", aux_bits, "R:", r_bits)

{'10010000000100': 1024}
aux: 1001000000 R: 0100
