# Modular Exponentiation
## General Notes
    - Looking to solve X**Y mod N
    - Central subroutine of **Shor's period finding algorithm** used for integer factorization

## Goal
Implement a quantum circuit for modular exponentiation from scratch.

# Limitations
    - Unless stated otherwise, you are only allowed to use the gates X, CNOT, CCNOT and Multicontrolled-NOT. See Appendix A for an example of use of the Multicontrolled-NOT gate.

    - Can implement additional auxiliary functions as long as the functions below are implemented.

    - For each implemented function, please give evidence that the implementation is correct by:
        - initializing each input register with some number up to 4 bits, 
        - each auxiliary register with |0>
        - measuring the output register to verify if the value is as expected

    - Reuse qubits from auxiliary registers as much as possible.
        - It is crucial that auxiliary registers are equal to |0> both at the beginning and at the end of computation of each function

In [236]:
from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit_ibm_runtime import SamplerV2 as Sampler
from qiskit import QuantumCircuit
from qiskit.circuit.library import MCXGate
from qiskit.providers.basic_provider import BasicSimulator
from qiskit_aer import AerSimulator
from qiskit import transpile
import numpy as np

# 1.1 Initialization
The function `set_bits(circuit,A,X)` initializes the bits of register `A` with the binary string `X`.

For each i in `len(X)`, if `X[i]=1`, then the function applies the **X**-gate to `A[i]`.

Otherwise, it does nothing.

Assume `len(A)=len(X)`.

If `qubits = [2,4,3,7,5]` and `X = 01011`, the **X**-gate is applied to qubits 4, 7, and 5

In [237]:
def set_bits(circuit,A,X):

    # Width is determined by A
    w = len(A)
    
    # Check if X is read-in as a string or an integer
    # If X is written as int, X = 11
    if isinstance(X, int):

        # Check not negative
        if X < 0:
            raise ValueError("X must be a positive integer.")
        
        # Convert to binary
        X = bin(X)
        # Remove leading 0b
        X = X[2:]
        # Pad
        X = X.zfill(w)

    elif isinstance(X, str):
            
            # String handling
            if X.startswith("0b"):

                # X has already been wrapped by bin(), ie X = bin(11)
                # Strip 0b
                X = X[2:]
                # Pad
                X = X.zfill(w)
        
            else:

                # X is just the string "01011" or "1011"
                # and we are assuming that input for the decimal number "1,011" 
                # will jsut be written as X=1011, not its binary equivalent 1111110011
                X = X.zfill(w)

    else:

        raise TypeError("X must be an int or str")
    
    if len(X) > w:

        raise ValueError("Binary value doesn't fit in target register")

    # Enforce width
    X = X.zfill(w)

    # Reverse bit strong so A[0] is the least significant bit-index of the register
    X = X[::-1]

    # Apply X gate
    for i in range(w):
        
        if X[i]=="1":
            circuit.x(A[i])
        

    circuit.barrier()

    return circuit

In [238]:
# Check 1.1 initialization
circuit = QuantumCircuit(8, 0) 
A = [2,4,3,7,5] # qubits
# X = 01011, which is 11 in decimal
# X = bin(11) 
X = "1011"

# print(X)
# print(len(X))
# print(A)
# print(len(A))

set_bits(circuit,A,X)
print(circuit)

           ░ 
q_0: ──────░─
           ░ 
q_1: ──────░─
     ┌───┐ ░ 
q_2: ┤ X ├─░─
     └───┘ ░ 
q_3: ──────░─
     ┌───┐ ░ 
q_4: ┤ X ├─░─
     └───┘ ░ 
q_5: ──────░─
           ░ 
q_6: ──────░─
     ┌───┐ ░ 
q_7: ┤ X ├─░─
     └───┘ ░ 


# 1.2 Copy
The function `copy(circuit,A,B)` copies the binary string `bin(A)` to register B.

Assume that `len(A)=len(B)` and before application of function, B is initialized to |0>

**Hint: use CNOT gates**

In [239]:
def copy(circuit,A,B):

    # Copy binary sting bin(A) to register B using CNOT gates
    for i in range(len(A)):
        circuit.cx(A[i], B[i])

    return circuit

In [240]:
# Check 1.2 copy()
circuit = QuantumCircuit(8, 4)
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]

