# Multi-controlled gate decomposition

In this notebook we will implement an algorithm from [this paper](https://arxiv.org/pdf/quant-ph/9503016.pdf) decomposing multi-controlled gate into sequence of 1-qubit gates and CNOT gates.


In [1]:
import numpy as np
import cirq
from scipy.stats import unitary_group

In [125]:
def is_unitary(U):
    n = U.shape[0]
    return np.allclose(U @ U.T.conj(), np.eye(n))

def is_special_unitary(A):
    return is_unitary(A) and np.allclose(np.linalg.det(A), 1.0)

def unitary_power(U, p):
    """Raises unitary U to power p."""
    assert is_unitary(U)
    eig_vals, eig_vectors = np.linalg.eig(U)
    Q = np.array(eig_vectors)
    return Q @ np.diag(np.exp(p * 1j * np.angle(eig_vals))) @ Q.conj().T


# Applies controlled X.
# Result consists of CNOT and CCNOT gates.
def decompose_fc_x(controls, target, free_qubits=[]):
    m = len(controls)
    if m == 0:
        return [cirq.X.on(target)]
    elif m == 1:
        return [cirq.CNOT.on(controls[0], target)]
    elif m == 2:
        return [cirq.CCNOT.on(controls[0], controls[1], target)]
    
    
    m = len(controls)
    n = m + 1 + len(free_qubits) 
    if (n >= 2 * m - 1):
        # Lemma 7.2.
        seq1 = [cirq.CCNOT(controls[m-2-i], free_qubits[m-4-i], free_qubits[m-3-i])  for i in range(m-3)]
        seq2 = seq1 + [cirq.CCNOT(controls[0], controls[1], free_qubits[0])] + seq1[::-1] 
        first_gate = cirq.CCNOT(controls[m-1], free_qubits[m-3], target)
        return [first_gate] + seq2 + [first_gate] + seq2 
    elif len(free_qubits) >= 1: 
        # Lemma 7.3. 
        m1 = n // 2
        seq1 = decompose_fc_x(controls[:m1], free_qubits[0], free_qubits=controls[m1:] + [target] + free_qubits[1:])
        seq2 = decompose_fc_x(controls[m1:] + [free_qubits[0]], target, free_qubits = controls[:m1] + free_qubits[1:])
        return seq1 + seq2 + seq1 + seq2
    else:
        # No free qubit.
        X = np.array([[0, 1], [1,0]])
        return decompose_fc_gate(X, controls, target)


def test_fc_x_decomposition():
    for n in range(2, 12):
        qubits = cirq.LineQubit.range(n)
        for m in range(0, n):
            c1 = cirq.Circuit([cirq.I.on(q) for q in qubits])
            gates = decompose_fc_x(qubits[:m], qubits[m], free_qubits=qubits[m+1:])
            c1.append(gates)
            result_matrix = c1.unitary()
            
            c2 = cirq.Circuit([cirq.I.on(q) for q in qubits])
            c2 += cirq.ControlledGate(cirq.X, num_controls = m).on(*qubits[0:m+1])
            expected_matrix = c2.unitary()
            
            
            
            assert np.allclose(expected_matrix, result_matrix)
            
            print('n=%d, m=%d, len=%d OK' % (n,m, len(gates)))
            
test_fc_x_decomposition()

n=2, m=0, len=1 OK
n=2, m=1, len=1 OK
n=3, m=0, len=1 OK
n=3, m=1, len=1 OK
n=3, m=2, len=1 OK
n=4, m=0, len=1 OK
n=4, m=1, len=1 OK
n=4, m=2, len=1 OK
n=4, m=3, len=39 OK
n=5, m=0, len=1 OK
n=5, m=1, len=1 OK
n=5, m=2, len=1 OK
n=5, m=3, len=4 OK
n=5, m=4, len=61 OK
n=6, m=0, len=1 OK
n=6, m=1, len=1 OK
n=6, m=2, len=1 OK
n=6, m=3, len=4 OK
n=6, m=4, len=10 OK
n=6, m=5, len=95 OK
n=7, m=0, len=1 OK
n=7, m=1, len=1 OK
n=7, m=2, len=1 OK
n=7, m=3, len=4 OK
n=7, m=4, len=8 OK
n=7, m=5, len=16 OK
n=7, m=6, len=137 OK
n=8, m=0, len=1 OK
n=8, m=1, len=1 OK
n=8, m=2, len=1 OK
n=8, m=3, len=4 OK
n=8, m=4, len=8 OK
n=8, m=5, len=18 OK
n=8, m=6, len=24 OK
n=8, m=7, len=203 OK
n=9, m=0, len=1 OK
n=9, m=1, len=1 OK
n=9, m=2, len=1 OK
n=9, m=3, len=4 OK
n=9, m=4, len=8 OK
n=9, m=5, len=12 OK
n=9, m=6, len=24 OK
n=9, m=7, len=32 OK
n=9, m=8, len=269 OK
n=10, m=0, len=1 OK
n=10, m=1, len=1 OK
n=10, m=2, len=1 OK
n=10, m=3, len=4 OK
n=10, m=4, len=8 OK
n=10, m=5, len=12 OK
n=10, m=6, len=26 OK
n=10, 

In [124]:
        
def is_identity(M):
    return np.allclose(M, np.eye(M.shape[0]))
    
# Applies ∧_m(U^matrix_pow) on given qubits.
def decompose_fc_gate(U, controls, target, free_qubits=[], matrix_pow=1.0):
    M = unitary_power(U, matrix_pow)
    
    assert U.shape == (2, 2)
    assert is_unitary(U)
    
    
    def _Ry(theta):
        return np.array([[np.cos(theta/2), np.sin(theta/2)], [-np.sin(theta/2), np.cos(theta/2)]])
    
    def _Rz(alpha):
        return np.diag(np.exp([0.5j * alpha, -0.5j * alpha]))
    
    
    # Notation of the paper, see chapter 4.
    delta = np.angle(np.linalg.det(M)) * 0.5
    M_su = M / np.exp(1j * delta)
    assert is_special_unitary(M_su)
    
    theta = 2 * np.arccos(np.abs(M_su[0,0]))
    alpha = np.angle(M_su[0,0]) + np.angle(M_su[0,1])
    beta = np.angle(M_su[0,0]) - np.angle(M_su[0,1])
    
    
    A = _Rz(alpha) @ _Ry(theta/2)
    B = _Ry(-theta/2) @ _Rz(-(alpha+beta)/2)
    C = _Rz((beta-alpha)/2)
    X = np.array([[0, 1], [1, 0]])
    assert np.allclose(A @ B @ C, np.eye(2))
    assert np.allclose(A @ X @ B @ X @ C, M_su)
    
    m = len(controls)
    if m == 1:
        # Chapter 5.1.
        result = [
            cirq.ZPowGate(exponent=delta/np.pi).on(controls[0]),
            cirq.rz(-0.5*(beta-alpha)).on(target),
            cirq.CNOT.on(controls[0], target),
            cirq.rz(0.5*(beta+alpha)).on(target),
            cirq.ry(0.5*theta).on(target),
            cirq.CNOT.on(controls[0], target),
            cirq.ry(-0.5*theta).on(target),
            cirq.rz(-alpha).on(target),
        ]
        
        # Remove no-ops.
        result = [g for g in result if not is_identity(g._unitary_())]
        
        return result
    else:   
        gate_is_special_unitary = np.allclose(delta, 0)
        
        if gate_is_special_unitary:
            # Lemma 7.9.
            cnot_seq = decompose_fc_x(controls[:-1], target, free_qubits=[controls[-1]])
            result = []
            result += decompose_fc_gate(C, [controls[-1]], target)
            result += cnot_seq
            result += decompose_fc_gate(B, [controls[-1]], target)
            result += cnot_seq
            result += decompose_fc_gate(A, [controls[-1]], target)
            return result   
        else:
            # Lemma 
            cnot_seq = decompose_fc_x(controls[:-1], controls[-1], free_qubits=free_qubits+[target])
            part1 = decompose_fc_gate(U, [controls[-1]], target, matrix_pow=0.5*matrix_pow) 
            part2 = decompose_fc_gate(U, [controls[-1]], target, matrix_pow=-0.5*matrix_pow) 
            part3 = decompose_fc_gate(U, controls[:-1], target, free_qubits=free_qubits+[controls[-1]], matrix_pow=0.5*matrix_pow)
            return part1 + cnot_seq + part2 + cnot_seq + part3            
    
    return circuit

verify_decomposition(random_unitary())

m=1 OK gates=8, z=71, k1=2.666667, k2=7.920792
m=2 OK gates=26, z=53, k1=6.500000, k2=6.483791
m=3 OK gates=44, z=51, k1=8.800000, k2=4.883463
m=4 OK gates=68, z=65, k1=11.333333, k2=4.247345
m=5 OK gates=104, z=95, k1=14.857143, k2=4.158337
m=6 OK gates=148, z=141, k1=18.500000, k2=4.109969
m=7 OK gates=216, z=203, k1=24.000000, k2=4.407264
m=8 OK gates=284, z=281, k1=28.400000, k2=4.436807
m=9 OK gates=384, z=375, k1=34.909091, k2=4.740156
m=10 OK gates=476, z=485, k1=39.666667, k2=4.759524
m=11 OK gates=608, z=611, k1=46.769231, k2=5.024378
m=12 OK gates=724, z=753, k1=51.714286, k2=5.027429
m=13 OK gates=888, z=911, k1=59.200000, k2=5.254127
m=14 OK gates=1028, z=1085, k1=64.250000, k2=5.244630
m=15 OK gates=1221, z=1275, k1=71.823529, k2=5.426425
m=16 OK gates=1383, z=1481, k1=76.833333, k2=5.402133
m=17 OK gates=1606, z=1703, k1=84.526316, k2=5.556901
m=18 OK gates=1790, z=1941, k1=89.500000, k2=5.524521
m=19 OK gates=2046, z=2195, k1=97.428571, k2=5.667433


In [121]:

def verify_decomposition(U):
    for m in range (1, 20):
        qubits = cirq.LineQubit.range(m+1)
        
        gates = decompose_fc_gate(U, qubits[:-1], qubits[-1])
        
        
        # Must verify that all gates are either CNOT or Ry/Rz acting on last qubit.
        
        if m <= 9:
            result_matrix = cirq.Circuit(gates).unitary()
            d = 2**(m+1)
            expected_matrix = np.eye(d, dtype=np.complex128)
            expected_matrix[d-2:d, d-2:d] = U

            assert np.allclose(expected_matrix, result_matrix)
        print("m=%d OK gates=%d, z=%d, k1=%f, k2=%f" % (m, len(gates), 8*m**2-42*m+105, len(gates) / (m+2), len(gates) / (m*m+0.01)))
      
    
def random_unitary():
    np.random.seed(555)
    return unitary_group.rvs(2)
    
def random_special_unitary():
    np.random.seed(55)
    U = unitary_group.rvs(2)
    U /= np.sqrt(np.linalg.det(U))
    return U



