In [1]:
import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import numpy as np
np.set_printoptions(precision=5, linewidth=150)
import spax
from datetime import datetime
from sklearn.decomposition import PCA

# Principle Component Analysis (PCA)
Let X be $n \times m$ matrix, where $n$ represents dimensionality of the data and $m$ number of samples.
To prepare data for the PCA, it should be centered `X -= X.mean(axis = 1)`, and possibly whitened `X /= X.std(axis = 1)`. Whitening is done to scale dimensions to unit variance. This is specifically useful if data dimensions are in different units or represent different observables. PCA then consists of finding a rotation matrix which makes covariance of the data diagonal. Specifically, PCA of order $N$ does this rotation for $N$ largest eigenvalues of the covariance matrix.

## Singular Value Decomposition (SVD)
Every non-quadratic matrix can be decomposed as $$X = U \, S \, V^T \, ,$$ where $U$ is $n \times m$ matrix with orthonormal columns ($U^T \, U = I$), $V$ is $m \times m$ orthonormal matrix ($V^T \, V = I$) and $S$ a diagonal $m\times m$ matrix of singular values. From such decomposition one can write
\begin{align}
X \, X^T &= U \, S^2 \, U^T  \, , \\
X^T X &= V \, S^2 \, V^T \, .
\end{align}
Both $X \, X^T$ and $X^T X$ have the same eigenvalues, with larger of the two having the rest equal to $0$.

## Case #1: $n \le m$
Covariance of the data can be written as $$ C = \frac{1}{m-1} X \, X^T = U \, \frac{S^2}{m-1} \, U^T \, .$$ Therefore, by solving an eigenvalue problem for $C$, we can find $U$. By picking $N$ eigenvectors in the directions of the largest eigenvalues, we construct $\widetilde{U}$ used for PCA.

For some matrix $X_0$ of size $n \times m_0$, PCA is simply $$ Y_0 = \widetilde{U}^T X_0$$.