# Use set_bits() (testing different inputs just in case)
set_bits(circuit, A, 11)
# set_bits(circuit, A, "1011")
# set_bits(circuit, A, "01011")

# Copy A to B
copy(circuit,A,B)

print(A)
print(B)
print(circuit)

[0, 1, 2, 3]
[4, 5, 6, 7]
     ┌───┐ ░                     
q_0: ┤ X ├─░───■─────────────────
     ├───┤ ░   │                 
q_1: ┤ X ├─░───┼────■────────────
     └───┘ ░   │    │            
q_2: ──────░───┼────┼────■───────
     ┌───┐ ░   │    │    │       
q_3: ┤ X ├─░───┼────┼────┼────■──
     └───┘ ░ ┌─┴─┐  │    │    │  
q_4: ──────░─┤ X ├──┼────┼────┼──
           ░ └───┘┌─┴─┐  │    │  
q_5: ──────░──────┤ X ├──┼────┼──
           ░      └───┘┌─┴─┐  │  
q_6: ──────░───────────┤ X ├──┼──
           ░           └───┘┌─┴─┐
q_7: ──────░────────────────┤ X ├
           ░                └───┘
c: 4/════════════════════════════
                                 


# 1.3 Full Adder
The function `full_adder(circuit,a,b,r,c_in,c_out,AUX)` implements a full adder.

Registers:

`a` and `b` store the bits to the added

`c_in` stores the carry-in bit

`c_out` stores the carry-out bit

`r` stores the result of the sum

`AUX` is the auxiliary register

In [241]:
def full_adder(circuit,a,b,r,c_in,c_out,AUX):

    # Sum with XOR: r ^ (a ^ b ^ c_in)
    # r <- r ^ a
    circuit.cx(a, r)
    # r <- r ^ b
    circuit.cx(b, r)
    # r <- r ^ c_in
    circuit.cx(c_in, r)

    # Carry with CCNOT: c_out <- c_out ^ (a & b) ^ (a & c_in) ^ (b & c_in)
    circuit.ccx(a, b, c_out)
    circuit.ccx(a, c_in, c_out)
    circuit.ccx(b, c_in, c_out)

    # TODO: does this reset AUX to zero?

    return circuit

In [242]:
# 1.3 Full Adder check
circuit = QuantumCircuit(5, 2)

# Assign qubits
a = 0
b = 1
c_in = 2
r = 3
c_out = 4
AUX = []

# Initialize a = 1, b = 1, c_in = 0
circuit.x(a)
circuit.x(b)

# Run full adder
full_adder(circuit, a, b, r, c_in, c_out, AUX)

# Results
circuit.measure(r, 0)
circuit.measure(c_out, 1)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'10': 1024}
     ┌───┐                                    
q_0: ┤ X ├──■──────────────■────■─────────────
     ├───┤  │              │    │             
q_1: ┤ X ├──┼────■─────────■────┼───────■─────
     └───┘  │    │         │    │       │     
q_2: ───────┼────┼────■────┼────■───────■─────
          ┌─┴─┐┌─┴─┐┌─┴─┐  │    │  ┌─┐  │     
q_3: ─────┤ X ├┤ X ├┤ X ├──┼────┼──┤M├──┼─────
          └───┘└───┘└───┘┌─┴─┐┌─┴─┐└╥┘┌─┴─┐┌─┐
q_4: ────────────────────┤ X ├┤ X ├─╫─┤ X ├┤M├
                         └───┘└───┘ ║ └───┘└╥┘
c: 2/═══════════════════════════════╩═══════╩═
                                    0       1 


# 1.4 Addition
The function `add(circuit,A,B,R,AUX)` implements a circuit that adds `number(A)` to `number(B)` and stores the result at register `R`.

Assume `len(A)==len(B)==lent(R)`

The circuit is obtained by creating as cascade of `full_adder` circuits.

The carry bits are part of the auxiliary register AUX.

Note the carry-in bit of the first adder (from right to left) is set to 0.

