## Quantum Programming Project - Modular Exponentiation

Project-03

- Lea Jesenkovic
- Lars Herbold

In [None]:
from qiskit import QuantumCircuit
from logic_gates import LogicGatesExtension    

### MONKEY PATCHING ###
QuantumCircuit.and_gate = LogicGatesExtension.apply_and
QuantumCircuit.or_gate = LogicGatesExtension.apply_or
QuantumCircuit.xor_gate = LogicGatesExtension.apply_xor
QuantumCircuit.controlled_and_gate = LogicGatesExtension.apply_controlled_and
QuantumCircuit.controlled_or_gate = LogicGatesExtension.apply_controlled_or
QuantumCircuit.controlled_xor_gate = LogicGatesExtension.apply_controlled_xor

def set_bits(circuit, A, X):
    for i, bit in enumerate(X):
        if bit == '1':
            circuit.x(A[i])

def copy(circuit, A, B):
    for i in range(len(A)):
        circuit.cx(A[i], B[i])

def controlled_copy(circuit, c, A, B):
    for i in range(len(A)):
        circuit.ccx(c, A[i], B[i])

def controlled_not(circuit, c, q):
    circuit.cx(c, q)


def full_adder(circuit, a, b, r, c_in, c_out, AUX):

    circuit.xor_gate(a, b, AUX[2])

    circuit.xor_gate(AUX[2], c_in, r)

    # c_out parts
    circuit.and_gate(AUX[2], c_in, AUX[0])
    circuit.and_gate(a, b, AUX[1])
    circuit.or_gate(AUX[0], AUX[1], c_out)

    # CLEANUP 
    circuit.and_gate(a, b, AUX[1])
    circuit.and_gate(AUX[2], c_in, AUX[0])
    circuit.xor_gate(a, b, AUX[2])

def controlled_full_adder(circuit, c, a, b, r, c_in, c_out, AUX):

    u0 = AUX[0]
    u1 = AUX[1]
    t  = AUX[2]

    circuit.xor_gate(a, b, t)

    circuit.ccx(c, t, r)
    circuit.ccx(c, c_in, r)

    circuit.controlled_and_gate(c, t, c_in, u0)
    circuit.controlled_and_gate(c, a, b, u1)
    circuit.controlled_or_gate(c, u0, u1, c_out)

    # CLEANUP
    circuit.controlled_and_gate(c, a, b, u1)
    circuit.controlled_and_gate(c, t, c_in, u0)
    circuit.xor_gate(a, b, t)


def add(circuit, A, B, R, AUX):
  
    n = len(A)
    carries = AUX[:n+1] 
    fa_internal_aux = AUX[n+1:n+1+3] 

    for i in range(n):
        full_adder(
            circuit,
            A[i], 
            B[i], 
            R[i], 
            carries[i], 
            carries[i+1], 
            fa_internal_aux
        )

def controlled_add(circuit, c, A, B, R, AUX):
   
    n = len(A)
    carries = AUX[:n+1]
    fa_aux  = AUX[n+1:n+1+3]

    for i in range(n):
        controlled_full_adder(circuit, c, A[i], B[i], R[i], carries[i], carries[i+1], fa_aux)


def subtract(circuit, A, B, R, AUX):
 
    n = len(A)
    carries = AUX[:n+1]
    fa_internal_aux = AUX[n+1:n+1+3]

    for b_qubit in B:
        circuit.x(b_qubit)

    circuit.x(carries[0])

    for i in range(n):
        full_adder(
            circuit,
            A[i], 
            B[i], 
            R[i], 
            carries[i], 
            carries[i+1], 
            fa_internal_aux
        )

    # CLEANUP
    circuit.x(carries[0])
    for b_qubit in B:
        circuit.x(b_qubit)

def controlled_subtract(circuit, c, A, B, R, AUX):

    n = len(A)
    carries = AUX[:n+1]
    fa_aux  = AUX[n+1:n+1+3]

    # Controlled negate B (only if c=1)
    for qb in B:
        controlled_not(circuit, c, qb)

    # Controlled set carry-in = 1 (only if c=1)
    controlled_not(circuit, c, carries[0])

    for i in range(n):
        controlled_full_adder(circuit, c, A[i], B[i], R[i], carries[i], carries[i+1], fa_aux)

    # Cleanup
    controlled_not(circuit, c, carries[0])
    for qb in B:
        controlled_not(circuit, c, qb)

def greater_or_eq(circuit, A, B, r, AUX):

    n = len(A)
    temp_R = AUX[n+4 : n+4+n] 
    
    # Perform subtraction
    subtract(circuit, A, B, temp_R, AUX)
    
    # The final carry bit (AUX[n]) determines the result 
    circuit.cx(AUX[n], r)
    
    # CLEANUP
    subtract(circuit, A, B, temp_R, AUX)

