# A faster way of doing TFA

The basic TFA generative model is: 

$$
Y \mid W, F \sim \mathcal{MN}(WF, \sigma^2 I, I), 
$$

where F is the matrix containing the factors. We can put a trivial i.i.d. normal prior on $W$, giving us the following marginal: 

$$
Y \mid F \sim \mathcal{MN}(0, \sigma^2 I , F^T F + I). 
$$
We can apply the matrix determinant and inversion lemmas to replace the $F^T F$ inverse with a much nicer $FF^T$ inverse, and compute the marginal log-likelihood pretty efficiently, autodiff through it, and optimize by L-BFGS-B. 

In contrast to vanilla TFA, this way of doing things: 
- Gets rid of the coordinate ascent bit completely: we find $F$ marginalizing over $W$ and then compute $W$ in closed form once. 
- Gets rid of finite differences, which is slow. 
- Gets rid of a massive jacobian that needs to be subsampled (there's a tradeoff here: we're not exploiting the least squares structure any more). 
- Can be computed on GPU if we'd like. 

It's also still compatible with naive voxelwise and TRwise subsampling as in regular TFA. 

And it's still theoretically compatible with HTFA (just not implemented yet). Let's see how it goes! 

In [1]:
import numpy as np
from brainiak.factoranalysis.tfa import TFA
from brainiak.factoranalysis.fast_tfa import FastTFA, get_factormat

def benchmark_tfa(v, t, k, noise, maxv, maxt, skip_slow=False, skip_subsampled=False):
    R = np.random.randint(2, high=102, size=(v, 3)).astype("float64")
    centers = np.random.randint(2, high=102, size=(k, 3)).astype("float64")
    widths = np.abs(np.random.normal(loc=0, scale=np.std(R)**2, size=(k, 1)))
    F = get_factormat(R, centers, widths).numpy()
    W = np.random.normal(size=(t, k))
    X = W @ F + np.random.normal(size=(t, v))*noise

    if skip_slow is False:
        print("\nFull Fast TFA:")
        fast_tfa = FastTFA(k=k)
        %timeit -n 1 -r 5 fast_tfa.fit(X, R)
        fast_tfa_mse = np.mean((X - fast_tfa.W_ @ fast_tfa.F_)**2)
        print(f"MSE={fast_tfa_mse}")

        print("\nFull TFA")    
        tfa_full = TFA(K=k, max_num_tr=t, max_num_voxel=v)
        %timeit -n 1 -r 5 tfa_full.fit(X.T, R)
        tfa_full_mse = np.mean((X.T - tfa_full.F_ @ tfa_full.W_)**2)
        print(f"MSE={tfa_full_mse}")
    
    if skip_subsampled is False: 
        print("\nSubsampled Fast TFA:")
        fast_tfa = FastTFA(k=k)
        %timeit -n 1 -r 5 fast_tfa.fit(X, R, subsamp_size_v=maxv, subsamp_size_t=maxt, n_iter=10)
        fast_tfa_mse = np.mean((X - fast_tfa.W_ @ fast_tfa.F_)**2)
        print(f"MSE={fast_tfa_mse}")


        print("\nSubsampled regular TFA:")
        tfa_subsamp = TFA(K=k, max_num_tr=maxt, max_num_voxel=maxv)
        %timeit -n 1 -r 5 tfa_subsamp.fit(X.T, R)
        tfa_subsamp_mse = np.mean((X.T - tfa_subsamp.F_ @ tfa_subsamp.W_)**2)
        print(f"MSE={tfa_subsamp_mse}")

# I think this might force the JIT to do its thing so we can have clean timings below    
R = np.random.randint(2, high=102, size=(10, 3)).astype("float64")
centers = np.random.randint(2, high=102, size=(3, 3)).astype("float64")
widths = np.abs(np.random.normal(loc=0, scale=np.std(R)**2, size=(3, 1)))
F = get_factormat(R, centers, widths).numpy()
W = np.random.normal(size=(5, 3))
X = W @ F + np.random.normal(size=(5, 10))
ftfa = FastTFA(k=3)
_ = ftfa.fit(X, R)

In [2]:
# tiny problem with basically no noise, sanity check, subsampling doesn't make sense
benchmark_tfa(v=200, t=100, k=5, noise=0.1, maxv=200, maxt=100, skip_subsampled=True) 


Full Fast TFA:
235 ms ± 103 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=0.016805094136183363

Full TFA
21.3 s ± 146 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=0.017925624628788255


In [3]:
# same tiny problem with a bit more noise
benchmark_tfa(v=200, t=100, k=5, noise=1, maxv=200, maxt=100, skip_subsampled=True) 


Full Fast TFA:
The slowest run took 6.14 times longer than the fastest. This could mean that an intermediate result is being cached.
168 ms ± 119 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=0.9855854473253423

Full TFA
7.82 s ± 99.1 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=0.9857954102123114


In [4]:
# modest subsampling, no benefit for fast TFA but big benefit for regular TFA
benchmark_tfa(v=500, t=200, k=15, noise=1, maxv=200, maxt=200) 


Full Fast TFA:
2.71 s ± 295 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=0.9719086009432499

Full TFA
1min 35s ± 1.48 s per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=1.011639365409478

Subsampled Fast TFA:
4.66 s ± 463 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=0.9825072316601544

Subsampled regular TFA:
19.9 s ± 98.9 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=1.016532650335811


In [6]:
# no patience to try the non-subsampled one
benchmark_tfa(v=1000, t=200, k=15, noise=1, skip_slow=True, maxv=250, maxt=200) 


Subsampled Fast TFA:
4.91 s ± 845 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=0.995978002703193

Subsampled regular TFA:
38.8 s ± 6.82 s per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=1.0244633334503621


In [7]:
# ROI size
benchmark_tfa(v=2000, t=500, k=20, noise=1, skip_slow=True, maxv=250, maxt=200)


Subsampled Fast TFA:
5.62 s ± 1.63 s per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=1.0009771281581075

Subsampled regular TFA:
53.8 s ± 2.38 s per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=1.0193063341196682


In [8]:
# full brain size
benchmark_tfa(v=100000, t=1250, k=50, noise=1, skip_slow=True, maxv=250, maxt=200)


Subsampled Fast TFA:
21 s ± 1.29 s per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=1.0219173418617433

Subsampled regular TFA:
2min 54s ± 12 s per loop (mean ± std. dev. of 5 runs, 1 loop each)
MSE=1.029504454343515


# Conclusion 
Fast TFA is consistently more accurate at reconstruction (by integrating over W), is faster than regular without subsampling TFA but not nearly as fast as the subsampling version of regular-TFA for more practical problems. 

With subsampling, fast TFA is still more accurate than regular TFA (subsampling or otherwise) and substantially faster than either version. Asymptotically it seems ~10x faster than subsampling TFA with the same subsampling size. 