In [2]:
from qiskit import QuantumCircuit
from math import pi
import numpy as np
from scipy.linalg import sqrtm, fractional_matrix_power
from scipy.linalg import fractional_matrix_power
from qiskit.quantum_info import Operator
from qiskit.circuit.library import UnitaryGate
from qiskit.quantum_info import random_unitary

def zyz_decompose(U, eps=1e-12):
    """
    Returns (alpha, beta, gamma, delta) for
        U = e^{i alpha} Rz(beta) Ry(gamma) Rz(delta)

    Works for all 2x2 unitaries, including edge cases gamma≈0,π.
    Angles are wrapped to (-pi, pi].
    """
    # global phase
    alpha = np.angle(np.linalg.det(U)) / 2.0
    W = U * np.exp(-1j*alpha)  # remove phase

    W00, W01 = W[0,0], W[0,1]
    W10, W11 = W[1,0], W[1,1]

    # magnitudes give gamma directly
    c = np.clip(abs(W00), 0.0, 1.0)  # = cos(g/2)
    s = np.clip(abs(W01), 0.0, 1.0)  # = sin(g/2)
    gamma = 2*np.arctan2(s, c)

    def wrap(x): return (x + np.pi) % (2*np.pi) - np.pi

    # handle degeneracies cleanly
    if s < eps:                       # gamma ~ 0
        gamma = 0.0
        beta  = wrap(np.angle(W11) - np.angle(W00))  # choose δ=0
        delta = 0.0
    elif c < eps:                     # gamma ~ π
        gamma = np.pi
        beta  = wrap(np.angle(W10) - np.angle(W01) + np.pi)  # choose δ=0
        delta = 0.0
    else:                              # generic case
        delta = wrap(np.angle(W01) - np.angle(W00) - np.pi)
        beta  = wrap(np.angle(W11) - np.angle(W01) + np.pi)

    alpha, beta, gamma, delta = map(wrap, (alpha, beta, gamma, delta))

    # resolve possible π ambiguity in alpha → ensure exact match not -U
    def Rz(t): 
        t2 = t/2
        return np.array([[np.exp(-1j*t2), 0],[0, np.exp(1j*t2)]], complex)
    def Ry(t):
        t2 = t/2; c,s = np.cos(t2), np.sin(t2)
        return np.array([[c, -s],[s, c]], complex)
    Urec = np.exp(1j*alpha) * (Rz(beta) @ Ry(gamma) @ Rz(delta))
    if np.allclose(-U, Urec, atol=1e-10):
        alpha = wrap(alpha + np.pi)

    return alpha, beta, gamma, delta

def controlled_U(U) -> QuantumCircuit:
    alpha, beta, gamma, delta = zyz_decompose(U)

    qc = QuantumCircuit(2)
    # C
    qc.rz((delta - beta)/2, 1)
    qc.cx(0, 1)
    # B
    qc.rz(-(delta + beta)/2, 1)
    qc.ry(-gamma/2, 1)
    qc.cx(0, 1)
    # A 
    qc.ry(gamma/2, 1)
    qc.rz(beta, 1)
    # P
    qc.p(alpha, 0)

    return qc

def Rx(theta):
    c = np.cos(theta/2)
    s = np.sin(theta/2)
    return np.array([
        [c, -1j*s],
        [-1j*s, c]
    ], dtype=complex)

def P(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="P")
    for k in reversed(range(2, n + 1)):
        Rx_matrix = Rx(pi / (2**(n - k + 1)))
        qc.append(controlled_U(Rx_matrix), [k - 1, n])

    return qc

def Q(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="Q")
    for k in reversed(range(1, n)):
        qc.append(multi_controlled_RX(k), range(k + 1))
    
    return qc

def multi_controlled_RX(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="LDD mcRX")
    qc.append(P(n), range(n + 1))
    Rx_matrix = Rx(pi / (2**(n - 1)))
    qc.append(controlled_U(Rx_matrix), [0, n])
    qc.append(Q(n), range(n + 1))
    qc.append(P(n).inverse(), range(n + 1))
    qc.append(Q(n).inverse(), range(n + 1))

    return qc

def P_U(n: int, U) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="P_U")
    for k in reversed(range(2, n + 1)):
        U = sqrtm(U)
        cu_gate = multi_controlled_U(1, U)
        qc.append(cu_gate, [k - 1, n])

    return qc

def multi_controlled_U(n: int, U) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="LDD mcu")
    qc.append(P_U(n, U), range(n + 1))
    root_U = fractional_matrix_power(U, 2**(-n + 1))
    qc.append(controlled_U(root_U), [0, n])
    qc.append(Q(n), range(n + 1))
    qc.append(P_U(n, U).inverse(), range(n + 1))
    qc.append(Q(n).inverse(), range(n + 1))

    return qc

def test_mcu(n, U):
    mcu = multi_controlled_U(n, U)
    qc_mcu = QuantumCircuit(n + 1)
    qc_mcu.append(mcu, range(n + 1))
    matrix_mcu = Operator(qc_mcu).data
    
    qc_ref_mcu = QuantumCircuit(n + 1)
    qc_ref_mcu.append(UnitaryGate(U).control(n), range(n + 1))
    matrix_ref_mcu = Operator(qc_ref_mcu).data

    return np.allclose(matrix_mcu, matrix_ref_mcu, 1e-5)

def run_test_mcu(n_max, m_max):
    for m in range(m_max):
        for n in range(1, n_max + 1):
            U = random_unitary(2).data
            if test_mcu(n, U):
                print(f"✅ Test passed: n_controls = {n}, on unitary number {m}")
            else:
                print(f"❌ Test failed: n_controls = {n}, on unitary number {m}")

run_test_mcu(4, 6)

✅ Test passed: n_controls = 1, on unitary number 0
✅ Test passed: n_controls = 2, on unitary number 0
✅ Test passed: n_controls = 3, on unitary number 0
✅ Test passed: n_controls = 4, on unitary number 0
✅ Test passed: n_controls = 1, on unitary number 1
✅ Test passed: n_controls = 2, on unitary number 1
✅ Test passed: n_controls = 3, on unitary number 1
✅ Test passed: n_controls = 4, on unitary number 1
✅ Test passed: n_controls = 1, on unitary number 2
✅ Test passed: n_controls = 2, on unitary number 2
✅ Test passed: n_controls = 3, on unitary number 2
✅ Test passed: n_controls = 4, on unitary number 2
✅ Test passed: n_controls = 1, on unitary number 3
✅ Test passed: n_controls = 2, on unitary number 3
✅ Test passed: n_controls = 3, on unitary number 3
✅ Test passed: n_controls = 4, on unitary number 3
✅ Test passed: n_controls = 1, on unitary number 4
✅ Test passed: n_controls = 2, on unitary number 4
✅ Test passed: n_controls = 3, on unitary number 4
✅ Test passed: n_controls = 4, 