In [243]:
def add(circuit,A,B,R,AUX):

    # R <- R ^ (number(A) + number(B))
    # AUX is where carry bits live

    # Sanity checks
    n = len(A)
    if len(B) != n or len(R) != n:
        raise ValueError("len(A), len(B), and len(R) must be the same length")
    if len(AUX) < n + 1:
        raise ValueError("AUX must have at least len(A) + 1 qubits.")

    # Computing sum
    for i in range(len(A)):

        c_in = AUX[i]
        c_out = AUX[i+1]

        full_adder(circuit, A[i], B[i], R[i], c_in, c_out, AUX)

    # Resetting AUX by doing reverse
    for i in range(n-1, -1, -1):
        
        c_in = AUX[i]
        c_out = AUX[i+1]

        circuit.ccx(B[i], c_in, c_out)
        circuit.ccx(A[i], c_in, c_out)
        circuit.ccx(A[i], B[i], c_out)

    return circuit

In [244]:
# 1.4 Check Addition
circuit = QuantumCircuit(17, 4)

# Number of 4-bit numbers
n = 4

# Registers
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]
R = [8, 9, 10, 11]
AUX = [12, 13, 14, 15, 16]

# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Compute addition
add(circuit, A, B, R, AUX)

for i in range(n):
    circuit.measure(R[i], i)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'1000': 1024}
      ┌───┐                                                                 »
 q_0: ┤ X ├────────────■────────────────────────────────────────────■───────»
      ├───┤            │                                            │       »
 q_1: ┤ X ├────────────┼────■───────────────────────────────────────┼────■──»
      └───┘            │    │                                       │    │  »
 q_2: ───────■─────────┼────┼────────────────────────■──────────────┼────┼──»
             │         │    │                        │              │    │  »
 q_3: ───────┼────■────┼────┼────────────────────────┼────■─────────┼────┼──»
      ┌───┐  │    │    │    │                        │    │         │    │  »
 q_4: ┤ X ├──┼────┼────┼────┼──────────────■─────────┼────┼─────────■────┼──»
      └───┘  │    │    │    │              │         │    │         │    │  »
 q_5: ───────┼────┼────┼────┼──────────────┼────■────┼────┼─────────┼────■──»
      ┌───┐  │    │    │    │              │    │

# 1.5 Subtraction
The function `subtract(circuit,A,B,R,AUX)` implements a circuit that subtracts `Number(B)` from `Number(A)` and stores the result in the register `R`.

Assume that `len(A)=len(B)=len(R)`.

Such a circuit can be obtained by negating each bit stores in B, and applying the adder circuit with the first carry-in bit set to 1 instead of 0.

In [245]:
def subtract(circuit,A,B,R,AUX):

    # R = Number(A) - Number(B)
    # Using:
    # R <- R ^ (A - B mod 2**n)
    # R <- A + !B + 1 (two's complement)
    
    # Negate B
    for i in range(len(B)):
        circuit.x(B[i])

    # Set AUX[0] to 1
    circuit.x(AUX[0])

    # Add negated B
    add(circuit, A, B, R, AUX)

    # Undo AUX[0] to 1 by doing the same transform
    circuit.x(AUX[0])

    # Unnegate B
    for i in range(len(B)):
        circuit.x(B[i])

    # AUX is reset to 0 within add()

    return circuit

In [246]:
# 1.5 Check subtraction
circuit = QuantumCircuit(17, 4)

n = 4

# Registers
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]
R = [8, 9, 10, 11]
AUX = [12, 13, 14, 15, 16]

# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Run subtraction
subtract(circuit, A, B, R, AUX)

# Restults
for i in range(n):
    circuit.measure(R[i], i)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'1110': 1024}
      ┌───┐                                                                 »
 q_0: ┤ X ├─────────────────■──────────────────────────────────■────────────»
      ├───┤                 │                                  │            »
 q_1: ┤ X ├─────────────────┼────■─────────────────────────────┼────■───────»
      └───┘                 │    │                             │    │       »
 q_2: ───────■──────────────┼────┼─────────────────────────────┼────┼────■──»
             │              │    │                             │    │    │  »
 q_3: ───────┼────■─────────┼────┼───────────────────■─────────┼────┼────┼──»
      ┌───┐  │    │  ┌───┐  │    │                   │         │    │    │  »
 q_4: ┤ X ├──┼────┼──┤ X ├──┼────┼────■──────────────┼─────────■────┼────┼──»
      ├───┤  │    │  └───┘  │    │    │              │         │    │    │  »
 q_5: ┤ X ├──┼────┼─────────┼────┼────┼────■─────────┼─────────┼────■────┼──»
      ├───┤  │    │  ┌───┐  │    │    │    │     

