In [None]:
from qiskit import QuantumCircuit
from math import pi
import numpy as np
from scipy.linalg import sqrtm, fractional_matrix_power
from scipy.linalg import fractional_matrix_power
from qiskit.quantum_info import Operator
from qiskit.circuit.library import UnitaryGate, RXGate
from qiskit.quantum_info import random_unitary

def zyz_decompose(U: np.ndarray):
    detU = np.linalg.det(U)
    alpha = np.angle(detU) / 2.0
    V = U * np.exp(-1j * alpha)
    a, b = V[0,0], V[0,1]
    c, d = V[1,0], V[1,1]
    cos_gamma_2 = np.clip(np.abs(a), 0, 1)
    sin_gamma_2 = np.clip(np.abs(b), 0, 1)
    gamma = 2 * np.arctan2(sin_gamma_2, cos_gamma_2)
    phi_sum  = -np.angle(a * np.conj(d))
    phi_diff =  np.angle(c * np.conj(b)) - np.pi
    beta  = 0.5 * (phi_sum + phi_diff)
    delta = 0.5 * (phi_sum - phi_diff)
    def wrap(x): return (x + np.pi) % (2*np.pi) - np.pi

    return tuple(map(wrap, (alpha, beta, gamma, delta)))

def controlled_U(U) -> QuantumCircuit:
    alpha, beta, gamma, delta = zyz_decompose(U)
    qc = QuantumCircuit(2)
    # C
    qc.rz((delta - beta)/2, 1)
    qc.cx(0, 1)
    # B
    qc.rz(-(delta + beta)/2, 1)
    qc.ry(-gamma/2, 1)
    qc.cx(0, 1)
    # A 
    qc.ry(gamma/2, 1)
    qc.rz(beta, 1)
    # P
    qc.p(alpha, 0)

    return qc

def P(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="P")
    for k in reversed(range(2, n + 1)):
        crx = RXGate(pi / (2**(n - k + 1)), label=f"Rx(pi/{2**(n - k + 1)})").control(1)
        qc.append(crx, [k - 1, n])

    return qc

def Q(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="Q")
    for k in reversed(range(1, n)):
        qc.append(multi_controlled_RX(k), range(k + 1))
    
    return qc

def multi_controlled_RX(n: int) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="LDD mcRX")
    qc.append(P(n), range(n + 1))
    CRXGate = RXGate(pi / (2**(n - 1)), label=f"Rx(pi/{2**(n - 1)})").control(1)
    qc.append(CRXGate, [0, n])
    qc.append(Q(n), range(n + 1))
    qc.append(P(n).inverse(), range(n + 1))
    qc.append(Q(n).inverse(), range(n + 1))

    return qc

def P_U(n: int, U) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="P_U")
    for k in reversed(range(2, n + 1)):
        U = sqrtm(U)
        # cu_gate = UnitaryGate(U, label=f"U^(2^-{n - k + 1})").control(1)
        cu_gate = multi_controlled_U(1, U)
        qc.append(cu_gate, [k - 1, n])

    return qc

def multi_controlled_U(n: int, U) -> QuantumCircuit:
    qc = QuantumCircuit(n + 1, name="LDD mcu")
    qc.append(P_U(n, U), range(n + 1))
    root_U = fractional_matrix_power(U, 2**(-n + 1))
    # cu_gate = UnitaryGate(root_U, label=f"U^(2^-{n - 1})").control(1)
    # qc.append(cu_gate, [0, n])
    qc.append(controlled_U(root_U), [0, n])
    qc.append(Q(n), range(n + 1))
    qc.append(P_U(n, U).inverse(), range(n + 1))
    qc.append(Q(n).inverse(), range(n + 1))

    return qc

def test_mcu(n, U):
    mcu = multi_controlled_U(n, U)
    qc_mcu = QuantumCircuit(n + 1)
    qc_mcu.append(mcu, range(n + 1))
    matrix_mcu = Operator(qc_mcu).data
    
    qc_ref_mcu = QuantumCircuit(n + 1)
    qc_ref_mcu.append(UnitaryGate(U).control(n), range(n + 1))
    matrix_ref_mcu = Operator(qc_ref_mcu).data

    return np.allclose(matrix_mcu, matrix_ref_mcu, 1e-5)

def run_test_mcu(n_max, m_max):
    for m in range(m_max):
        for n in range(1, n_max + 1):
            U = random_unitary(2).data
            if test_mcu(n, U):
                print(f"✅ Test passed: n = {n}, on unitary number {m}")
            else:
                print(f"❌ Test failed: n = {n}, on unitary number {m}")

run_test_mcu(3, 5)

❌ Test failed: n = 1, on unitary number 0
✅ Test passed: n = 2, on unitary number 0
❌ Test failed: n = 3, on unitary number 0
✅ Test passed: n = 1, on unitary number 1
❌ Test failed: n = 2, on unitary number 1
❌ Test failed: n = 3, on unitary number 1
❌ Test failed: n = 1, on unitary number 2
✅ Test passed: n = 2, on unitary number 2
❌ Test failed: n = 3, on unitary number 2
❌ Test failed: n = 1, on unitary number 3
✅ Test passed: n = 2, on unitary number 3
❌ Test failed: n = 3, on unitary number 3
✅ Test passed: n = 1, on unitary number 4
✅ Test passed: n = 2, on unitary number 4
❌ Test failed: n = 3, on unitary number 4
