# PCA from decimated data — **global** PCA (recommended)

**Why global PCA:** A per-sequence or per–time-index PCA gives different coordinate systems per sample (PC1 in shot A ≠ PC1 in shot B). That breaks learnability: the NN sees incompatible features. A **global** PCA gives a consistent embedding: every time point in every sequence is projected onto the same axes.

**Pipeline:**
- Split train/val/test at the **sequence (shot) level**.
- Fit PCA **only on training time points**: stream through training shots, accumulate mean μ ∈ ℝ¹⁶⁰ and covariance C ∈ ℝ¹⁶⁰×¹⁶⁰ with a single-pass online algorithm (Welford). Memory stays O(160²). Eigendecompose C and take top K eigenvectors W_K.
- Transform **every** sequence: z_t = (x_t − μ) W_K → each shot becomes Z^(n) ∈ ℝ^(T_n × K). Save as `*_pca` (dsrpt_decimated_pca, clear_decimated_pca).

In [None]:
import numpy as np
import pandas as pd
import h5py
from pathlib import Path
from sklearn.decomposition import PCA
from tqdm import tqdm

# Paths — decimated inputs (disruptive + clear)
DECIMATED_ROOT = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei/dsrpt_decimated')
CLEAR_DECIMATED_ROOT = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei/clear_decimated')
# Outputs: N_PCA_OUT in {1, 4, 8, 16} → dirs dsrpt_decimated_pca{N}, clear_decimated_pca{N}
N_PCA_OUT = 1   # number of PCs to write (training default is 1)
OUT_DSRPT_PCA = DECIMATED_ROOT.parent / f'dsrpt_decimated_pca{N_PCA_OUT}'
OUT_CLEAR_PCA = DECIMATED_ROOT.parent / f'clear_decimated_pca{N_PCA_OUT}'

N_COMPONENTS = 16   # fit top K (global PCA); must be >= N_PCA_OUT
CHANNELS = 20 * 8   # 160

## 1. List shots, lengths, and splits (train/val/test)

In [None]:
def get_shot_lengths(root: Path) -> dict[int, int]:
    """Return {shot_id: T} for each .h5 in root."""
    out = {}
    for p in root.glob('*.h5'):
        if not p.stem.isdigit():
            continue
        shot = int(p.stem)
        with h5py.File(p, 'r') as f:
            T = f['LFS'].shape[-1]
        out[shot] = T
    return out

def get_splits(root: Path, lengths: dict) -> dict[int, str]:
    """Return {shot_id: 'train'|'val'|'test'} from meta.csv or 80/20 default."""
    if (root / 'meta.csv').exists():
        meta = pd.read_csv(root / 'meta.csv')
        if 'split' in meta.columns:
            return dict(zip(meta['shot'].astype(int), meta['split'].astype(str)))
    # Default: 80% train, 20% test
    shots = list(lengths.keys())
    n = len(shots)
    n_train = int(0.8 * n)
    out = {s: 'train' for s in shots[:n_train]}
    for s in shots[n_train:]:
        out[s] = 'test'
    return out

dsrpt_lengths = get_shot_lengths(DECIMATED_ROOT) if DECIMATED_ROOT.exists() else {}
clear_lengths = get_shot_lengths(CLEAR_DECIMATED_ROOT) if CLEAR_DECIMATED_ROOT.exists() else {}
dsrpt_splits = get_splits(DECIMATED_ROOT, dsrpt_lengths) if dsrpt_lengths else {}
clear_splits = get_splits(CLEAR_DECIMATED_ROOT, clear_lengths) if clear_lengths else {}

train_shots_d = [s for s, sp in dsrpt_splits.items() if sp == 'train']
train_shots_c = [s for s, sp in clear_splits.items() if sp == 'train']
print(f'Disruptive: {len(dsrpt_lengths)} shots  (train={len(train_shots_d)})')
print(f'Clear:      {len(clear_lengths)} shots  (train={len(train_shots_c)})')

## 2. Streaming mean and covariance (training time points only)

In [None]:
# Welford in batches: merge (n, mean, M2) with batch (B, 160). Memory O(160²) + O(B*160).
BATCH_SIZE = 10_000   # time points per batch (vectorized update)