# 1.6 Comparison
The function `greater_or_eq(circuit,A,B,r,AUX)` implements a circuit that tests whether `number(A)` is greater than or equal to `number(B)`

Results is stored in register r.

Assume `len(A)=len(B)`

In [247]:
def greater_or_eq(circuit,A,B,r,AUX):

    # Assuming unsigned bits and that A and B don't change and AUX must stay 0
    # X is greater or equal to Y when the final carry out X-Y is 1? need to double check

    # Sanity checks
    # TODO: put variations of these in previous functions
    n = len(A)

    if len(B) != n:
        raise ValueError("len(A) must equal len(B)")
    
    if len(AUX) < n + 1:
        raise ValueError("AUX register too small.")
    
    # Allow for flexible size bits, previously only worked for 4
    # TODO: add to previous fns
    A = A[:n]
    B = B[:n]
    AUX = AUX[:n + 1]

    # Subtract A - B but re-written since `subtract(circuit, A, B, R, AUX)` uses list R instead of single qubit register r
    # Complement B
    for i in range(n):
        circuit.x(B[i])

    # Set intitial carry-in to 1
    circuit.x(AUX[0])

    # Going left to right, compute carry chain into AUX indices 1 to n
    for i in range(n):
        c_in = AUX[i]
        c_out = AUX[i + 1]

        # Carry with CCNOT: c_out <- c_out ^ (A[i] & !B[i]) ^ (A[i] & c_in) ^ (!B[i] & c_in)
        # Recall B has already been negated above, but I am leaving negation here jsut to look at the overall equaiton
        circuit.ccx(A[i], B[i], c_out)
        circuit.ccx(A[i], c_in, c_out)
        circuit.ccx(B[i], c_in, c_out)

    # Check the final carry out val and put in r
    circuit.cx(AUX[n], r)

    # Set AUX back to zero by going right to left (reverse)
    for i in range(n - 1, -1, -1):
        c_in = AUX[i]
        c_out = AUX[i + 1]

        circuit.ccx(B[i], c_in, c_out)
        circuit.ccx(A[i], c_in, c_out)
        circuit.ccx(A[i], B[i], c_out)

    circuit.x(AUX[0])

    # Reverse the complement of B
    for i in range(n):
        circuit.x(B[i])
    
    
    return circuit

In [248]:
# 1.6 Check greater or equal
circuit = QuantumCircuit(17, 1) # does classical bit count matter as much on the previous fns

n = 4

# Registers
A = [0, 1, 2, 3]
B = [4, 5, 6, 7]
# r is the only classical bit
r = 8
# recall len(AUX) = n + 1
AUX = [9, 10, 11, 12, 13]


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Run subtraction
greater_or_eq(circuit, A, B, r, AUX)

# Restults
circuit.measure(r, 0)

simulator = BasicSimulator()
result = simulator.run(circuit).result()
counts = result.get_counts()

print(counts)
print(circuit)

{'0': 1024}
      ┌───┐                                                                 »
 q_0: ┤ X ├─────────────────■─────────■─────────────────────────────────────»
      ├───┤                 │         │                                     »
 q_1: ┤ X ├───────■─────────┼─────────┼─────────■───────────────────────────»
      └───┘       │         │         │         │                           »
 q_2: ────────────┼─────────┼────■────┼─────────┼─────────■─────────────────»
                  │         │    │    │         │         │                 »
 q_3: ────────────┼────■────┼────┼────┼─────────┼─────────┼─────────■───────»
      ┌───┐┌───┐  │    │    │    │    │         │         │         │       »
 q_4: ┤ X ├┤ X ├──┼────┼────■────┼────┼────■────┼─────────┼─────────┼───────»
      ├───┤└───┘  │    │    │    │    │    │    │         │         │       »
 q_5: ┤ X ├───────■────┼────┼────┼────┼────┼────┼────■────┼─────────┼───────»
      ├───┤┌───┐  │    │    │    │    │    │    │   

# 1.7 Addition Modulo N
The fucntion `add_mod(circuit,N,A,B,R,aux)` implements a circuit that adds `number(A)` to `number(B)` modulo `number(N)`.

Result is stored at register `R`.

Assume `len(A)==len(B)==len(N)`, `number(A) < number(N)`, and `number(B) < number(N)`

