In [1]:
import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import numpy as np
np.set_printoptions(precision=5, linewidth=150)
import spax
from datetime import datetime
from sklearn.decomposition import PCA

# Principle Component Analysis (PCA)
Let X be $n \times m$ matrix, where $n$ represents dimensionality of the data and $m$ number of samples.
To prepare data for the PCA, it should be centered `X -= X.mean(axis = 1)`, and possibly whitened `X /= X.std(axis = 1)`. Whitening is done to scale dimensions to unit variance. This is specifically useful if data dimensions are in different units or represent different observables. PCA then consists of finding a rotation matrix which makes covariance of the data diagonal. Specifically, PCA of order $N$ does this rotation for $N$ largest eigenvalues of the covariance matrix.

## Singular Value Decomposition (SVD)
Every non-quadratic matrix can be decomposed as $$X = U \, S \, V^T \, ,$$ where $U$ is $n \times m$ matrix with orthonormal columns ($U^T \, U = I$), $V$ is $m \times m$ orthonormal matrix ($V^T \, V = I$) and $S$ a diagonal $m\times m$ matrix of singular values. From such decomposition one can write
\begin{align}
X \, X^T &= U \, S^2 \, U^T  \, , \\
X^T X &= V \, S^2 \, V^T \, .
\end{align}
Both $X \, X^T$ and $X^T X$ have the same eigenvalues, with larger of the two having the rest equal to $0$.

## Case #1: $n < m$
Covariance of the data can be written as $$ C = \frac{1}{m-1} X \, X^T = U \, \frac{S^2}{m-1} \, U^T \, .$$ Therefore, by solving an eigenvalue problem for $C$, we can find $U$. By picking $N$ eigenvectors in the directions of the largest eigenvalues, we construct $\widetilde{U}$ used for PCA.

For some matrix $X_0$ of size $n \times m_0$, PCA is simply $$ Y_0 = \widetilde{U}^T X_0$$.