def welford_merge(n, mean, M2, batch: np.ndarray):
    """Merge batch (B, 160) into running stats. Returns (n_new, mean_new, M2_new)."""
    B = len(batch)
    if B == 0:
        return n, mean, M2
    batch_mean = batch.mean(axis=0)
    batch_M2 = (batch - batch_mean).T @ (batch - batch_mean)
    n_new = n + B
    mean_new = (n * mean + B * batch_mean) / n_new
    delta = mean - batch_mean
    M2_new = M2 + batch_M2 + (n * B / n_new) * np.outer(delta, delta)
    return n_new, mean_new, M2_new

def load_shot_flat(root: Path, shot: int, T: int) -> np.ndarray:
    """Load one shot as (T, 160) float64."""
    with h5py.File(root / f'{shot}.h5', 'r') as f:
        data = np.asarray(f['LFS'][:], dtype=np.float64)  # (20, 8, T)
    return data.reshape(CHANNELS, -1).T  # (T, 160)

def transform_shot_global(root: Path, shot: int, T: int, mu: np.ndarray, W_K: np.ndarray) -> np.ndarray:
    """z_t = (x_t - μ) W_K. Return (K, T) for LFS layout."""
    X = load_shot_flat(root, shot, T)  # (T, 160)
    Z = (X - mu) @ W_K  # (T, K)
    return Z.T.astype(np.float32)  # (K, T)

# Training shots only: (root, shot, T)
train_tuples = [(DECIMATED_ROOT, s, dsrpt_lengths[s]) for s in train_shots_d]
train_tuples += [(CLEAR_DECIMATED_ROOT, s, clear_lengths[s]) for s in train_shots_c]

n_total = 0
mean = np.zeros(CHANNELS, dtype=np.float64)
M2 = np.zeros((CHANNELS, CHANNELS), dtype=np.float64)

for root, shot, T in tqdm(train_tuples, desc='Streaming μ and M2 (train only)'):
    X = load_shot_flat(root, shot, T)  # (T, 160)
    for start in range(0, len(X), BATCH_SIZE):
        batch = X[start:start + BATCH_SIZE]
        n_total, mean, M2 = welford_merge(n_total, mean, M2, batch)

print(f'Training time points: {n_total}')
if n_total < 2:
    raise ValueError('Need at least 2 training time points to fit PCA')

## 3. Eigendecompose C and take top K eigenvectors

In [None]:
C = M2 / (n_total - 1)
eigenvalues, eigenvectors = np.linalg.eigh(C)
idx = np.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

K = min(N_COMPONENTS, CHANNELS, len(eigenvalues))
mu = mean.copy()
W_K = eigenvectors[:, :K].astype(np.float32)  # (160, K)

var_explained = eigenvalues / eigenvalues.sum()
print(f'Global PCA: μ ∈ R^{CHANNELS}, W_K ∈ R^{CHANNELS}×{K}')
print(f'Cumulative variance (top {K}): {var_explained[:K].sum():.4f}')

## 3b. How many components for ≥99% variance? (from global C)

In [None]:
target_variance = 0.99
cumvar = np.cumsum(var_explained)
k_99 = int(np.searchsorted(cumvar, target_variance)) + 1
k_99 = min(k_99, len(cumvar))
print(f'Components needed for ≥{target_variance*100:.0f}% variance (global PCA): {k_99}')
print(f'  → Consider N_COMPONENTS >= {k_99} to retain ≥99% of the information.')

## 3c. Save PCA result (top 16 PCs; write 1/4/8/16 via N_PCA_OUT in 3d)

In [None]:
N_PC_SAVE = 16   # save top 16 so we can write _pca1, _pca4, _pca8, _pca16
W_full = eigenvectors[:, :N_PC_SAVE].astype(np.float32)  # (160, 16)
pca_save_path = Path('pca_global_top16.npz')
np.savez(pca_save_path, mu=mu.astype(np.float32), W_K=W_full, var_explained=var_explained[:N_PC_SAVE].astype(np.float32))