**HINT**
Since the numbers are stored in `A` and `B` are smaller than `N`, the sum will be SMALLER than 2*N.

It is enough to 1. add both numbers, 2. test whether the result is great_or_eq to N 3. If it greater than N, subtract N from result.

Test can be done by applying controlled subtraction (control bit is the  output bit of the comparison function)

## Helper functions
Keeping things concise. Want to make sure that all carried bits are calculated correctly and also reversible so that AUX always returns to 0, regardless of A and B vals.

### Not controlled

In [249]:
# Bit-wise operations to be used in add_update
def carry_forward(circuit, a, b, aux_carry):

    # Goes one bit position at a time. 
    # Computes the information to the next bit byt oesn't finalize sum yet

    circuit.cx(aux_carry, b)
    circuit.cx(aux_carry, a)
    circuit.ccx(a, b, aux_carry)

def carry_backward(circuit, a, b, aux_carry):

    # Reverse order
    # Final sum written into target register & restores AUX
    circuit.ccx(a, b, aux_carry)
    circuit.cx(aux_carry, a)
    circuit.cx(a, b)
    
def add_update(circuit, A, B, aux_carry):

    # Full n-bit addtion s.t. B <- B + A mod 2**n
    
    n = len(A)
    if len(B) != n: raise ValueError("len(A) must equal len(B)")

    for i in range(n):
        carry_forward(circuit, A[i], B[i], aux_carry)

    for i in range(n - 1, -1, -1):
        carry_backward(circuit, A[i], B[i], aux_carry)

    return circuit
    

### Controlled
These only act if the control qubit = 1.

In [250]:
# Initialize multicontrolled NOT gate
num_controls = 3
mcx_gate = MCXGate(num_controls)

def carry_forward_controlled(circuit, a, b, aux_carry, flag_control):
    circuit.ccx(flag_control, aux_carry, b)
    circuit.ccx(flag_control, aux_carry, a)
    circuit.append(mcx_gate, [flag_control, a, b, aux_carry])

def carry_backward_controlled(circuit, a, b, aux_carry, flag_control):
    circuit.append(mcx_gate, [flag_control, a, b, aux_carry])
    circuit.ccx(flag_control, aux_carry, a)
    circuit.ccx(flag_control, a, b)
    
def add_update_controlled(circuit, A, B, aux_carry, flag_control):

    # Full n-bit addtion s.t. B <- B + A mod 2**n
    
    n = len(A)
    if len(B) != n: raise ValueError("len(A) must equal len(B)")

    for i in range(n):
        carry_forward_controlled(circuit, A[i], B[i], aux_carry, flag_control)

    for i in range(n - 1, -1, -1):
        carry_backward_controlled(circuit, A[i], B[i], aux_carry, flag_control)

    return circuit

def subtract_update_controlled(circuit, R, N, flag_compare, aux_carry):

    n = len(R)
    if len(N) != n: raise ValueError("len(R) must equal len(N)")

    # Temporarily negate N
    for i in range(n):
        circuit.cx(flag_compare, N[i])

    # Controlled +1 on the carry bit
    circuit.cx(flag_compare, aux_carry)

    # Controlled add_update
    # R <- R + !N + 1 == R - N (mod 2**n)
    add_update_controlled(circuit, N, R, aux_carry, flag_compare)

    # Undo controlled +1
    circuit.cx(flag_compare, aux_carry)

    # Unnegate N
    for i in range(n):
        circuit.cx(flag_compare, N[i])

    return circuit

In [251]:
# def add_mod(circuit, N, A, B, R, aux):

#     n = len(A)
#     if len(B) != n or len(N) != n or len(R) != n: raise ValueError("All registers must be the same length")
#     if len(aux) < (2*n + 2): raise ValueError("aux must have at least n + 2 qubits")

#     AUX = aux[: n + 1]
#     flag_compare = aux[n + 1]
#     aux_carry = AUX[0]

#     #  Temp register to store R = A + B (the unreduced sum)
#     temp_R = aux[n + 2: n + 2 +n]
    

#     # Implementing R <- (A + B) mod N
#     # First, R <- A + B add(circuit, A, B, R, AUX_carry)
#     # Sum: R[i] <- R[i] ^ A[i] ^ B[i] ^ c_in
#     # Carry with CCNOT: c_out <- c_out ^ (A[i] & B[i]) ^ (A[i] & c_in) ^ (B[i] & c_in)
#     add(circuit, A, B, R, AUX)

