In [1]:
import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import numpy as np
import spax
from datetime import datetime

# Principle Component Analysis (PCA)
Let X be $n \times m$ matrix, where $n$ represents dimensionality of the data and $m$ number of samples.
To prepare data for the PCA, it should be centered `X -= X.mean(axis = 1)`, and possibly whitened `X /= X.std(axis = 1)`. Whitening is done to scale dimensions to unit variance. This is specifically useful if data dimensions are in different units or represent different observables. PCA then consists of finding a rotation matrix which makes covariance of the data diagonal. Specifically, PCA of order $N$ does this rotation for $N$ largest eigenvalues of the covariance matrix.

## Singular Value Decomposition (SVD)
Every non-quadratic matrix can be decomposed as $$X = U \, S \, V^T \, ,$$ where $U$ is $n \times m$ matrix with orthonormal columns ($U^T \, U = I$), $V$ is $m \times m$ orthonormal matrix ($V^T \, V = I$) and $S$ a diagonal $m\times m$ matrix of singular values. From such decomposition one can write
\begin{align}
X \, X^T &= U \, S^2 \, U^T  \, , \\
X^T X &= V \, S^2 \, V^T \, .
\end{align}
Both $X \, X^T$ and $X^T X$ have the same eigenvalues, with larger of the two having the rest equal to $0$.

## Case #1: $n \le m$
Covariance of the data can be written as $$ C = \frac{1}{m-1} X \, X^T = U \, \frac{S^2}{m-1} \, U^T \, .$$ Therefore, by solving an eigenvalue problem for $C$, we can find $U$. By picking $N$ eigenvectors in the directions of the largest eigenvalues, we construct $\widetilde{U}$ used for PCA.

For some matrix $X_0$ of size $n \times m_0$, PCA is simply $$ Y_0 = \widetilde{U}^T X_0$$.

In [2]:
pca = spax.PCA_m(5, devices = jax.devices("gpu"))
data = np.random.normal(0, 1, size = (8, 1000000)) * np.sqrt(np.arange(1, 9))[:, np.newaxis]

pca.fit(data, batch_size = 2) # N_dim % (N_devices * batch_size) == 0
sampled_data = pca.sample(1000000)
print(np.std(pca.transform(sampled_data), axis = 1, ddof = 1)**2) # should be [4, 5, 6, 7, 8]
print(pca.λ ** 2) # should be the same
print(np.round(pca.U.T, 1)) # should be a +-unit matrix on last 5 dimensions

[4.0087023 5.0025253 6.008477  6.993668  7.9928746]
[3.9976556 5.001418  6.008531  6.989715  8.001953 ]
[[-0. -0. -0.  1.  0. -0. -0. -0.]
 [ 0. -0. -0. -0.  1. -0. -0.  0.]
 [ 0.  0.  0.  0.  0.  1.  0.  0.]
 [-0. -0.  0.  0.  0. -0.  1. -0.]
 [-0.  0. -0.  0. -0. -0.  0.  1.]]


## Case #2 $n \ge m$
In this case, it is better to write: $$D = \frac{1}{n} X^T \, X = V \, \frac{S^2}{n} \, V^T \, .$$
Solving eigenvector problem for $D$ gives us $V$ and $S$. Then, rotation matrix can be computed as $$ U = X \, V \, S^{-1} \, .$$ The rest is the same as in previous case.

In [3]:
pca = spax.PCA_m(5, devices = jax.devices("gpu"))
data = np.random.normal(0, 1, size = (10**5, 10**4)).astype(np.float32) * np.sqrt(np.arange(1, 10**5 + 1)).astype(np.float32)[:, np.newaxis]

In [4]:
tic = datetime.now()
pca.fit(data, batch_size = 5 * 10**4, centering_data = "CPU") # N_dim % (N_devices * batch_size) == 0
print("DURATION:", datetime.now() - tic)
sampled_data = pca.sample(10**4)
print(np.std(pca.transform(sampled_data), axis = 1, ddof = 1)**2)
print(pca.λ ** 2)
print(np.round(pca.U.T[:, -10:], 3))

DURATION: 0:00:45.662287
[941537.75 923491.56 916237.8  974216.4  917588.  ]
[937998.07002875 938286.72793243 939743.96863667 940160.34365363
 942715.25324863]
[[-0.006 -0.003 -0.     0.009  0.002  0.001 -0.001  0.002  0.004  0.001]
 [-0.001  0.005 -0.005  0.003  0.003  0.     0.001 -0.006 -0.002 -0.005]
 [ 0.007 -0.002  0.011  0.003 -0.008 -0.003 -0.006  0.005  0.001  0.011]
 [-0.004  0.002 -0.001 -0.008 -0.005  0.     0.006  0.006 -0.003 -0.007]
 [-0.003 -0.003 -0.006 -0.004  0.003 -0.001 -0.003  0.001  0.001 -0.001]]


## Test save + load

In [5]:
pca.save("test.hdf5")
pca_new = spax.PCA_m(5, devices = jax.devices("gpu"))
pca_new.load("test.hdf5")

x0 = np.arange(100000)[:, np.newaxis]
print(pca.transform(x0))
print(pca_new.transform(x0))

[[ 37038.668]
 [-98532.23 ]
 [-24165.244]
 [ 91972.45 ]
 [-44349.188]]
[[ 37038.668]
 [-98532.23 ]
 [-24165.244]
 [ 91972.45 ]
 [-44349.188]]