info_saved = var_explained[:N_PC_SAVE].sum()
print(f'Saved top {N_PC_SAVE} PCs to {pca_save_path}')
print(f'Variance (information) retained: {info_saved*100:.2f}%')
print(f'Per-component: {var_explained[:N_PC_SAVE]}')

## 3d. Save transformed data (top N_PCA_OUT PCs → *_pca1, *_pca4, *_pca8, *_pca16)

In [None]:
W_save = W_full[:, :N_PCA_OUT]

OUT_DSRPT_PCA.mkdir(parents=True, exist_ok=True)
if (DECIMATED_ROOT / 'meta.csv').exists():
    import shutil
    shutil.copy(DECIMATED_ROOT / 'meta.csv', OUT_DSRPT_PCA / 'meta.csv')
for shot, T in tqdm(list(dsrpt_lengths.items()), desc=f'Save disruptive {N_PCA_OUT}-PC'):
    data = transform_shot_global(DECIMATED_ROOT, shot, T, mu, W_save)
    with h5py.File(OUT_DSRPT_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)
print(f'Saved {len(dsrpt_lengths)} shots to {OUT_DSRPT_PCA} (shape {N_PCA_OUT}×T per shot)')

OUT_CLEAR_PCA.mkdir(parents=True, exist_ok=True)
if CLEAR_DECIMATED_ROOT.exists() and (CLEAR_DECIMATED_ROOT / 'meta.csv').exists():
    shutil.copy(CLEAR_DECIMATED_ROOT / 'meta.csv', OUT_CLEAR_PCA / 'meta.csv')
for shot, T in tqdm(list(clear_lengths.items()), desc=f'Save clear {N_PCA_OUT}-PC'):
    data = transform_shot_global(CLEAR_DECIMATED_ROOT, shot, T, mu, W_save)
    with h5py.File(OUT_CLEAR_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)
print(f'Saved {len(clear_lengths)} shots to {OUT_CLEAR_PCA} (shape {N_PCA_OUT}×T per shot)')

## 4. Transform and save disruptive shots -> dsrpt_decimated_pca

In [None]:
OUT_DSRPT_PCA.mkdir(parents=True, exist_ok=True)
if (DECIMATED_ROOT / 'meta.csv').exists():
    import shutil
    shutil.copy(DECIMATED_ROOT / 'meta.csv', OUT_DSRPT_PCA / 'meta.csv')

for shot, T in tqdm(list(dsrpt_lengths.items()), desc='Save disruptive PCA'):
    data = transform_shot_global(DECIMATED_ROOT, shot, T, mu, W_K)
    with h5py.File(OUT_DSRPT_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)

print(f'Saved {len(dsrpt_lengths)} shots to {OUT_DSRPT_PCA}')

## 5. Transform and save clear shots -> clear_decimated_pca

In [None]:
OUT_CLEAR_PCA.mkdir(parents=True, exist_ok=True)
if CLEAR_DECIMATED_ROOT.exists() and (CLEAR_DECIMATED_ROOT / 'meta.csv').exists():
    import shutil
    shutil.copy(CLEAR_DECIMATED_ROOT / 'meta.csv', OUT_CLEAR_PCA / 'meta.csv')

for shot, T in tqdm(list(clear_lengths.items()), desc='Save clear PCA'):
    data = transform_shot_global(CLEAR_DECIMATED_ROOT, shot, T, mu, W_K)
    with h5py.File(OUT_CLEAR_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)

print(f'Saved {len(clear_lengths)} shots to {OUT_CLEAR_PCA}')

## 6. Sanity: shapes

In [None]:
if dsrpt_lengths:
    shot0 = list(dsrpt_lengths.keys())[0]
    with h5py.File(OUT_DSRPT_PCA / f'{shot0}.h5', 'r') as f:
        sh = f['LFS'].shape
    print(f'Disruptive PCA example: shot {shot0} LFS shape = {sh} (N_components, T)')
if clear_lengths:
    shot0 = list(clear_lengths.keys())[0]
    with h5py.File(OUT_CLEAR_PCA / f'{shot0}.h5', 'r') as f:
        sh = f['LFS'].shape
    print(f'Clear PCA example:      shot {shot0} LFS shape = {sh} (N_components, T)')