#     # Save copy of result to uncompute flag later
#     # temp_R <- temp_R ^ R
#     copy(circuit, R, temp_R)

#     # Then find flag from greater_or_eq(circuit, R, N, flag_compare, aux_compare)
#     # Carry with CCNOT: c_out <- c_out ^ (R[i] & !N[i]) ^ (R[i] & c_in) ^ (!N[i] & c_in)
#     # Comparison goes into flag_compare <- R >= N 
#     greater_or_eq(circuit, temp_R, N, flag_compare, AUX)

#     # If flag_compare == 1: R <- R - N
#     # Reduces based on modulo
#     subtract_update_controlled(circuit, R, N, flag_compare, aux_carry)

#     # Set flag_compare back to 0
#     # greater_or_eq(circuit, temp_R, N, flag_compare, AUX)

#     # Clear temp_R
#     # temp_R <- temp_R ^ (A + B) = 0
#     add(circuit, A, B, temp_R, AUX)

    
#     # Recompute temp_R so we can uncompute the flag
#     add(circuit, A, B, temp_R, AUX)

#     # Set flag back to 0 
#     greater_or_eq(circuit, temp_R, N, flag_compare, AUX)

#     # Uncompute temp_R
#     add(circuit, A, B, temp_R, AUX)
#     # copy(circuit, R, temp_R)

#     # Reapply reductoin to return to A+B mod N
#     # subtract_update_controlled(circuit, R, N, flag_compare, aux_carry)
 
#     return circuit

In [252]:
def add_mod(circuit, N, A, B, R, aux):

    n = len(A)
    if len(B) != n or len(N) != n or len(R) != n:
        raise ValueError("All registers must be the same length")

    # Needs: AUX (n+1) + flag (1) + temp_R (n)  = 2n + 2
    if len(aux) < (2*n + 2):
        raise ValueError("aux must have at least 2n + 2 qubits")

    AUX = aux[: n + 1]
    flag_compare = aux[n + 1]

    # We'll reuse AUX[0] as the single-bit carry for subtract_update_controlled
    aux_carry = AUX[0]

    # temp_R stores the unreduced sum (A+B) so we can uncompute flag_compare later
    temp_R = aux[n + 2 : n + 2 + n]

    # -------------------------------------------------------
    # 1) R <- R ^ (A + B)
    # -------------------------------------------------------
    add(circuit, A, B, R, AUX)

    # -------------------------------------------------------
    # 2) temp_R <- temp_R ^ R    (so temp_R = A + B)
    # -------------------------------------------------------
    copy(circuit, R, temp_R)

    # -------------------------------------------------------
    # 3) flag_compare <- flag_compare ^ [temp_R >= N]
    #    (uses unreduced sum, so predicate is correct)
    # -------------------------------------------------------
    greater_or_eq(circuit, temp_R, N, flag_compare, AUX)

    # -------------------------------------------------------
    # 4) If flag_compare == 1: R <- R - N
    # -------------------------------------------------------
    subtract_update_controlled(circuit, R, N, flag_compare, aux_carry)

    # -------------------------------------------------------
    # 5) Uncompute flag_compare back to 0 by rerunning the SAME comparison
    #    on the SAME inputs (temp_R and N).
    # -------------------------------------------------------
    greater_or_eq(circuit, temp_R, N, flag_compare, AUX)

    # -------------------------------------------------------
    # 6) Clear temp_R back to 0:
    #    temp_R currently equals (A+B), so XOR it with (A+B) again.
    # -------------------------------------------------------
    add(circuit, A, B, temp_R, AUX)

    return circuit


In [253]:
# 1.7 Check add_mod
n = 4

aux_len = (2*n + 2)
n_qubits = 4*n + aux_len
n_classical = n + aux_len

circuit = QuantumCircuit(n_qubits, n_classical)

# Registers
A = list(range(0, n))
B = list(range(n, 2*n))
N = list(range(2*n, 3*n))
R = list(range(3*n, 4*n))
aux = list(range(4*n, 4*n + aux_len))


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let B = 5 (0101)
circuit.x(B[0])
circuit.x(B[2])

# Let N = 7 (0111)
circuit.x(N[0])
circuit.x(N[1])
circuit.x(N[2])


