# Schmidt decomposition of a vector

This problem can be solved using Singular Value Decomposition.

In [1]:
import numpy as np
import cirq

# Returns Schmidt decomposition of a vector, s.t. vec = sum (k_i * a_i x b_i). 
def schmidt_decompose(vec, dim1, dim2, tol=1e-9):
    """Schmidt decomposition of a unit vector."""
    assert vec.shape == (dim1 * dim2, )
    assert np.allclose(np.linalg.norm(vec), 1), "Not unit vector."
    A = np.reshape(vec, (dim1, dim2))
    a, k, bT = np.linalg.svd(A)
    size = sum(k>=tol)
    a = [a[:, i] for i in range(size)]
    b = [bT[i, :] for i in range(size)]
    k = k[0:size]
    return a, b, k  

In [2]:

def test_schmidt_decompose(vec):
    def is_orthonormal_basis(vecs):
        V = np.array(vecs)
        return np.allclose(V @ V.T.conj(), np.eye(len(vecs)))
        
    vec = np.array(vec)
    dim = len(vec)
    for dim1 in range(1, dim+1):
        if dim % dim1 != 0: continue
        dim2 = dim // dim1
        a, b, k = schmidt_decompose(vec, dim1, dim2)
        
        assert np.allclose(np.linalg.norm(k), 1)
        assert is_orthonormal_basis(a)
        assert is_orthonormal_basis(b)
        
        vec_restored = sum([k[i] * np.kron(a[i], b[i]) for i in range(len(k))])
        assert np.allclose(vec, vec_restored)

test_schmidt_decompose([1.0])
test_schmidt_decompose([0.5, 0.5, 0.5, 0.5])

for i in range(100):
    vec = np.random.random(size=(8,2)) 
    vec = vec[:,0] + 1j * vec[:,1]
    vec /= np.linalg.norm(vec)
    test_schmidt_decompose(vec)
print("OK")

OK
