## AMP FOR SPIKED-TENSOR-MODELS

This is an implementation of the AMP algorithm for the mixed spiked tensor model (2+3), as written in https://arxiv.org/pdf/1812.09066.pdf written in pytorch. While it runs on CPU, porting it on GPU is trivial.

Here is the algorithm, written with a Gaussian prior

In [1]:
import torch
from math import sqrt

# Priors
def prior_gb(A, B, prmts):
    """Compute f_a and f_c for Gauss-Bernoulli prior"""

    rho, mu, sig = prmts

    m = (B * sig + mu) / (1. + A * sig)
    v = sig / (1 + A * sig)
    keep = (rho + (1 - rho) * torch.sqrt(1. + A * sig) *
            torch.exp(-.5 * m ** 2 / v + .5 * mu ** 2 / sig))
    p_s = rho / keep

    a = p_s * m
    c = p_s * v + p_s * (1. - p_s) * m ** 2
    return a, torch.mean(c)


# AMP Solver
def amp(Y, T, Delta2, Delta3,
        prior=prior_gb, prior_prmts=None, true_coef=None,
        max_iter=250, tol=1e-13, verbose=1):
    """Iterate VAMP equations"""

    N, _ = Y.shape

    # Initialize variables
    B2 = torch.zeros(N, 1)
    A2 = 0.
    B3 = torch.zeros(N, 1)
    A3 = 0.
    a = torch.randn(N, 1)
    a_old = torch.zeros(N, 1)
    c = 1.
    S2 = sqrt(2)
    SN = sqrt(N)
    K = S2 / (Delta3 * N)
    if verbose > 0:
        print("time ; convergence mse, Matrix-mmse ")
    for t in range(max_iter):
        # Messages/estimates on x from likelihood
        B2 = (1. / (SN * Delta2)) * Y @ a - (c / Delta2) * a_old
        A2 = (1. / Delta2) * torch.mean(a * a)
        B3 = (K * torch.transpose(a, 0, 1)@T@a).view(N, -1)
        B3 = B3 - (2 / Delta3) * c * (torch.mean(a * a)) * a_old
        A3 = (1. / Delta3) * torch.mean(a * a)**2

        a_old = torch.clone(a)
        a, cc = prior(A2 + A3, B2 + B3, prior_prmts)
        c = torch.mean(cc)
        a = 0.5 * a + 0.5 * a_old

        # Compute metrics
        conv = torch.mean((a - a_old) ** 2)
        mse = torch.mean((a - true_coef) ** 2) if true_coef is not None else 0.
        Mmse = torch.mean((a@torch.transpose(a, 0, 1) - (true_coef@torch.transpose(true_coef, 0, 1)))**2) if true_coef is not None else 0.
        if verbose > 0:
            print("t = %d; conv = %g, mse = %g, Mmse = %g" % (t, conv, mse, Mmse))
        if conv < tol:
            break
    return a

Now, this is a short demo on how the algorithm works:

In [2]:
import torch
# We keep CPU by default, if you want GPU uncomment the next line
# torch.set_default_tensor_type('torch.cuda.FloatTensor')#GPU by default
from math import sqrt
import time

N = 100
DELTA2 = 0.01
DELTA3 = 1

X0 = torch.randn(N, 1)
Y = X0@torch.transpose(X0, 0, 1)
T = Y.view(N, N, 1)@torch.transpose(X0, 0, 1)

Y = Y / sqrt(N) + torch.randn(N, N) * sqrt(DELTA2)
T = (sqrt(2) / N) * T + torch.randn(N, N, N) * sqrt(DELTA3)

t = time.time()
Xhat = amp(Y, T, DELTA2, DELTA3,
           prior=prior_gb, prior_prmts=(1, 0, 1),
           true_coef=X0, max_iter=250)
elapsed = time.time() - t

time ; convergence mse, Matrix-mmse 
t = 0; conv = 0.226766, mse = 1.104, Mmse = 0.91156
t = 1; conv = 0.0729921, mse = 0.848262, Mmse = 0.85193
t = 2; conv = 0.228364, mse = 0.266017, Mmse = 0.494889
t = 3; conv = 0.311169, mse = 0.023434, Mmse = 0.0577346
t = 4; conv = 0.00832498, mse = 0.0104917, Mmse = 0.0193538
t = 5; conv = 0.000734208, mse = 0.00923649, Mmse = 0.0170892
t = 6; conv = 0.000189386, mse = 0.00919293, Mmse = 0.017017
t = 7; conv = 4.99531e-05, mse = 0.00933875, Mmse = 0.0172903
t = 8; conv = 1.33684e-05, mse = 0.00946125, Mmse = 0.0175169
t = 9; conv = 3.62473e-06, mse = 0.00953846, Mmse = 0.0176592
t = 10; conv = 9.92879e-07, mse = 0.00958249, Mmse = 0.0177403
t = 11; conv = 2.74189e-07, mse = 0.00960644, Mmse = 0.0177843
t = 12; conv = 7.62091e-08, mse = 0.00961909, Mmse = 0.0178076
t = 13; conv = 2.1289e-08, mse = 0.00962566, Mmse = 0.0178196
t = 14; conv = 5.97307e-09, mse = 0.00962901, Mmse = 0.0178258
t = 15; conv = 1.6811e-09, mse = 0.0096307, Mmse = 0.017828