# Run add mod
add_mod(circuit, N, A, B, R, aux)

# Results
for i in range(n):
    circuit.measure(R[i], i)

for j in range(aux_len):
    circuit.measure(aux[j], n + j)

# too man qubits for BasicSimulator
simulator = AerSimulator()
transpiled_circuit = transpile(circuit, simulator)
result = simulator.run(transpiled_circuit).result()
counts = result.get_counts()

print(counts)
# print(transpiled_circuit)

bitstring = next(iter(counts))
r_bits = bitstring[-n:]          # last 4 bits
aux_bits = bitstring[:-n]        # first 10 bits
print("aux:", aux_bits, "R:", r_bits)


{'00000000000001': 1024}
aux: 0000000000 R: 0001


# 1.8 Multiplication by Two Modulo N
The function `times_two_mod(circuit,N,A,R,AUX)` implements a circuit that doubles `number(A)` modulo `number(N)`

Result is stored in register `R`

**HINT** copy the register `A` and then compute `number(A) + number(A)`  modulo `number(N)`

In [254]:
def times_two_mod(circuit,N,A,R,AUX):

    # R <- R ^ (A + A) mod N
    # Sanity checks
    n = len(A)
    
    if len(N) != n or len(R) != n:
        raise ValueError("All registers must have same") 
    
    # Copy A into a temp variable
    # Splitting AUX
    temp_A = AUX[:n]
    add_aux = AUX[n:]

    # temp_A <- temp_A ^ A
    copy(circuit, A, temp_A)

    # Use add_mod()
    # R <- R ^ (A + R mod N)
    add_mod(circuit, N, A, temp_A, R, add_aux)

    # Uncompute temp_A
    copy(circuit, A, temp_A)
    
    return circuit

In [255]:
# 1.8 Check times_two_mod
n = 4

aux_len = (3*n + 2)
n_qubits = 3*n + aux_len
n_classical = n + aux_len

circuit = QuantumCircuit(n_qubits, n_classical)

# Registers
A = list(range(0, n))
N = list(range(n, 2*n))
R = list(range(2*n, 3*n))
AUX = list(range(3*n, 3*n + aux_len))


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let N = 7 (0111)
# circuit.x(N[0])
# circuit.x(N[1])
# circuit.x(N[2])

# Let N = 15 (1111)
for i in range(n):
    circuit.x(N[i])


# Run add mod
times_two_mod(circuit, N, A, R, AUX)

# Results
for i in range(n):
    circuit.measure(R[i], i)

for j in range(aux_len):
    circuit.measure(AUX[j], n + j)

# too man qubits for BasicSimulator
simulator = AerSimulator()
transpiled_circuit = transpile(circuit, simulator)
result = simulator.run(transpiled_circuit).result()
counts = result.get_counts()

print(counts)
# print(transpiled_circuit)

bitstring = next(iter(counts))
r_bits = bitstring[-n:]          # last 4 bits
aux_bits = bitstring[:-n]        # first 10 bits
print("aux:", aux_bits, "R:", r_bits)

{'000000000000000110': 1024}
aux: 00000000000000 R: 0110


# 1.9 Multiplication by a Power of Two Modulo N
The function `times_two_power_mod(circuit,N,A,k,R,AUX)` implements a circuit that multiplies `number(A)` by 2**k modulo

**HINT**: apply the function `times_two_mod(circuit,N,A,R,AUX)` k times in a row

In [None]:
# Original
# def times_two_power_mod(circuit, N, A, k, R, AUX):

#     # Sanity checks
#     if k < 0: raise ValueError("k must be positive")
#     if not isinstance(k, int): raise TypeError("k must be an integer")

#     # times_two_mod requires aux length 3n+2 & need k+1 registers inside AUX to store X0 thru Xk
#     n = len(A)
#     chain_len = (k + 1) * n
#     aux_two_len = 3*n + 2

#     # Split AUX into chain registers and scrach space used by times_two_mod
#     chain_bits = AUX[:chain_len]
#     aux_two_mod = AUX[chain_len : chain_len + aux_two_len]

#     # Build list of registers X_reg[0...k] each length n
#     X_regs = []
#     for i in range(k + 1):
#         X_regs.append(chain_bits[i*n : (i+1)*n])

#     # X0 <- A
#     copy(circuit, A, X_regs[0])

