In [2]:
import pennylane as qml
from pennylane import numpy as np

# --- Compute modular multiplication permutation ---
def modexp_perm(m, modulus=15, base=2, nbits=4):
    return {a: (pow(base, m, modulus) * a) % modulus for a in range(2**nbits)}

def cycles_from_perm(perm):
    visited = set()
    cycles = []
    for start in range(len(perm)):
        if start in visited or perm[start] == start:
            continue
        cycle = []
        current = start
        while current not in visited:
            visited.add(current)
            cycle.append(current)
            current = perm[current]
        if len(cycle) > 1:
            cycles.append(cycle)
    return cycles

# --- Decompose multi-controlled X using ancillas ---
def multi_controlled_x(ctrls, target, ancillas):
    n = len(ctrls)
    if n == 1:
        qml.CNOT(wires=[ctrls[0], target])
    elif n == 2:
        qml.Toffoli(wires=[ctrls[0], ctrls[1], target])
    else:
        qml.Toffoli(wires=[ctrls[0], ctrls[1], ancillas[0]])
        for i in range(2, n - 1):
            qml.Toffoli(wires=[ctrls[i], ancillas[i - 2], ancillas[i - 1]])
        qml.Toffoli(wires=[ctrls[n - 1], ancillas[n - 3], target])
        for i in reversed(range(2, n - 1)):
            qml.Toffoli(wires=[ctrls[i], ancillas[i - 2], ancillas[i - 1]])
        qml.Toffoli(wires=[ctrls[0], ctrls[1], ancillas[0]])

# --- Controlled basis swap via explicit gates ---
def controlled_basis_swap_no_ctrl(i, j, target_wires, control_wire, ancillas):
    n = len(target_wires)
    bi = [int(b) for b in format(i, f"0{n}b")]
    bj = [int(b) for b in format(j, f"0{n}b")]

    for idx in range(n):
        if bi[idx] != bj[idx]:
            ctrl_bits = [target_wires[k] for k in range(n) if k != idx]
            ctrl_vals = [bi[k] for k in range(n) if k != idx]

            for w, b in zip(ctrl_bits, ctrl_vals):
                if b == 0:
                    qml.PauliX(w)

            all_ctrls = [control_wire] + ctrl_bits
            multi_controlled_x(all_ctrls, target_wires[idx], ancillas)

            for w, b in zip(ctrl_bits, ctrl_vals):
                if b == 0:
                    qml.PauliX(w)

# --- Controlled U^m implementation using swaps ---
def controlled_modular_multiplication_no_ctrl(m, control_wire, target_wires, ancillas):
    perm = modexp_perm(m=m, modulus=15, base=2, nbits=len(target_wires))
    cycles = cycles_from_perm(perm)
    for cycle in cycles:
        for i in reversed(range(1, len(cycle))):
            controlled_basis_swap_no_ctrl(cycle[0], cycle[i], target_wires, control_wire, ancillas)
    qml.Barrier(wires=[control_wire] + target_wires + ancillas)


In [3]:
@qml.qnode(qml.device("default.qubit", wires=5))
def demo_circuit():
    qml.Hadamard(wires=0)  # control qubit
    qml.BasisState(np.array([0, 0, 0, 1]), wires=[1, 2, 3, 4])  # initial |1⟩
    controlled_modular_multiplication_no_ctrl(m=3, control_wire=0, target_wires=[1, 2, 3, 4])
    return qml.state()

In [4]:
def verify_controlled_modular_multiplication(m):
    wires = list(range(5))  # 1 control + 4 target

    dev = qml.device("default.qubit", wires=wires)

    @qml.qnode(dev)
    def test_circuit():
        # Prepare |1⟩ control
        qml.PauliX(wires=0)

        # Prepare |1⟩ target (|0001⟩)
        qml.BasisState(np.array([0, 0, 0, 1]), wires=[1, 2, 3, 4])

        # Apply controlled U^m
        controlled_modular_multiplication_no_ctrl(m, control_wire=0, target_wires=[1, 2, 3, 4])

        return qml.state()

    state = test_circuit()

    # Compute expected output basis index
    result = pow(2, m, 15)  # 2^m mod 15
    full_index = (1 << 4) + result  # control qubit in position 0 is 1 -> bit 1xxxx

    # Check maximum amplitude index
    probs = np.abs(state) ** 2
    max_index = int(np.argmax(probs))

    passed = max_index == full_index
    print(f"Test CU^{m}: expected index {full_index} (|1⟩|{result}⟩), got {max_index}")
    return passed


In [5]:
for m in range(4):
    assert verify_controlled_modular_multiplication(m)

NotImplementedError: Only supports up to 2 control wires