def controlled_greater_or_eq(circuit, c, A, B, r, AUX):

    tmp_flag = AUX[0]
    rest = AUX[1:]

    greater_or_eq(circuit, A, B, tmp_flag, rest)
    circuit.ccx(c, tmp_flag, r)
    greater_or_eq(circuit, A, B, tmp_flag, rest)

# TEST greater_or_eq
# qc = QuantumCircuit(4 + 40)
# A = [0,1]
# B = [2,3]
# r = 4
# AUX = list(range(5, 45))

# set_bits(qc, A, "10")
# set_bits(qc, B, "01")

# greater_or_eq(qc, A, B, r, AUX)
# print(qc.draw(fold=120))

def add_mod(circuit, N, A, B, R, AUX):

    n = len(A)

    SUM  = AUX[0:n]           # will hold S = A + B
    DIFF = AUX[n:2*n]         # will hold D = S - N
    flag = AUX[2*n]           # 1 if S >= N else 0
    aux_rest = AUX[2*n+1:]    # scratch for add/sub/greater_or_eq

    # helpers: controlled copy using Toffoli 
    def controlled_copy(ctrl, SRC, DST):
        for i in range(len(SRC)):
            circuit.ccx(ctrl, SRC[i], DST[i])

    # SUM = A + B
    add(circuit, A, B, SUM, aux_rest)

    # flag = (SUM >= N)
    greater_or_eq(circuit, SUM, N, flag, aux_rest)

    # DIFF = SUM - N
    subtract(circuit, SUM, N, DIFF, aux_rest)

    # Write result to R:
    # if flag==0 -> R = SUM
    # if flag==1 -> R = DIFF
    
    circuit.x(flag)                 
    controlled_copy(flag, SUM, R)  
    circuit.x(flag)                

    controlled_copy(flag, DIFF, R)

    # Uncompute DIFF back to 0
    subtract(circuit, SUM, N, DIFF, aux_rest)

    # Uncompute flag back to 0
    greater_or_eq(circuit, SUM, N, flag, aux_rest)

    # Uncompute SUM back to 0
    add(circuit, A, B, SUM, aux_rest)

def controlled_add_mod(circuit, c, N, A, B, R, AUX):

    n = len(A)
    tmp  = AUX[:n]
    rest = AUX[n:]

    add_mod(circuit, N, A, B, tmp, rest)
    controlled_copy(circuit, c, tmp, R)
    add_mod(circuit, N, A, B, tmp, rest)

def times_two_mod(circuit, N, A, R, AUX):

    n = len(A)
    A_copy = AUX[:n]
    aux_rest = AUX[n:]

    copy(circuit, A, A_copy)
    add_mod(circuit, N, A, A_copy, R, aux_rest)
    copy(circuit, A, A_copy)

def controlled_times_two_mod(circuit, c, N, A, R, AUX):

    n = len(A)
    tmp  = AUX[:n]
    rest = AUX[n:]

    times_two_mod(circuit, N, A, tmp, rest)
    controlled_copy(circuit, c, tmp, R)
    times_two_mod(circuit, N, A, tmp, rest)

# TEST times_two_mod
# qc = QuantumCircuit(8 + 30)

# N   = [0, 1]
# A   = [2, 3]
# R   = [4, 5]
# AUX = list(range(6, 36))

# set_bits(qc, N, "11")   # N = 3
# set_bits(qc, A, "10")   # A = 2

# times_two_mod(qc, N, A, R, AUX)

# print(qc.draw())

def times_two_power_mod(circuit, N, A, k, R, AUX):

    n = len(A)
    blocks_needed = (k + 1) * n
    V_all = AUX[:blocks_needed]
    scratch = AUX[blocks_needed:]

    V = [V_all[i*n:(i+1)*n] for i in range(k+1)]

    copy(circuit, A, V[0])

    for i in range(1, k+1):
        times_two_mod(circuit, N, V[i-1], V[i], scratch)

    copy(circuit, V[k], R)

    for i in range(k, 0, -1):
        times_two_mod(circuit, N, V[i-1], V[i], scratch)

    copy(circuit, A, V[0])

def controlled_times_two_power_mod(circuit, c, N, A, k, R, AUX):

    n = len(A)
    tmp  = AUX[:n]
    rest = AUX[n:]

    times_two_power_mod(circuit, N, A, k, tmp, rest)
    controlled_copy(circuit, c, tmp, R)
    times_two_power_mod(circuit, N, A, k, tmp, rest)


# TEST times_two_power_mod
# qc = QuantumCircuit(8 + 60)   # 8 data qubits + lots of AUX

# N   = [0, 1]
# A   = [2, 3]
# R   = [4, 5]
# AUX = list(range(6, 68))      # 62 AUX qubits

# set_bits(qc, N, "11")   # N = 3
# set_bits(qc, A, "01")   # A = 1

# times_two_power_mod(qc, N, A, k=2, R=R, AUX=AUX)

# print(qc.draw(fold=120))

