In [1]:
import numpy as np

In [2]:
def load_data(fname):
    data = np.loadtxt(fname, delimiter=',')
    X, Y = data[:, 1:].reshape(-1, 28, 28), data[:, 0]
    return X, Y

In [3]:
def mean_filter(x):
    return x.reshape(14, 2, 14, 2).transpose(0, 2, 1, 3).mean(axis=(2, 3))


def image_to_mps(x):
    mps = []
    for pixel in x.flatten():
        
        M = np.zeros([1, 2, 1], float)
        M[0, 0, 0] = np.cos(np.pi/2*pixel/255)
        M[0, 1, 0] = np.sin(np.pi/2*pixel/255)
        mps.append(M)

    return mps

In [4]:
def dsum(A, B):
    d = A.shape[1]
    assert d == B.shape[1]
    
    dsum = np.zeros((A.shape[0] + B.shape[0], d, A.shape[2] + B.shape[2]))
    dsum[:A.shape[0], :, :A.shape[2]] = A
    dsum[A.shape[0]:, :, A.shape[2]:] = B

    return dsum


def row(A, B):
    assert A.shape[0] == B.shape[0] == 1
    d = A.shape[1]
    assert d == B.shape[1]
    
    row = np.zeros((1, d, A.shape[2] + B.shape[2]))
    row[:, :, :A.shape[2]] = A
    row[:, :, A.shape[2]:] = B

    return row


def col(A, B):
    assert A.shape[2] == B.shape[2] == 1
    d = A.shape[1]
    assert d == B.shape[1]
    
    col = np.zeros((A.shape[0] + B.shape[0], d, 1))
    col[:A.shape[0], :, :] = A
    col[A.shape[0]:, :, :] = B

    return col


def add(mps_a, mps_b):
    L = len(mps_a)
    assert len(mps_b) == L
    
    mps = [row(mps_a[0], mps_b[0])] + [dsum(mps_a[i], mps_b[i]) for i in range(1, L - 1)] + [col(mps_a[L - 1], mps_b[L - 1])]

    return mps

In [5]:
def inner_product(mps_a: list, mps_b: list):

    L = len(mps_a)
    assert L == len(mps_b)

    t = np.tensordot(mps_b[0], mps_a[0].conj(), axes=[1, 1]) # vL [j] vR, vL* [j*] vR*
    t = t.squeeze(axis=(0, 2))  # vR vR*

    for n in range(1, L):

        t = np.tensordot(t, mps_b[n], axes=[0, 0]) # [vR] vR*, [vL] j vR
        t = np.tensordot(t, mps_a[n].conj(), axes=[[0, 1], [0, 1]]) # [vR*] [j] vR, [vL*] [j*] vR*

    return t.item()

In [10]:
def compress(mps, χ_max):
    
    L = len(mps)
    mps = mps + [np.array([[[1]]])]
    for n in range(L):
        
        # SVD and truncation
        vL, d, vR = mps[n].shape
        M = mps[n].reshape(vL*d, vR)
        U, S, Vh = np.linalg.svd(M, full_matrices=False)
        U, S, Vh = U[:, :χ_max], S[:χ_max], Vh[:χ_max, :]
        χ = len(S)
        mps[n] = U.reshape(vL, d, χ)

        # absorb SVD tensors into next tensor
        vL, d, vR = mps[n+1].shape
        M = mps[n+1].reshape(vL, d*vR)
        mps[n+1] = ((S[:, None] * Vh)@M).reshape(χ, d, vR)

    norm = mps[-1]
    mps = [M * norm**(1/L) for M in mps[:-1]]

    return mps, norm

In [7]:
X, Y = load_data('data/mnist_train.csv')
X_test, Y_test = load_data('data/mnist_test.csv')

In [8]:
digit = 2
X_mps = [image_to_mps(mean_filter(x)) for x in X[Y == digit]]

In [11]:
χ_max = 10
Ψ = X_mps[0]
for i, x_mps in enumerate(X_mps[1:]):
    Ψ = add(Ψ, x_mps)
    if i > χ_max - 2:
        Ψ, norm = compress(Ψ, χ_max)
    if i == len(X_mps)-1:
        L = 196
        Ψ / norm**(1/L)

In [17]:
X_test_mps = list()
for digit in range(10):
    X_test_mps.append([image_to_mps(mean_filter(x)) for x in X_test[Y_test == digit]])

In [21]:
mean_overlaps = []
for digit in range(10):
 print(digit)
 mean_overlaps.append(np.mean([abs(inner_product(Ψ, x)) for x in X_test_mps[digit]]))

0
1
2
3
4
5
6
7
8
9


In [24]:
for mean_overlap in mean_overlaps:
    print(mean_overlap)

2.4262893272384197e-10
2.977318722909857e-10
8.198235687279568e-06
3.177051687172096e-10
1.5774381819483206e-10
1.9570984332002365e-10
2.4718265024979617e-11
7.342532837739642e-10
1.2636397427327328e-10
1.9051323873635279e-10


In [23]:
inner_product(Ψ, Ψ)

1.0012985714637612

In [25]:
len(X[Y==2]), len(X_test[Y_test==2])

(5958, 1032)