# Schmidt decomposition of a 4x4 matrix

Given 4x4 matrix, represent it as linear combintation of tensor products of 2x2 matrices.

This problem can be solved using KAK decomposition, according to [this paper](https://arxiv.org/ftp/arxiv/papers/1006/1006.3412.pdf).

In [1]:
import numpy as np
import cirq
from scipy.stats import unitary_group

In [2]:
PAULI_BASIS = [op._unitary_() for op in [cirq.I, cirq.X, cirq.Y, cirq.Z]]

def is_unitary(U):
    n = U.shape[0]
    return np.allclose(U @ U.T.conj(), np.eye(n))

def schmidt_decomposition(U, atol=1e-9):
    """Calculates Schmidt decomposition of 4x4 unitary matrix.
    
    Represents unitary matrix U as linear combination of tensor products of 2x2 unitaries:
        U = sum_i z_i * A_i ⊗ B_i,
    where A_i, B_i - 2x2 unitary matrices, z_i - positive and real, sum_i |z_i|^2 = 1.
    Sum has 1, 2, or 4 terms.
    
    Args:
        U: Unitary matrix to decompose.
        atol: Ignore coefficients whose absolute value is smaller than this. Defaults to 1e-9.

    Returns:
        Dict with keys `first_qubit_ops`, `second_qubit_ops` and `koeffs`, containing values of
        A_i, B_i and z_i respectively.
    """
    assert U.shape == (4, 4)
    assert is_unitary(U)
    
    kak = cirq.kak_decomposition(U)
    c1, c2, c3 = [2 * c for c in kak.interaction_coefficients]
    B0, B1 = kak.single_qubit_operations_before
    A0, A1 = kak.single_qubit_operations_after
    g = kak.global_phase
    
    # Caculate coefficients.
    z = [
        0.5  * (np.exp(0.5j*c1) * np.cos(0.5*(c3-c2)) + np.exp(-0.5j*c1) * np.cos(0.5*(c3+c2))),
        0.5  * (np.exp(0.5j*c1) * np.cos(0.5*(c3-c2)) - np.exp(-0.5j*c1) * np.cos(0.5*(c3+c2))),
        -0.5j* (np.exp(0.5j*c1) * np.sin(0.5*(c3-c2)) - np.exp(-0.5j*c1) * np.sin(0.5*(c3+c2))),
        0.5j * (np.exp(0.5j*c1) * np.sin(0.5*(c3-c2)) + np.exp(-0.5j*c1) * np.sin(0.5*(c3+c2))),
    ]
     
    # Throw away zero coefficients.
    take = [i for i in range(4) if abs(z[i]) > atol]
    z = [z[i] for i in take]
    a = [g * A0 @ PAULI_BASIS[i] @ B0 for i in take]
    b = [A1 @ PAULI_BASIS[i] @ B1 for i in take]
    
    
    # Make coefficients real.
    for i in range(len(z)):
        a[i] *= (z[i] / np.abs(z[i]))
        z[i] = np.abs(z[i])
        
    return {
        'first_qubit_ops': a, 
        'second_qubit_ops': b, 
        'koeffs': np.array(z),
    }

In [3]:
def test_schmidt_decomposition(U):
    sd = schmidt_decomposition(U)
    a = sd['first_qubit_ops']
    b = sd['second_qubit_ops']
    k = sd['koeffs']
    n = len(k)
    
    assert (n==1 or n==2 or n==4)
    assert len(a) == n
    assert len(b) == n
    for i in range(n):
        assert is_unitary(a[i])
        assert is_unitary(b[i])
        assert np.allclose(k[i], np.abs(k[i]))
    assert np.allclose(np.linalg.norm(k), 1)    
    
    U_restored = sum([k[i] * np.kron(a[i], b[i]) for i in range(len(k))])
    assert np.allclose(U, U_restored)

CNOT = np.array([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 0, 1],
    [0, 0, 1, 0],
])
test_schmidt_decomposition(CNOT)

for mx1 in PAULI_BASIS:
    for mx2 in PAULI_BASIS:
        test_schmidt_decomposition(np.kron(mx1, mx2))

for _ in range(50):
    U = np.kron(unitary_group.rvs(2), unitary_group.rvs(2))
    test_schmidt_decomposition(U)
    U = unitary_group.rvs(4)
    test_schmidt_decomposition(U)
    
print("OK")

OK
