## Tensor factorization with AMP with p = 3 

In [1]:
import torch
from math import sqrt


# Priors
def prior_gb(A, B, prmts):
    """Compute f_a and f_c for Gauss-Bernoulli prior"""

    rho, mu, sig = prmts

    m = (B * sig + mu) / (1. + A * sig)
    v = sig / (1 + A * sig)
    keep = (rho + (1 - rho) * torch.sqrt(1. + A * sig) *
            torch.exp(-.5 * m ** 2 / v + .5 * mu ** 2 / sig))
    p_s = rho / keep

    a = p_s * m
    c = p_s * v + p_s * (1. - p_s) * m ** 2
    return a, torch.mean(c)


# Solver
def amp(T, Delta3,
        prior=prior_gb, prior_prmts=None, true_coef=None,
        max_iter=250, tol=1e-13, verbose=1):
    """Iterate VAMP equations"""

    N, _, _ = T.shape

    # Initialize variables
    B3 = torch.zeros(N, 1)
    A3 = 0.
    a = torch.randn(N, 1)
    a_old = torch.zeros(N, 1)
    c = 1.
    S2 = sqrt(2)
    SN = sqrt(N)
    K = S2 / (Delta3 * N)
    if verbose > 0:
        print("time ; convergence mse, Matrix-mmse ")
    for t in range(max_iter):
        # Messages/estimates on x from likelihood
        B3 = (K * torch.transpose(a, 0, 1)@T@a).view(N, -1)
        B3 = B3 - (2 / Delta3) * c * (torch.mean(a * a_old)) * a_old
        A3 = (1. / Delta3) * torch.mean(a * a)**2
        a_old = torch.clone(a)
        
        a, c = prior(A3, B3, prior_prmts)        
        a = 0.5 * a + 0.5 * a_old

        # Compute metrics
        conv = torch.mean((a - a_old) ** 2)
        overl = torch.abs(torch.mean (a * X0))
        mse = torch.mean((a - true_coef) ** 2) if true_coef is not None else 0.
        Mmse = torch.mean((a@torch.transpose(a, 0, 1) - (true_coef@torch.transpose(true_coef, 0, 1)))**2) if true_coef is not None else 0.
        if verbose > 0:
            print("t = %d; conv = %g, mse = %g, Mmse = %g" % (t, conv, mse, Mmse))
        if conv < tol:
            break
    return a

In [4]:
import torch
# We keep CPU by default, if you want GPU uncomment the next line
# torch.set_default_tensor_type('torch.cuda.FloatTensor')#GPU by default

from math import sqrt
import time

N = 300
DELTA3 = 0.001

X0 = torch.randn(N,1)
Y = X0@torch.transpose(X0, 0, 1)
T = Y.view(N, N, 1)@torch.transpose(X0, 0, 1)

R0 = torch.randn(N,N,N)
R1 = torch.transpose(R0,0,1)
R2 = torch.transpose(R0,0,2)
R3 = torch.transpose(R0,1,2)
R4 = torch.transpose(R1,1,2)
R5 = torch.transpose(R2,1,2)
RUMORE = (R0 + R1 + R2 + R3 + R4 +R5)/sqrt(6)

T = (sqrt(2) / N) * T + RUMORE * sqrt(DELTA3)
t = time.time()
Xhat = amp(T, DELTA3,
       prior=prior_gb, prior_prmts=(1, 0, 1),
           true_coef=X0, max_iter=250)
elapsed = time.time() - t

time ; convergence mse, Matrix-mmse 
t = 0; conv = 0.244524, mse = 1.34885, Mmse = 1.18319
t = 1; conv = 0.0626158, mse = 1.18884, Mmse = 1.12823
t = 2; conv = 0.0632123, mse = 1.29586, Mmse = 1.17292
t = 3; conv = 0.0115716, mse = 1.24559, Mmse = 1.1487
t = 4; conv = 0.0121815, mse = 1.33051, Mmse = 1.18281
t = 5; conv = 0.00332024, mse = 1.28977, Mmse = 1.16627
t = 6; conv = 0.00325358, mse = 1.32554, Mmse = 1.18391
t = 7; conv = 0.0014046, mse = 1.3099, Mmse = 1.17731
t = 8; conv = 0.00129132, mse = 1.32577, Mmse = 1.18566
t = 9; conv = 0.00075553, mse = 1.32122, Mmse = 1.18365
t = 10; conv = 0.000646842, mse = 1.32878, Mmse = 1.18744
t = 11; conv = 0.000467813, mse = 1.32866, Mmse = 1.18714
t = 12; conv = 0.000400123, mse = 1.33298, Mmse = 1.18895
t = 13; conv = 0.000325376, mse = 1.33443, Mmse = 1.1892
t = 14; conv = 0.000283105, mse = 1.3374, Mmse = 1.19018
t = 15; conv = 0.00024339, mse = 1.33926, Mmse = 1.19056
t = 16; conv = 0.000215587, mse = 1.34158, Mmse = 1.19117
t = 17; c

t = 144; conv = 1.41532e-05, mse = 1.39538, Mmse = 1.20398
t = 145; conv = 1.38815e-05, mse = 1.39528, Mmse = 1.20402
t = 146; conv = 1.35943e-05, mse = 1.39518, Mmse = 1.20407
t = 147; conv = 1.3293e-05, mse = 1.39507, Mmse = 1.20411
t = 148; conv = 1.2979e-05, mse = 1.39496, Mmse = 1.20415
t = 149; conv = 1.26538e-05, mse = 1.39485, Mmse = 1.20418
t = 150; conv = 1.23191e-05, mse = 1.39473, Mmse = 1.20422
t = 151; conv = 1.19765e-05, mse = 1.39461, Mmse = 1.20426
t = 152; conv = 1.16277e-05, mse = 1.39448, Mmse = 1.2043
t = 153; conv = 1.12745e-05, mse = 1.39435, Mmse = 1.20433
t = 154; conv = 1.09183e-05, mse = 1.39422, Mmse = 1.20437
t = 155; conv = 1.05609e-05, mse = 1.39409, Mmse = 1.2044
t = 156; conv = 1.02038e-05, mse = 1.39396, Mmse = 1.20444
t = 157; conv = 9.84843e-06, mse = 1.39382, Mmse = 1.20447
t = 158; conv = 9.49618e-06, mse = 1.39369, Mmse = 1.2045
t = 159; conv = 9.14834e-06, mse = 1.39355, Mmse = 1.20453
t = 160; conv = 8.80598e-06, mse = 1.39341, Mmse = 1.20456
t 