In [44]:
from qiskit import QuantumCircuit, transpile
from qiskit_aer import Aer
from math import pi
import numpy as np
from scipy.linalg import sqrtm, fractional_matrix_power
from scipy.linalg import fractional_matrix_power
from qiskit.quantum_info import Operator
from qiskit.circuit.library import UnitaryGate, RXGate

def zyz_decompose(U: np.ndarray):
    detU = np.linalg.det(U)
    alpha = np.angle(detU) / 2.0
    V = U * np.exp(-1j * alpha)
    a, b = V[0,0], V[0,1]
    c, d = V[1,0], V[1,1]
    cos_gamma_2 = np.clip(np.abs(a), 0, 1)
    sin_gamma_2 = np.clip(np.abs(b), 0, 1)
    gamma = 2 * np.arctan2(sin_gamma_2, cos_gamma_2)
    phi_sum  = -np.angle(a * np.conj(d))
    phi_diff =  np.angle(c * np.conj(b)) - np.pi
    beta  = 0.5 * (phi_sum + phi_diff)
    delta = 0.5 * (phi_sum - phi_diff)
    def wrap(x): return (x + np.pi) % (2*np.pi) - np.pi

    return tuple(map(wrap, (alpha, beta, gamma, delta)))

def controlled_U(U) -> QuantumCircuit:
    alpha, beta, gamma, delta = zyz_decompose(U)
    qc = QuantumCircuit(2)
    qc.rz((delta - beta)/2, 1)
    qc.cx(0, 1)
    qc.rz(-(delta + beta)/2, 1)
    qc.ry(-gamma/2, 1)
    qc.cx(0, 1)
    qc.ry(gamma/2, 1)
    qc.rz(beta, 1)
    qc.p(alpha, 0)

    return qc

def P(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="P")
    for k in reversed(range(2, n + 1)):
        crx = RXGate(pi / (2**(n - k + 1)), label=f"Rx(pi/{2**(n - k + 1)})").control(1)
        qc.append(crx, [k - 1, n])

    return qc

def Q(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="Q")
    for k in reversed(range(1, n)):
        qc.append(multi_controlled_RX(k), range(k + 1))
    
    return qc

def multi_controlled_RX(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="LDD mcRX")
    qc.append(P(n), range(n + 1))
    CRXGate = RXGate(pi / (2**(n - 1)), label=f"Rx(pi/{2**(n - 1)})").control(1)
    qc.append(CRXGate, [0, n])
    qc.append(Q(n), range(n + 1))
    qc.append(P(n).inverse(), range(n + 1))
    qc.append(Q(n).inverse(), range(n + 1))

    return qc

def P_U(n: int, U) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="P_U")
    for k in reversed(range(2, n + 1)):
        U = sqrtm(U)
        # cu_gate = UnitaryGate(U, label=f"U^(2^-{n - k + 1})").control(1)
        cu_gate = multi_controlled_U(1, U)
        qc.append(cu_gate, [k - 1, n])

    return qc

def multi_controlled_U(n: int, U) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="LDD mcu")
    qc.append(P_U(n, U), range(n + 1))
    root_U = fractional_matrix_power(U, 2**(-n + 1))
    # cu_gate = UnitaryGate(root_U, label=f"U^(2^-{n - 1})").control(1)
    # qc.append(cu_gate, [0, n])
    qc.append(controlled_U(root_U), [0, n])
    qc.append(Q(n), range(n + 1))
    qc.append(P_U(n, U).inverse(), range(n + 1))
    qc.append(Q(n).inverse(), range(n + 1))

    return qc

In [45]:
def unitary(circ: QuantumCircuit) -> np.ndarray:
    return Operator(circ).data

def equal_up_to_global_phase(A: np.ndarray, B: np.ndarray, atol=1e-8) -> bool:
    # If either is all zeros (shouldn’t happen for unitaries), fall back to allclose
    if np.allclose(A, 0) or np.allclose(B, 0):
        return np.allclose(A, B, atol=atol)
    # Estimate global phase phi so that e^{-i phi} * B ≈ A
    inner = np.vdot(A.flatten(), B.flatten())  # sum conj(A_ij) * B_ij
    phi = np.angle(inner)
    return np.allclose(A, B * np.exp(-1j * phi), atol=atol)

def reference_mcu(n_controls: int, U2x2: np.ndarray):
    return UnitaryGate(U2x2).control(n_controls)

# ----------- Configure a test -----------
n_controls = 2                          # try 1, 2, 3, ...
theta = 0.73
U = RXGate(theta).to_matrix()           # any 2x2 unitary works

ut_gate = multi_controlled_U(n_controls, U)
ref_gate = reference_mcu(n_controls, U)

# Build circuits on identical qubit ordering: [controls..., target]
qc_ut  = QuantumCircuit(n_controls + 1)
qc_ut.append(ut_gate, list(range(n_controls + 1)))

qc_ref = QuantumCircuit(n_controls + 1)
qc_ref.append(ref_gate, list(range(n_controls + 1)))

# ----------- Compare unitaries -----------
U_ut  = unitary(qc_ut)
U_ref = unitary(qc_ref)

ok = equal_up_to_global_phase(U_ut, U_ref, atol=1e-8)
print("Matrix equivalence (up to global phase):", ok)
assert ok, "multi_controlled_U does not match Qiskit’s reference control()."

print("✅ Test passed.")

Matrix equivalence (up to global phase): True
✅ Test passed.