In [2]:
N_dim, N_samples = (8, 10**5)
pca = spax.PCA_m(5, devices = jax.devices("gpu"))
data = np.random.normal(0, 1, size = (N_dim, N_samples)) * np.sqrt(np.arange(1, N_dim + 1))[:, np.newaxis]
tic = datetime.now()
pca.fit(data, batch_size = N_dim // 2) # N_dim % (N_devices * batch_size) == 0
print("DURATION:", datetime.now() - tic)
sampled_data = pca.sample(N_samples)
print(np.std(pca.transform(sampled_data), axis = 1)**2) # should be [8, 7, 6, 5, 4]
print(pca.λ ** 2) # should be the same
print(np.round(pca.U.T, 1)) # should be a +-unit matrix on last 5 dimensions

DURATION: 0:00:03.698620
[8.12089 7.02699 5.98402 4.98791 3.96009]
[8.06495 7.02951 5.96893 4.9869  3.98807]
[[-0.  0.  0. -0.  0.  0.  0.  1.]
 [-0. -0.  0.  0. -0. -0.  1. -0.]
 [ 0.  0. -0.  0. -0.  1.  0. -0.]
 [ 0. -0. -0.  0.  1.  0.  0. -0.]
 [ 0.  0.  0.  1. -0. -0. -0.  0.]]


In [3]:
#testing the result with scikit-learn
pca_sk = PCA(n_components = 5)
tic = datetime.now()
pca_sk.fit(data.T)
print("DURATION:", datetime.now() - tic)
print(np.std(pca_sk.transform(sampled_data.T), axis = 0)**2)
print(pca_sk.singular_values_**2 / (N_samples - 1))
print(np.round(pca_sk.components_, 1))

DURATION: 0:00:00.174155
[8.12079 7.02662 5.98345 4.98718 3.96037]
[8.0649  7.02917 5.96963 4.98692 3.98784]
[[ 0.  0.  0. -0.  0.  0.  0.  1.]
 [ 0.  0. -0. -0.  0.  0. -1.  0.]
 [ 0. -0.  0. -0.  0. -1. -0.  0.]
 [ 0.  0. -0. -0.  1.  0.  0. -0.]
 [-0. -0. -0. -1. -0.  0.  0. -0.]]


## Case #2 $n \ge m$
In this case, it is better to write: $$D = \frac{1}{n} X^T \, X = V \, \frac{S^2}{n} \, V^T \, .$$
Solving eigenvector problem for $D$ gives us $V$ and $S$. Then, rotation matrix can be computed as $$ U = X \, V \, S^{-1} \, .$$ The rest is the same as in previous case.

In [4]:
N_dim, N_samples = (10**5, 8)
pca = spax.PCA_m(5, devices = jax.devices("gpu"))
data = np.random.normal(0, 1, size = (N_dim, N_samples)) * np.sqrt(np.arange(1, N_dim + 1))[:, np.newaxis]
tic = datetime.now()
pca.fit(data, batch_size = N_dim // 2, centering_data = "CPU") # N_dim % (N_devices * batch_size) == 0
print("DURATION:", datetime.now() - tic)
sampled_data = pca.sample(N_samples)
print(np.std(pca.transform(sampled_data), axis = 1)**2)
print(pca.λ ** 2)
print(pca.U.T)

DURATION: 0:00:01.556548
[1.23227e+09 6.08565e+08 5.36228e+08 4.96250e+08 4.15412e+08]
[7.22188e+08 7.18227e+08 7.16831e+08 7.11275e+08 7.10198e+08]
[[-1.17121e-05 -1.80550e-05  2.99666e-05 ...  4.43493e-03  8.86737e-03 -3.01205e-03]
 [-1.57520e-05  2.42868e-05  9.46663e-06 ...  1.42047e-03 -5.75001e-04  2.65901e-03]
 [ 1.41894e-05 -7.96396e-06 -1.87146e-05 ... -3.91604e-03  3.46383e-03  6.83253e-03]
 [ 1.75867e-05 -5.89193e-07  1.60927e-05 ... -1.38143e-03 -2.13178e-03  1.18323e-02]
 [-9.72305e-06  1.98087e-06  5.18939e-06 ...  2.10904e-03  4.51607e-04 -1.85308e-04]]


In [5]:
pca_sk = PCA(n_components = 5)
tic = datetime.now()
pca_sk.fit(data.T)
print("DURATION:", datetime.now() - tic)
print(np.std(pca_sk.transform(sampled_data.T), axis = 0)**2)
print(pca_sk.singular_values_**2 / (N_samples - 1))
print(pca_sk.components_)

DURATION: 0:00:00.118940
[1.23228e+09 6.08513e+08 5.36278e+08 4.96249e+08 4.15405e+08]
[7.22188e+08 7.18227e+08 7.16831e+08 7.11274e+08 7.10197e+08]
[[-1.17127e-05 -1.80534e-05  2.99671e-05 ...  4.43506e-03  8.86741e-03 -3.01168e-03]
 [ 1.57517e-05 -2.42885e-05 -9.46558e-06 ... -1.42032e-03  5.75860e-04 -2.65893e-03]
 [ 1.41891e-05 -7.96249e-06 -1.87133e-05 ... -3.91605e-03  3.46364e-03  6.83318e-03]
 [ 1.75858e-05 -5.89092e-07  1.60938e-05 ... -1.38124e-03 -2.13209e-03  1.18320e-02]
 [-9.72343e-06  1.98092e-06  5.18916e-06 ...  2.10920e-03  4.51624e-04 -1.85396e-04]]


## Test save + load

In [6]:
pca.save("test.hdf5")
pca_new = spax.PCA_m(5, devices = jax.devices("gpu"))
pca_new.load("test.hdf5")

x0 = np.arange(100000)[:, np.newaxis]
print(pca.transform(x0).flatten())
print(pca_new.transform(x0).flatten())
print(pca_sk.transform(x0.T)[:, ::-1].flatten())

[  2248.0647 -28313.857   20754.691   68855.44   -59270.133 ]
[  2248.0647 -28313.857   20754.691   68855.44   -59270.133 ]
[-59274.09708  68853.53357  20756.58006  28313.33214   2244.92584]