#     # Forward: k times_two_mod calls in a row
#     for i in range(k):
#         times_two_mod(circuit, N, X_regs[i], X_regs[i + 1], aux_two_mod)

#     # Put final value in R
#     # R <- R ^ Xk
#     copy(circuit, X_regs[k], R)

#     # Backward to uncompute Xk through X 1 back to 0
#     for i in range(k - 1, -1, -1):
#         times_two_mod(circuit, N, X_regs[i], X_regs[i + 1], aux_two_mod)

#     # Set X0 back to 0
#     copy(circuit, A, X_regs[0])
#     return circuit

In [264]:
def times_two_power_mod(circuit, N, A, k, R, AUX):

    # Sanity checks
    if k < 0: raise ValueError("k must be positive")
    if not isinstance(k, int): raise TypeError("k must be an integer")

    # times_two_mod requires aux length 3n+2 & need k+1 registers inside AUX to store X0 thru Xk
    n = len(A)
    chain_len = (k + 1) * n
    aux_two_len = 3*n + 2

    # Split AUX into chain registers and scrach space used by times_two_mod
    chain_bits = AUX[:chain_len]
    aux_two_mod = AUX[chain_len : chain_len + aux_two_len]

    # Build list of registers X_reg[0...k] each length n
    X_regs = []
    for i in range(k + 1):
        X_regs.append(chain_bits[i*n : (i+1)*n])

    # X0 <- A
    copy(circuit, A, X_regs[0])

    # New here
    # Local layout inside subcircuit:
    # [ N | Xin | Xout | aux_two_mod ]
    sub = QuantumCircuit(3*n + aux_two_len)

    qN    = list(range(0, n))
    qXin  = list(range(n, 2*n))
    qXout = list(range(2*n, 3*n))
    qAux  = list(range(3*n, 3*n + aux_two_len))

    # Call times two mod
    times_two_mod(sub, qN, qXin, qXout, qAux)

    # Turn into an Instruction so we can invert it later
    T2 = sub.to_instruction(label="times_two_mod")

    # Forward: k times_two_mod calls in a row
    for i in range(k):
        circuit.append(T2, N + X_regs[i] + X_regs[i + 1] + aux_two_mod)

    # Put final value in R
    # R <- R ^ Xk
    copy(circuit, X_regs[k], R)

    # Backward to uncompute Xk through X 1 back to 0
    for i in range(k - 1, -1, -1):
         circuit.append(T2.inverse(), N + X_regs[i] + X_regs[i + 1] + aux_two_mod)

    # Set X0 back to 0
    copy(circuit, A, X_regs[0])
    return circuit

In [265]:
# 1.9 Check times_two_power_mod
from qiskit.transpiler import CouplingMap

n = 4
k = 2

aux_len = (k + 1) * n + (3*n + 2)
n_qubits = 3*n + aux_len
n_classical = n + aux_len

circuit = QuantumCircuit(n_qubits, n_classical)

# Registers
A = list(range(0, n))
N = list(range(n, 2*n))
R = list(range(2*n, 3*n))
AUX = list(range(3*n, 3*n + aux_len))


# Let A = 3 (0011)
circuit.x(A[0])
circuit.x(A[1])

# Let N = 7 (0111)
circuit.x(N[0])
circuit.x(N[1])
circuit.x(N[2])


# Run times_two_power_mod()
times_two_power_mod(circuit, N, A, k, R, AUX)

# Results
for i in range(n):
    circuit.measure(R[i], i)

for j in range(aux_len):
    circuit.measure(AUX[j], n + j)

# too man qubits for BasicSimulator
simulator = AerSimulator(method="matrix_product_state")

nq = circuit.num_qubits
coupling_map = CouplingMap.from_full(nq)

transpiled_circuit = transpile(circuit, basis_gates=["u", "cx"], coupling_map=coupling_map, optimization_level=0) 
result = simulator.run(transpiled_circuit, shots=1024).result()
counts = result.get_counts()
print(counts)

# print(transpiled_circuit)

bitstring = next(iter(counts))
r_bits = bitstring[-n:]          # last 4 bits
aux_bits = bitstring[:-n]        # first 10 bits
print("aux:", aux_bits, "R:", r_bits)

{'000000000000000000000000000101': 1024}
aux: 00000000000000000000000000 R: 0101