In [2]:
N_dim, N_samples = (16, 10**5)
pca = spax.PCA_m(5, devices = jax.devices("gpu"))
data = np.random.normal(0, 1, size = (N_dim, N_samples)) * np.sqrt(np.arange(1, N_dim + 1))[:, np.newaxis]
tic = datetime.now()
pca.fit(data, batch_size = N_dim // 2, centering_data = "GPU") # N_dim % (N_devices * batch_size) == 0
print("DURATION:", datetime.now() - tic)
sampled_data = pca.sample(N_samples, batch_size = N_dim // 2)
print(np.std(pca.transform(data, batch_size = N_dim // 2), axis = 1)**2) # should be [16, 15, 14, 13, 12]
print(np.std(pca.transform(sampled_data, batch_size = N_dim // 2), axis = 1)**2) #should be the same
print(pca.λ ** 2) # should be the same
print(np.round(pca.U.T, 1)) # should be a +-unit matrix on last 5 dimensions

DURATION: 0:00:03.665672
[15.86349 15.01729 14.025   12.97456 11.96139]
[15.88904 15.04109 14.00261 12.92626 12.06539]
[15.86442 15.01822 14.02536 12.97476 11.96212]
[[ 0.  -0.  -0.  -0.   0.  -0.   0.  -0.   0.  -0.   0.   0.   0.   0.  -0.1  1. ]
 [-0.  -0.  -0.   0.  -0.  -0.   0.  -0.   0.   0.  -0.  -0.  -0.  -0.1  1.   0.1]
 [-0.  -0.  -0.  -0.  -0.   0.  -0.  -0.   0.  -0.  -0.   0.  -0.1  1.   0.1 -0. ]
 [ 0.  -0.  -0.  -0.   0.  -0.   0.   0.   0.   0.  -0.  -0.   1.   0.1  0.  -0. ]
 [ 0.   0.   0.   0.   0.  -0.  -0.  -0.  -0.  -0.   0.   1.   0.  -0.   0.   0. ]]


In [3]:
#testing the result with scikit-learn
pca_sk = PCA(n_components = 5)
tic = datetime.now()
pca_sk.fit(data.T)
print("DURATION:", datetime.now() - tic)
print(np.std(pca_sk.transform(data.T), axis = 0)**2)
print(np.std(pca_sk.transform(sampled_data.T), axis = 0)**2)
print(pca_sk.singular_values_**2 / (N_samples - 1))
print(np.round(pca_sk.components_, 1))

DURATION: 0:00:00.258082
[15.86399 15.01809 14.02532 12.97499 11.96181]
[15.88819 15.03968 14.00197 12.92547 12.06467]
[15.86415 15.01824 14.02546 12.97512 11.96193]
[[-0.  -0.   0.   0.  -0.   0.   0.   0.  -0.   0.  -0.  -0.  -0.  -0.   0.1 -1. ]
 [-0.   0.  -0.   0.  -0.   0.  -0.  -0.  -0.  -0.   0.   0.   0.   0.1 -1.  -0.1]
 [-0.  -0.   0.  -0.  -0.   0.  -0.  -0.   0.  -0.  -0.   0.  -0.1  1.   0.1 -0. ]
 [-0.  -0.   0.  -0.  -0.   0.   0.  -0.  -0.  -0.   0.   0.  -1.  -0.1 -0.   0. ]
 [ 0.  -0.  -0.   0.  -0.  -0.   0.  -0.  -0.  -0.   0.   1.   0.  -0.   0.   0. ]]


## Case #2 $n \ge m$
In this case, it is better to write: $$D = \frac{1}{n} X^T \, X = V \, \frac{S^2}{n} \, V^T \, .$$
Solving eigenvector problem for $D$ gives us $V$ and $S$. Then, rotation matrix can be computed as $$ U = X \, V \, S^{-1} \, .$$ The rest is the same as in previous case.

In [4]:
N_dim, N_samples = (10**5, 16)
pca = spax.PCA_m(5, devices = jax.devices("gpu"))
data = np.random.normal(0, 1, size = (N_dim, N_samples)) * np.sqrt(np.arange(1, N_dim + 1))[:, np.newaxis]
tic = datetime.now()
pca.fit(data, batch_size = N_dim // 2, centering_data = "GPU") # N_dim % (N_devices * batch_size) == 0
print("DURATION:", datetime.now() - tic)
sampled_data = pca.sample(10**4, batch_size = N_dim // 2)
print(np.std(pca.transform(data, batch_size = N_dim // 2), axis = 1)**2)
print(np.std(pca.transform(sampled_data, batch_size = N_dim // 2), axis = 1)**2)
print(pca.λ ** 2)
print(pca.U.T)

DURATION: 0:00:01.705780
[3.20480e+08 3.18571e+08 3.17954e+08 3.16769e+08 3.16200e+08]
[3.41459e+08 3.37940e+08 3.43414e+08 3.37691e+08 3.44295e+08]
[3.41847e+08 3.39811e+08 3.39152e+08 3.37887e+08 3.37281e+08]
[[ 6.33442e-06  8.16398e-06 -3.66614e-06 ...  3.28223e-03 -8.05994e-03  2.22728e-03]
 [ 4.32146e-06  2.09783e-05 -5.39127e-06 ...  4.88083e-03 -3.92408e-03  1.12816e-03]
 [-4.10747e-05  4.59198e-06  3.08347e-05 ...  4.53058e-03 -6.94469e-03 -3.38505e-03]
 [ 4.58250e-07  1.24855e-05  1.09394e-05 ...  5.18696e-03  9.30697e-04 -6.11469e-04]
 [-1.19243e-06 -2.17123e-05 -4.26089e-05 ... -3.43697e-03  1.93229e-03 -2.56672e-03]]


In [5]:
pca_sk = PCA(n_components = 5)
tic = datetime.now()
pca_sk.fit(data.T)
print("DURATION:", datetime.now() - tic)
print(np.std(pca_sk.transform(data.T), axis = 0)**2)
print(np.std(pca_sk.transform(sampled_data.T), axis = 0)**2)
print(pca_sk.singular_values_**2 / (N_samples - 1))
print(pca_sk.components_)

DURATION: 0:00:00.166816
[3.20481e+08 3.18571e+08 3.17954e+08 3.16769e+08 3.16200e+08]
[3.41459e+08 3.37940e+08 3.43414e+08 3.37690e+08 3.44296e+08]
[3.41846e+08 3.39809e+08 3.39151e+08 3.37887e+08 3.37281e+08]
[[ 6.33356e-06  8.16345e-06 -3.66521e-06 ...  3.28216e-03 -8.06003e-03  2.22716e-03]
 [-4.31963e-06 -2.09781e-05  5.39206e-06 ... -4.88114e-03  3.92440e-03 -1.12789e-03]
 [ 4.10754e-05 -4.59016e-06 -3.08327e-05 ... -4.53029e-03  6.94427e-03  3.38534e-03]
 [ 4.60913e-07  1.24849e-05  1.09398e-05 ...  5.18659e-03  9.31135e-04 -6.11179e-04]
 [ 1.18826e-06  2.17128e-05  4.26083e-05 ...  3.43740e-03 -1.93294e-03  2.56633e-03]]


## Test save + load

In [6]:
pca.save("test.hdf5")
pca_new = spax.PCA_m(5, devices = jax.devices("gpu"))
pca_new.load("test.hdf5")

x0 = np.arange(100000)[:, np.newaxis]
print(pca.transform(x0, batch_size = N_dim // 2).flatten())
print(pca_new.transform(x0, batch_size = N_dim // 2).flatten())
print(pca_sk.transform(x0.T).flatten())

[-71815.555   29139.97     3841.4258 -87116.92    23392.324 ]
[-71815.555   29139.967    3841.4229 -87116.92    23392.318 ]
[-71812.21997 -29130.60387  -3838.60998 -87106.1801  -23407.71851]