def multiply_mod(circuit, N, A, B, R, AUX):

    n = len(A)
    needed = (n + 1) * n + n
    if len(AUX) < needed:
        raise ValueError(f"multiply_mod: AUX too small, need at least {needed}, got {len(AUX)}")

    S_all = AUX[: (n + 1) * n]
    T = AUX[(n + 1) * n : (n + 2) * n]
    scratch = AUX[(n + 2) * n :]

    S = [S_all[i*n:(i+1)*n] for i in range(n + 1)]

    # Forward: build partial sums
    for k in range(n):
        controlled_times_two_power_mod(circuit, B[k], N, A, k, T, scratch)

        add_mod(circuit, N, S[k], T, S[k+1], scratch)

        # Uncompute T back to 0
        controlled_times_two_power_mod(circuit, B[k], N, A, k, T, scratch)

    # Output
    copy(circuit, S[n], R)

    # Uncompute partial sums to clean AUX
    for k in range(n - 1, -1, -1):
        controlled_times_two_power_mod(circuit, B[k], N, A, k, T, scratch)
        add_mod(circuit, N, S[k], T, S[k+1], scratch)  # uncompute S[k+1]
        controlled_times_two_power_mod(circuit, B[k], N, A, k, T, scratch)


def multiply_mod_fixed(circuit, N, X, B, AUX, N_int):

    n = len(B)
    X_inv = pow(X, -1, N_int)
    A1 = AUX[0:n]
    A2 = AUX[n:2*n]
    R  = AUX[2*n:3*n]
    scratch = AUX[3*n:]

    def swap_regs(Xreg, Yreg):
        for i in range(n):
            circuit.cx(Xreg[i], Yreg[i])
            circuit.cx(Yreg[i], Xreg[i])
            circuit.cx(Xreg[i], Yreg[i])

    # Load constants into A1 and A2 (then we'll un-load them at end)
    set_bits(circuit, A1, format(X % N_int, f"0{n}b"))
    set_bits(circuit, A2, format(X_inv, f"0{n}b"))

    # Compute R ^= (X * B) mod N   (R starts |0> so R becomes the product)
    multiply_mod(circuit, N, A1, B, R, scratch)

    # Swap: put product into B (in place)
    swap_regs(B, R)     # now B = X*B mod N, and R = old B

    # Erase old B from R using inverse
    multiply_mod(circuit, N, A2, B, R, scratch)

    # Unload constants (return A1 and A2 to |0>)
    set_bits(circuit, A2, format(X_inv, f"0{n}b"))
    set_bits(circuit, A1, format(X % N_int, f"0{n}b"))

# TEST multiply_mod_fixed
# qc = QuantumCircuit(2 + 2 + 80)  # example n=2 with lots of AUX
# N   = [0,1]
# B   = [2,3]
# AUX = list(range(4, 4+80))

# set_bits(qc, N, "11")  # N=3
# set_bits(qc, B, "10")  # B=2

# multiply_mod_fixed(qc, N, X=2, B=B, AUX=AUX, N_int=3)

# print(qc.draw(fold=120))

def multiply_mod_fixed_power_2_k(circuit, N, X, B, AUX, k, N_int):
    
    W = X % N_int
    for _ in range(k):
        W = (W * W) % N_int

    multiply_mod_fixed(circuit, N, W, B, AUX, N_int)

def multiply_mod_fixed_power_Y(circuit, N, X, B, AUX, Y, N_int):
   
    n = len(Y)
    B_tmp = AUX[:n]
    work = AUX[n:]

    def swap_regs(Xreg, Yreg):
        for i in range(n):
            circuit.cx(Xreg[i], Yreg[i])
            circuit.cx(Yreg[i], Xreg[i])
            circuit.cx(Xreg[i], Yreg[i])

    def controlled_swap(ctrl, Xreg, Yreg):
        for i in range(n):
            circuit.ccx(ctrl, Xreg[i], Yreg[i])
            circuit.ccx(ctrl, Yreg[i], Xreg[i])
            circuit.ccx(ctrl, Xreg[i], Yreg[i])

    W = X % N_int

    for k in range(n):

        copy(circuit, B, B_tmp)

        multiply_mod_fixed(circuit, N, W, B_tmp, work, N_int)

        controlled_swap(Y[k], B, B_tmp)

        multiply_mod_fixed(circuit, N, W, B_tmp, work, N_int)

        copy(circuit, B, B_tmp)

        W = (W * W) % N_int













      ┌───┐┌───┐                                                                                                    »
 q_0: ┤ X ├┤ X ├────────────────────────────────────────────────────────────────────────────────────────────────────»
      ├───┤├───┤                                                                                                    »
 q_1: ┤ X ├┤ X ├────────────────────────────────────────────────────────────────────────────────────────────────────»
      ├───┤└───┘                                                                                                    »
 q_2: ┤ X ├────────────────────────────────■─────────■──────────────────────────────────────────────────────────────»
      └───┘                                │         │                                                              »
 q_3: ─────────────────────────────────────┼─────────┼──────────────────────────────────────────────────────────────»
      ┌───┐                                │         │  