## MAP-AMP FOR SPIKED-TENSOR-MODELS

This is an implementation of the AMP algorithm for the mixed spiked tensor model (2+3), as written in the MAP-AMP paper (to appear) written in pytorch. While it runs on CPU, porting it on GPU is trivial.

Here is the algorithm, written with a Spherical prior

In [23]:
import torch
from math import sqrt

# Priors
def prior_l2(B):
    """l2 norm prior"""
    N, _ = B.shape
    return B /torch.sqrt(torch.mean(B*B)), 1 /torch.sqrt(torch.mean(B*B))


# AMP Solver
def amp(Y, T, Delta2, Delta3,
        prior=prior_l2, true_coef=None,
        max_iter=250, tol=1e-13, verbose=1):
    """Iterate VAMP equations"""

    N, _ = Y.shape

    # Initialize variables
    B2 = torch.zeros(N, 1)
    B3 = torch.zeros(N, 1)
    a = torch.randn(N, 1)
    a_old = torch.zeros(N, 1)
    c = 1.
    S2 = sqrt(2)
    SN = sqrt(N)
    K = S2 / (Delta3 * N)
    if verbose > 0:
        print("time ; convergence mse, Matrix-mmse ")
    for t in range(max_iter):
        # Messages/estimates on x from likelihood
        B2 = (1. / (SN * Delta2)) * Y @ a - (c / Delta2) * a_old
        B3 = (K * torch.transpose(a, 0, 1)@T@a).view(N, -1)
        B3 = B3 - (2 / (Delta3) ) * c * (torch.mean(a * a_old)) * a_old

        a_old = torch.clone(a)
        a, c = prior_l2(B2 + B3)
        a = 0.5 * a + 0.5 * a_old

        # Compute metrics
        conv = torch.mean((a - a_old) ** 2)
        mse = torch.mean((a - true_coef) ** 2) if true_coef is not None else 0.
        Mmse = torch.mean((a@torch.transpose(a, 0, 1) - (true_coef@torch.transpose(true_coef, 0, 1)))**2) if true_coef is not None else 0.
        if verbose > 0:
            print("t = %d; conv = %g, mse = %g, Mmse = %g" % (t, conv, mse, Mmse))
        if conv < tol:
            break
    return a

Now, this is a short demo on how the algorithm works:

In [26]:
import torch
# We keep CPU by default, if you want GPU uncomment the next line
# torch.set_default_tensor_type('torch.cuda.FloatTensor')#GPU by default
from math import sqrt
import time

N = 300
DELTA2 = 0.01
DELTA3 = 1

X0 = torch.randn(N, 1)
Y = X0@torch.transpose(X0, 0, 1)
T = Y.view(N, N, 1)@torch.transpose(X0, 0, 1)

M_N = torch.randn(N, N)
Matrix_noise = (M_N + torch.transpose(M_N, 0, 1))/sqrt(2)
Y = Y / sqrt(N) + Matrix_noise * sqrt(DELTA2)

T_N = torch.randn(N, N, N)
R1 = torch.transpose(T_N,0,1)
R2 = torch.transpose(T_N,0,2)
R3 = torch.transpose(T_N,1,2)
R4 = torch.transpose(R1,1,2)
R5 = torch.transpose(R2,1,2)
Tensor_noise = (T_N + R1 + R2 + R3 + R4 +R5)/sqrt(6)
T = (sqrt(2) / N) * T + Tensor_noise * sqrt(DELTA3)

t = time.time()
Xhat = amp(Y, T, DELTA2, DELTA3,
           prior=prior_l2,
           true_coef=X0, max_iter=250)
elapsed = time.time() - t

time ; convergence mse, Matrix-mmse 
t = 0; conv = 0.504207, mse = 1.14215, Mmse = 1.20047
t = 1; conv = 0.292416, mse = 0.371058, Mmse = 0.638189
t = 2; conv = 0.0863816, mse = 0.105761, Mmse = 0.242267
t = 3; conv = 0.0215336, mse = 0.0361818, Mmse = 0.089527
t = 4; conv = 0.00588286, mse = 0.0166941, Mmse = 0.039646
t = 5; conv = 0.00164738, mse = 0.0111061, Mmse = 0.0244347
t = 6; conv = 0.000468327, mse = 0.00947392, Mmse = 0.019875
t = 7; conv = 0.000135182, mse = 0.00898894, Mmse = 0.0184921
t = 8; conv = 3.96345e-05, mse = 0.00884319, Mmse = 0.0180606
t = 9; conv = 1.1802e-05, mse = 0.00879957, Mmse = 0.0179212
t = 10; conv = 3.56629e-06, mse = 0.00878707, Mmse = 0.0178748
t = 11; conv = 1.09252e-06, mse = 0.008784, Mmse = 0.0178591
t = 12; conv = 3.38828e-07, mse = 0.00878365, Mmse = 0.0178539
t = 13; conv = 1.06222e-07, mse = 0.00878397, Mmse = 0.0178524
t = 14; conv = 3.36134e-08, mse = 0.00878436, Mmse = 0.0178521
t = 15; conv = 1.07271e-08, mse = 0.00878468, Mmse = 0.01785

In [28]:
Xhat

tensor([[-1.5269e+00],
        [ 7.7240e-01],
        [ 1.1646e-02],
        [ 2.8466e-01],
        [ 1.0193e+00],
        [ 1.1542e+00],
        [-1.8794e-01],
        [ 1.7347e-01],
        [-1.7144e+00],
        [ 2.7817e+00],
        [-1.5013e+00],
        [ 2.9375e+00],
        [ 8.6295e-02],
        [-3.6084e-01],
        [-5.3652e-01],
        [-8.8199e-02],
        [-1.2263e+00],
        [-2.0780e+00],
        [-5.2413e-01],
        [-1.3485e+00],
        [-6.3506e-01],
        [-6.5632e-01],
        [ 1.3542e+00],
        [ 9.4088e-01],
        [ 4.5054e-01],
        [ 1.2298e+00],
        [-5.2991e-01],
        [ 1.3896e+00],
        [-8.0773e-01],
        [ 1.1107e+00],
        [-1.7560e-01],
        [ 3.9748e-01],
        [-4.5544e-01],
        [-1.0322e+00],
        [-1.3727e+00],
        [ 9.6144e-01],
        [ 2.1044e+00],
        [-1.7616e+00],
        [-3.8328e-01],
        [-6.0522e-01],
        [ 1.8037e+00],
        [ 9.2775e-01],
        [ 3.4811e-01],
        [-1

In [22]:
sum(Xhat*Xhat)/300

tensor([300.0001])