In [1]:
import numpy as np
import pickle as pkl

In [2]:
def save(filename, obj):
    with open(filename, 'wb') as f:
        pkl.dump(obj, f)

def load(filename):
    with open(filename, 'rb') as f:
        return pkl.load(f)

In [None]:
def load_data(fname):
    data = np.loadtxt(fname, delimiter=',')
    X, Y = data[:, 1:].reshape(-1, 28, 28), data[:, 0]
    return X, Y

In [4]:
# Addition of MPS with open boundary condition

def dsum(A, B):
    d = A.shape[1]
    assert d == B.shape[1]

    dsum = np.zeros((A.shape[0] + B.shape[0], d, A.shape[2] + B.shape[2]))
    dsum[:A.shape[0], :, :A.shape[2]] = A
    dsum[A.shape[0]:, :, A.shape[2]:] = B

    return dsum


def row(A, B):
    assert A.shape[0] == B.shape[0] == 1
    d = A.shape[1]
    assert d == B.shape[1]

    row = np.zeros((1, d, A.shape[2] + B.shape[2]))
    row[:, :, :A.shape[2]] = A
    row[:, :, A.shape[2]:] = B

    return row


def col(A, B):
    assert A.shape[2] == B.shape[2] == 1
    d = A.shape[1]
    assert d == B.shape[1]

    col = np.zeros((A.shape[0] + B.shape[0], d, 1))
    col[:A.shape[0], :, :] = A
    col[A.shape[0]:, :, :] = B

    return col


def add(mps_a, mps_b):
    L = len(mps_a)
    assert len(mps_b) == L

    mps = [row(mps_a[0], mps_b[0])] + [dsum(mps_a[i], mps_b[i]) for i in range(1, L - 1)] + [col(mps_a[L - 1], mps_b[L - 1])]

    return mps


def compress(mps, χ_max, normalize=False):

    L = len(mps)
    norm = np.array([[[1]]])
    mps = mps + [norm]
    for n in range(L):

        # SVD and truncation
        vL, d, vR = mps[n].shape
        M = mps[n].reshape(vL*d, vR)
        U, S, Vh = np.linalg.svd(M, full_matrices=False)
        U, S, Vh = U[:, :χ_max], S[:χ_max], Vh[:χ_max, :]
        χ = len(S)
        mps[n] = U.reshape(vL, d, χ)

        # absorb SVD tensors into next tensor
        vL, d, vR = mps[n+1].shape
        M = mps[n+1].reshape(vL, d*vR)
        mps[n+1] = ((S[:, None] * Vh) @ M).reshape(χ, d, vR)

    mps, norm = mps[:-1], mps[-1]
    if not normalize:
        mps = [M * norm**(1/L) for M in mps]

    return mps, norm

In [None]:
# MPS model functions

def mean_filter(x):
    return x.reshape(14, 2, 14, 2).transpose(0, 2, 1, 3).mean(axis=(2, 3))


def image_to_mps(x):

    mps = []
    for pixel in x.flatten():

        M = np.zeros([1, 2, 1], float)
        M[0, 0, 0] = np.cos(np.pi/2*pixel/255)
        M[0, 1, 0] = np.sin(np.pi/2*pixel/255)
        mps.append(M)

    return mps


def inner_product(mps_a: list, mps_b: list):
    L = len(mps_a)
    assert L == len(mps_b)

    t = np.tensordot(mps_b[0], mps_a[0].conj(), axes=[1, 1])
    t = t.squeeze(axis=(0, 2))
    for n in range(1, L):

        t = np.tensordot(t, mps_b[n], axes=[0, 0])
        t = np.tensordot(t, mps_a[n].conj(), axes=[[0, 1], [0, 1]])

    return t.item()


def train(X, Y, χ_max=10):

    model = list()
    for digit in range(10):

        samples = X[Y == digit]
        for i, x in enumerate(samples):

            x = image_to_mps(mean_filter(x))
            Ψ = add(Ψ, x) if i > 0 else x

            if i > χ_max - 2:

                if i + 1 == len(X[Y == digit]):
                    Ψ, _ = compress(Ψ, χ_max, normalize=True)
                else:
                    Ψ, _ = compress(Ψ, χ_max, normalize=False)

        model.append(Ψ)

    return model


def project(model, x):
    return [abs(inner_product(ψ, image_to_mps(mean_filter(x)))) for ψ in model]


def predict(model, x):

    projections = project(model, x)
    prediction = np.argmax(projections)
    
    return prediction


def accuracies(model, X, Y):

    accuracies = list()
    for digit in range(10):

        samples = X[Y==digit]
        projections = [project(model, x) for x in samples]
        predictions = np.argmax(projections, axis=1)
        accuracy = sum(predictions == digit) / len(samples)
        accuracies.append(accuracy)

    return accuracies

In [10]:
X, Y = load_data('data/mnist_train.csv')
X_test, Y_test = load_data('data/mnist_test.csv')

In [11]:
model = train(X, Y, χ_max=10)

In [12]:
accuracies = accuracies(model, X_test, Y_test)
print(accuracies)

[0.7979591836734694, 0.9682819383259912, 0.5959302325581395, 0.6356435643564357, 0.659877800407332, 0.7130044843049327, 0.9039665970772442, 0.566147859922179, 0.7135523613963038, 0.7611496531219029]
