# A faster way of doing TFA

The basic TFA generative model is: 

$$
Y \mid W, F \sim \mathcal{MN}(WF, \sigma^2 I, I), 
$$

where F is the matrix containing the factors. We can put a trivial i.i.d. normal prior on $W$, giving us the following marginal: 

$$
Y \mid F \sim \mathcal{MN}(0, \sigma^2 I , F^T F + I). 
$$
We can apply the matrix determinant and inversion lemmas to replace the $F^T F$ inverse with a much nicer $FF^T$ inverse, and compute the marginal log-likelihood pretty efficiently, autodiff through it, and optimize by L-BFGS-B. 

In contrast to vanilla TFA, this way of doing things: 
- Gets rid of the coordinate ascent bit completely: we find $F$ marginalizing over $W$ and then compute $W$ in closed form once. 
- Gets rid of finite differences, which is slow. 
- Gets rid of a massive jacobian that needs to be subsampled (there's a tradeoff here: we're not exploiting the least squares structure any more). 
- Can be computed on GPU if we'd like. 

And it's still theoretically compatible with HTFA (just not implemented yet). Let's see how it goes! 

In [1]:
import numpy as np
from brainiak.factoranalysis.tfa import TFA
from brainiak.factoranalysis.fast_tfa import FastTFA, get_factormat

def benchmark_tfa(v, t, k, noise, skip_slow=False):
    R = np.random.randint(2, high=102, size=(v, 3)).astype("float64")
    centers = np.random.randint(2, high=102, size=(k, 3)).astype("float64")
    widths = np.abs(np.random.normal(loc=0, scale=np.std(R)**2, size=(k, 1)))
    F = get_factormat(R, centers, widths).numpy()
    W = np.random.normal(size=(t, k))
    X = W @ F + np.random.normal(size=(t, v))*noise

    fast_tfa = FastTFA(k=k)
    tfa_subsamp = TFA(K=k, max_num_tr=t//10, max_num_voxel=v//20)
    tfa_full = TFA(K=k, max_num_tr=t, max_num_voxel=v)

    print("\nFast TFA:")
    %time fast_tfa.fit(X, R)
    fast_tfa_mse = np.mean((X - fast_tfa.W_ @ fast_tfa.F_)**2)
    print(f"MSE={fast_tfa_mse}")

    print("\nSubsampled regular TFA:")
    %time tfa_subsamp.fit(X.T, R)
    tfa_subsamp_mse = np.mean((X.T - tfa_subsamp.F_ @ tfa_subsamp.W_)**2)
    print(f"MSE={tfa_subsamp_mse}")
    if skip_slow:
        return
    print("\nFull TFA")
    %time tfa_full.fit(X.T, R)
    tfa_full_mse = np.mean((X.T - tfa_full.F_ @ tfa_full.W_)**2)
    print(f"MSE={tfa_full_mse}")

In [2]:
benchmark_tfa(v=100, t=50, k=3, noise=0.1) # tiny problem with basically no noise, sanity check


Fast TFA:
CPU times: user 609 ms, sys: 46.9 ms, total: 656 ms
Wall time: 529 ms
MSE=0.01873000207327146

Subsampled regular TFA:
CPU times: user 5.98 s, sys: 1.81 s, total: 7.8 s
Wall time: 1.68 s
MSE=0.06499834081823767

Full TFA
CPU times: user 38.1 s, sys: 27.3 s, total: 1min 5s
Wall time: 9.08 s
MSE=0.010293511521374237


In [3]:
benchmark_tfa(v=100, t=50, k=3, noise=1) # same tiny problem with a bit more noise


Fast TFA:
CPU times: user 312 ms, sys: 344 ms, total: 656 ms
Wall time: 271 ms
MSE=0.9772830747585769

Subsampled regular TFA:
CPU times: user 2.91 s, sys: 1.59 s, total: 4.5 s
Wall time: 945 ms
MSE=1.001510387817818

Full TFA
CPU times: user 2.27 s, sys: 1.78 s, total: 4.05 s
Wall time: 597 ms
MSE=1.002109741576849


In [4]:
benchmark_tfa(v=500, t=200, k=15, noise=1, skip_slow=True) # still beats subsampling


Fast TFA:
CPU times: user 7.12 s, sys: 1.16 s, total: 8.28 s
Wall time: 1.82 s
MSE=0.977909290668985

Subsampled regular TFA:
CPU times: user 20.9 s, sys: 13.3 s, total: 34.2 s
Wall time: 4.65 s
MSE=1.0278639367093334


In [5]:
benchmark_tfa(v=1000, t=200, k=15, noise=1, skip_slow=True) # larger problem, subsampling starts really helping)


Fast TFA:
CPU times: user 32 s, sys: 1.7 s, total: 33.7 s
Wall time: 5.35 s
MSE=0.9931124439167173

Subsampled regular TFA:
CPU times: user 3.38 s, sys: 1.83 s, total: 5.2 s
Wall time: 751 ms
MSE=1.030644716273053


In [6]:
benchmark_tfa(v=2000, t=400, k=20, noise=1, skip_slow=True) # ok, now it gets really bad for our "fast" TFA


Fast TFA:
CPU times: user 8min 4s, sys: 8.11 s, total: 8min 13s
Wall time: 1min 16s
MSE=0.9904045683317866

Subsampled regular TFA:
CPU times: user 4.59 s, sys: 1.47 s, total: 6.06 s
Wall time: 855 ms
MSE=1.0340202418558002


# Conclusion 
Fast TFA is more accurate at reconstruction (by not subsampling and integrating over W), is faster than full-TFA but not nearly as fast as the subsampling version of regular-TFA for more practical problems. Next step: probably SGD of some sort. 