## Tensor factorization with MAP AMP with p = 3 

This is the MAP version of AMP for a tensor factorization problem with p=3

In [37]:
import torch
from math import sqrt


def prior_l2(B):
    """l2 norm prior"""
    N, _ = B.shape
    return B /torch.sqrt(torch.mean(B*B)), 1 /torch.sqrt(torch.mean(B*B))

# Solver
def amp(T, DELTA3,
        prior=prior_l2, true_coef=None,
        max_iter=250, tol=1e-13, verbose=1):
    """Iterate VAMP equations"""

    N, _, _ = T.shape

    # Initialize variables
    B3 = torch.zeros(N, 1)
    a = torch.randn(N, 1)
    a_old = torch.zeros(N, 1)
    c = 1.
    S2 = sqrt(2)
    SN = sqrt(N)
    K = S2 / (DELTA3 * N)
    if verbose > 0:
        print("time ; convergence mse, Matrix-mmse ")
    for t in range(max_iter):
        # Messages/estimates on x from likelihood
        B3old = torch.clone(B3) 
        B3 = (K * torch.transpose(a, 0, 1)@T@a).view(N, -1)
        B3 = B3 - 2 / (DELTA3) * c * (torch.mean(a * a_old)) * a_old
        a_old = torch.clone(a)
        a , c = prior(B3)           
        # Compute metrics
        conv = torch.mean((a - a_old) ** 2)
        overl = torch.abs(torch.mean (a * X0))
        mse = torch.mean((a - true_coef) ** 2) if true_coef is not None else 0.
        Mmse = torch.mean((a@torch.transpose(a, 0, 1) - (true_coef@torch.transpose(true_coef, 0, 1)))**2) if true_coef is not None else 0.
        if verbose > 0:
            print("t = %d; conv = %g, mse = %g, Mmse = %g" % (t, conv, mse, Mmse))
        if conv < tol:
            break
    return a

In [40]:
import torch
# We keep CPU by default, if you want GPU uncomment the next line
# torch.set_default_tensor_type('torch.cuda.FloatTensor')#GPU by default

from math import sqrt
import time

N = 300
DELTA3 = 0.001

X0 = torch.randn(N,1)
Y = X0@torch.transpose(X0, 0, 1)
T = Y.view(N, N, 1)@torch.transpose(X0, 0, 1)

R0 = torch.randn(N,N,N)
R1 = torch.transpose(R0,0,1)
R2 = torch.transpose(R0,0,2)
R3 = torch.transpose(R0,1,2)
R4 = torch.transpose(R1,1,2)
R5 = torch.transpose(R2,1,2)
RUMORE = (R0 + R1 + R2 + R3 + R4 +R5)/sqrt(6)

T = (sqrt(2) / N) * T + RUMORE * sqrt(DELTA3)

t = time.time()
Xhat = amp(T, DELTA3,
       prior=prior_l2, 
           true_coef=X0, max_iter=250)
elapsed = time.time() - t

time ; convergence mse, Matrix-mmse 
t = 0; conv = 1.81819, mse = 1.6485, Mmse = 2.02794
t = 1; conv = 1.513, mse = 0.361101, Mmse = 0.677777
t = 2; conv = 0.343269, mse = 0.00251946, Mmse = 0.00660462
t = 3; conv = 0.000917501, mse = 0.00156153, Mmse = 0.0046394
t = 4; conv = 2.35715e-06, mse = 0.00155284, Mmse = 0.00462158
t = 5; conv = 8.23503e-09, mse = 0.00155289, Mmse = 0.0046217
t = 6; conv = 3.38215e-11, mse = 0.00155287, Mmse = 0.00462165
t = 7; conv = 1.5962e-13, mse = 0.00155288, Mmse = 0.00462166
t = 8; conv = 7.1977e-15, mse = 0.00155288, Mmse = 0.00462165
