# PCA from decimated data — **global** PCA (recommended)

**Why global PCA:** A per-sequence or per–time-index PCA gives different coordinate systems per sample (PC1 in shot A ≠ PC1 in shot B). That breaks learnability: the NN sees incompatible features. A **global** PCA gives a consistent embedding: every time point in every sequence is projected onto the same axes.

**Pipeline:**
- **SNR > 3 (paper):** Optionally add `snr_min` to meta (from shot lists or computed from decimated data). Fit PCA and write output **only for shots with SNR > 3**.
- Split train/val/test at the **sequence (shot) level**.
- Fit PCA **only on training time points** (and only on training shots with SNR > 3): stream through those shots, accumulate mean μ ∈ ℝ¹⁶⁰ and covariance C ∈ ℝ¹⁶⁰×¹⁶⁰ with a single-pass online algorithm (Welford). Memory stays O(160²). Eigendecompose C and take top K eigenvectors W_K.
- Transform and save **only SNR > 3 shots**: z_t = (x_t − μ) W_K. Output dirs get `meta.csv` with `snr_min` column so downstream training (e.g. `run_fusion_soen --snr-min-threshold 3.0`) can use the same filter.

In [None]:
import numpy as np
import pandas as pd
import h5py
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from sklearn.decomposition import PCA
from tqdm import tqdm

# Paths — decimated inputs (disruptive + clear)
DECIMATED_ROOT = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei/dsrpt_decimated')
CLEAR_DECIMATED_ROOT = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei/clear_decimated')
# Outputs: N_PCA_OUT in {1, 4, 8, 16} → dirs dsrpt_decimated_pca{N}, clear_decimated_pca{N}
N_PCA_OUT = 1   # number of PCs to write (training default is 1)
OUT_DSRPT_PCA = DECIMATED_ROOT.parent / f'dsrpt_decimated_pca{N_PCA_OUT}'
OUT_CLEAR_PCA = DECIMATED_ROOT.parent / f'clear_decimated_pca{N_PCA_OUT}'

N_COMPONENTS = 16   # fit top K (global PCA); must be >= N_PCA_OUT
CHANNELS = 20 * 8   # 160

# Paper (Churchill et al.): "good ECEi data (SNR > 3)" — fit PCA and write output only for these shots
SNR_MIN_THRESHOLD = 3.0
# If True, computed SNR is 20*log10(ratio) so threshold 3 means 3 dB. If False, SNR = ratio (linear); paper may use linear.
SNR_IN_DB = False
# Baseline fraction of time used as "noise" (first part of shot). Rest = "signal".
SNR_BASELINE_FRAC = 0.1
# Optional: DisruptCNN-format shot lists (columns: Shot, ..., SNR min at index 5). If None, SNR is computed from decimated H5.
SHOT_LIST_DSRPT = None   # e.g. Path('disruptcnn/shots/d3d_disrupt_ecei.final.txt')
SHOT_LIST_CLEAR = None   # e.g. Path('disruptcnn/shots/d3d_clear_ecei.final.txt')

## 0. SNR > 3 (paper): add snr_min to meta, use only for PCA fit and output

Load SNR from DisruptCNN-format shot list files (column "SNR min") if set; otherwise **compute** from decimated H5: for each shot, **noise** = first `SNR_BASELINE_FRAC` of time (e.g. 10%), **signal** = rest. Per channel: ratio = std(signal)/std(noise); **snr_min = min over 160 channels**. Set `SNR_IN_DB=True` to store 20×log10(ratio) (then threshold 3 = 3 dB). **Computed values may not match the paper's curator SNR** — use shot lists when possible. Missing snr_min (NaN) → shot **kept**; only known snr_min ≤ threshold are dropped. Computation runs in parallel over shots.

In [None]:
def _get_shot_lengths(root: Path, meta_path: Path = None) -> dict[int, int]:
    """Get {shot: T} from meta (n_samples/T column) if present, else from H5 shape without loading data."""
    if meta_path and meta_path.exists():
        meta = pd.read_csv(meta_path)
        for col in ('n_samples', 'T', 'length', 'n_samples_decimated'):
            if col in meta.columns and 'shot' in meta.columns:
                return dict(zip(meta['shot'].astype(int), meta[col].astype(int)))
    out = {}
    for p in (root.glob('*.h5') if root.exists() else []):
        if not p.stem.isdigit():
            continue
        try:
            with h5py.File(p, 'r') as f:
                out[int(p.stem)] = f['LFS'].shape[-1]
        except Exception:
            pass
    return out

dsrpt_lengths = _get_shot_lengths(DECIMATED_ROOT, DECIMATED_ROOT / 'meta.csv') if DECIMATED_ROOT.exists() else {}
clear_lengths = _get_shot_lengths(CLEAR_DECIMATED_ROOT, CLEAR_DECIMATED_ROOT / 'meta.csv') if CLEAR_DECIMATED_ROOT.exists() else {}

meta_d = pd.read_csv(DECIMATED_ROOT / 'meta.csv') if (DECIMATED_ROOT / 'meta.csv').exists() else pd.DataFrame(columns=['shot', 'split', 't_disruption'])
meta_c = pd.read_csv(CLEAR_DECIMATED_ROOT / 'meta.csv') if CLEAR_DECIMATED_ROOT.exists() and (CLEAR_DECIMATED_ROOT / 'meta.csv').exists() else pd.DataFrame(columns=['shot', 'split'])

# Good shots = all shots in meta (no SNR filter); fallback to all in lengths if no meta
good_shots_d = set(meta_d['shot'].astype(int)) if not meta_d.empty and 'shot' in meta_d.columns else set(dsrpt_lengths.keys())
good_shots_c = set(meta_c['shot'].astype(int)) if not meta_c.empty and 'shot' in meta_c.columns else set(clear_lengths.keys())
# No SNR computation; leave empty so downstream meta output gets snr_min=NaN
snr_d, snr_c = {}, {}
if meta_d.empty and dsrpt_lengths:
    good_shots_d = set(dsrpt_lengths.keys())
if meta_c.empty and clear_lengths:
    good_shots_c = set(clear_lengths.keys())

print(f'Disruptive: {len(dsrpt_lengths)} shots → {len(good_shots_d)} kept (no SNR filter).')
print(f'Clear:      {len(clear_lengths)} shots → {len(good_shots_c)} kept (no SNR filter).')

### Shot list vs. data: `d3d_disrupt_ecei.final.txt`

Check whether the **disrupt** and **clear** shots used in this notebook (from decimated dirs / meta) are included in the DisruptCNN-format shot list file. Path is relative to the notebook directory or set `SHOT_LIST_FINAL_TXT` to an absolute path.

In [None]:
# Path to DisruptCNN disrupt shot list (disruptcnn dir on idies scratch)
DISRUPTCNN_ROOT = Path('/home/idies/workspace/Temporary/dpark1/scratch/soen_fusion_zero/disruptcnn')
SHOT_LIST_FINAL_TXT = DISRUPTCNN_ROOT / 'shots' / 'd3d_disrupt_ecei.final.txt'
if not SHOT_LIST_FINAL_TXT.is_absolute():
    SHOT_LIST_FINAL_TXT = Path.cwd() / SHOT_LIST_FINAL_TXT

def load_shots_from_final_txt(path: Path) -> set:
    """Load shot IDs (column 0) from DisruptCNN-format shot list; header line starts with #."""
    if not path.exists():
        return set()
    data = np.loadtxt(path, skiprows=1)
    if data.size == 0:
        return set()
    if data.ndim == 1:
        data = data.reshape(1, -1)
    return set(data[:, 0].astype(int))

shots_in_final_txt = load_shots_from_final_txt(SHOT_LIST_FINAL_TXT)
# Shots we use: disrupt = from meta_d or dsrpt_lengths; clear = from meta_c or clear_lengths
disrupt_shots_we_use = set(meta_d['shot'].astype(int)) if not meta_d.empty and 'shot' in meta_d.columns else set(dsrpt_lengths.keys())
clear_shots_we_use = set(meta_c['shot'].astype(int)) if not meta_c.empty and 'shot' in meta_c.columns else set(clear_lengths.keys())

print(f"Shot list file: {SHOT_LIST_FINAL_TXT}")
print(f"  Exists: {SHOT_LIST_FINAL_TXT.exists()}")
print(f"  Shots in file: {len(shots_in_final_txt)}")
print()
print("Disrupt shots (in our data):")
print(f"  Total: {len(disrupt_shots_we_use)}")
if shots_in_final_txt:
    in_both_d = disrupt_shots_we_use & shots_in_final_txt
    only_ours_d = disrupt_shots_we_use - shots_in_final_txt
    only_file_d = shots_in_final_txt - disrupt_shots_we_use
    print(f"  In shot list (included): {len(in_both_d)}")
    print(f"  NOT in shot list (missing from file): {len(only_ours_d)}")
    if only_ours_d:
        print(f"    Example shots only in our data: {sorted(only_ours_d)[:10]}{'...' if len(only_ours_d) > 10 else ''}")
    print(f"  In shot list but NOT in our data: {len(only_file_d)}")
    if only_file_d:
        print(f"    Example shots only in file: {sorted(only_file_d)[:10]}{'...' if len(only_file_d) > 10 else ''}")
print()
print("Clear shots (in our data):")
print(f"  Total: {len(clear_shots_we_use)}")
if shots_in_final_txt:
    in_both_c = clear_shots_we_use & shots_in_final_txt
    only_ours_c = clear_shots_we_use - shots_in_final_txt
    print(f"  In shot list (disrupt file; clear usually 0): {len(in_both_c)}")
    print(f"  NOT in shot list: {len(only_ours_c)}")

## 1. List shots, lengths, and splits (train/val/test)

In [None]:
def get_shot_lengths(root: Path) -> dict[int, int]:
    """Return {shot_id: T} for each .h5 in root."""
    out = {}
    for p in root.glob('*.h5'):
        if not p.stem.isdigit():
            continue
        shot = int(p.stem)
        with h5py.File(p, 'r') as f:
            T = f['LFS'].shape[-1]
        out[shot] = T
    return out

def get_splits(root: Path, lengths: dict) -> dict[int, str]:
    """Return {shot_id: 'train'|'val'|'test'} from meta.csv or 80/20 default."""
    if (root / 'meta.csv').exists():
        meta = pd.read_csv(root / 'meta.csv')
        if 'split' in meta.columns:
            return dict(zip(meta['shot'].astype(int), meta['split'].astype(str)))
    # Default: 80% train, 20% test
    shots = list(lengths.keys())
    n = len(shots)
    n_train = int(0.8 * n)
    out = {s: 'train' for s in shots[:n_train]}
    for s in shots[n_train:]:
        out[s] = 'test'
    return out

dsrpt_lengths = get_shot_lengths(DECIMATED_ROOT) if DECIMATED_ROOT.exists() else {}
clear_lengths = get_shot_lengths(CLEAR_DECIMATED_ROOT) if CLEAR_DECIMATED_ROOT.exists() else {}
dsrpt_splits = get_splits(DECIMATED_ROOT, dsrpt_lengths) if dsrpt_lengths else {}
clear_splits = get_splits(CLEAR_DECIMATED_ROOT, clear_lengths) if clear_lengths else {}

train_shots_d = [s for s, sp in dsrpt_splits.items() if sp == 'train']
train_shots_c = [s for s, sp in clear_splits.items() if sp == 'train']
# Restrict to SNR > threshold (good_shots from section 0) for PCA fit
train_shots_d = [s for s in train_shots_d if s in good_shots_d]
train_shots_c = [s for s in train_shots_c if s in good_shots_c]
print(f'Disruptive: {len(dsrpt_lengths)} shots  (train with SNR>3={len(train_shots_d)})')
print(f'Clear:      {len(clear_lengths)} shots  (train with SNR>3={len(train_shots_c)})')

## 2. Streaming mean and covariance (training time points only)

In [None]:
# Welford in batches: merge (n, mean, M2) with batch (B, 160). Memory O(160²) + O(B*160).
BATCH_SIZE = 10_000   # time points per batch (vectorized update)

def welford_merge(n, mean, M2, batch: np.ndarray):
    """Merge batch (B, 160) into running stats. Returns (n_new, mean_new, M2_new)."""
    B = len(batch)
    if B == 0:
        return n, mean, M2
    batch_mean = batch.mean(axis=0)
    batch_M2 = (batch - batch_mean).T @ (batch - batch_mean)
    n_new = n + B
    mean_new = (n * mean + B * batch_mean) / n_new
    delta = mean - batch_mean
    M2_new = M2 + batch_M2 + (n * B / n_new) * np.outer(delta, delta)
    return n_new, mean_new, M2_new

def load_shot_flat(root: Path, shot: int, T: int) -> np.ndarray:
    """Load one shot as (T, 160) float64."""
    with h5py.File(root / f'{shot}.h5', 'r') as f:
        data = np.asarray(f['LFS'][:], dtype=np.float64)  # (20, 8, T)
    return data.reshape(CHANNELS, -1).T  # (T, 160)

def transform_shot_global(root: Path, shot: int, T: int, mu: np.ndarray, W_K: np.ndarray) -> np.ndarray:
    """z_t = (x_t - μ) W_K. Return (K, T) for LFS layout."""
    X = load_shot_flat(root, shot, T)  # (T, 160)
    Z = (X - mu) @ W_K  # (T, K)
    return Z.T.astype(np.float32)  # (K, T)

# Training shots only: (root, shot, T)
train_tuples = [(DECIMATED_ROOT, s, dsrpt_lengths[s]) for s in train_shots_d]
train_tuples += [(CLEAR_DECIMATED_ROOT, s, clear_lengths[s]) for s in train_shots_c]

n_total = 0
mean = np.zeros(CHANNELS, dtype=np.float64)
M2 = np.zeros((CHANNELS, CHANNELS), dtype=np.float64)

for root, shot, T in tqdm(train_tuples, desc='Streaming μ and M2 (train only)'):
    X = load_shot_flat(root, shot, T)  # (T, 160)
    for start in range(0, len(X), BATCH_SIZE):
        batch = X[start:start + BATCH_SIZE]
        n_total, mean, M2 = welford_merge(n_total, mean, M2, batch)

print(f'Training time points: {n_total}')
if n_total < 2:
    raise ValueError('Need at least 2 training time points to fit PCA')

## 3. Eigendecompose C and take top K eigenvectors

In [None]:
C = M2 / (n_total - 1)
eigenvalues, eigenvectors = np.linalg.eigh(C)
idx = np.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

K = min(N_COMPONENTS, CHANNELS, len(eigenvalues))
mu = mean.copy()
W_K = eigenvectors[:, :K].astype(np.float32)  # (160, K)

var_explained = eigenvalues / eigenvalues.sum()
print(f'Global PCA: μ ∈ R^{CHANNELS}, W_K ∈ R^{CHANNELS}×{K}')
print(f'Cumulative variance (top {K}): {var_explained[:K].sum():.4f}')

## 3b. How many components for ≥99% variance? (from global C)

In [None]:
target_variance = 0.99
cumvar = np.cumsum(var_explained)
k_99 = int(np.searchsorted(cumvar, target_variance)) + 1
k_99 = min(k_99, len(cumvar))
print(f'Components needed for ≥{target_variance*100:.0f}% variance (global PCA): {k_99}')
print(f'  → Consider N_COMPONENTS >= {k_99} to retain ≥99% of the information.')

## 3c. Save PCA result (top 16 PCs; write 1/4/8/16 via N_PCA_OUT in 3d)

In [None]:
N_PC_SAVE = 16   # save top 16 so we can write _pca1, _pca4, _pca8, _pca16
W_full = eigenvectors[:, :N_PC_SAVE].astype(np.float32)  # (160, 16)
pca_save_path = Path('pca_global_top16.npz')
np.savez(pca_save_path, mu=mu.astype(np.float32), W_K=W_full, var_explained=var_explained[:N_PC_SAVE].astype(np.float32))

info_saved = var_explained[:N_PC_SAVE].sum()
print(f'Saved top {N_PC_SAVE} PCs to {pca_save_path}')
print(f'Variance (information) retained: {info_saved*100:.2f}%')
print(f'Per-component: {var_explained[:N_PC_SAVE]}')

## 3d. Save transformed data (top N_PCA_OUT PCs → *_pca1, *_pca4, *_pca8, *_pca16)

In [None]:
W_save = W_full[:, :N_PCA_OUT]
import shutil

# Only write shots with SNR > threshold; write meta with snr_min column
OUT_DSRPT_PCA.mkdir(parents=True, exist_ok=True)
meta_d_out = meta_d[meta_d['shot'].astype(int).isin(good_shots_d)] if not meta_d.empty else pd.DataFrame()
if meta_d_out.empty and good_shots_d:
    meta_d_out = pd.DataFrame({'shot': list(good_shots_d), 'snr_min': [snr_d.get(s, np.nan) for s in good_shots_d], 'split': [dsrpt_splits.get(s, 'train') for s in good_shots_d]})
    if not meta_d.empty and 't_disruption' in meta_d.columns:
        meta_d_out['t_disruption'] = meta_d_out['shot'].map(meta_d.set_index('shot')['t_disruption'])
if not meta_d_out.empty:
    meta_d_out.to_csv(OUT_DSRPT_PCA / 'meta.csv', index=False)
for shot in tqdm([s for s in dsrpt_lengths if s in good_shots_d], desc=f'Save disruptive {N_PCA_OUT}-PC (SNR>3)'):
    T = dsrpt_lengths[shot]
    data = transform_shot_global(DECIMATED_ROOT, shot, T, mu, W_save)
    with h5py.File(OUT_DSRPT_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)
print(f'Saved {len(good_shots_d)} disruptive shots (SNR>3) to {OUT_DSRPT_PCA} (shape {N_PCA_OUT}×T per shot)')

OUT_CLEAR_PCA.mkdir(parents=True, exist_ok=True)
meta_c_out = meta_c[meta_c['shot'].astype(int).isin(good_shots_c)] if not meta_c.empty else pd.DataFrame()
if meta_c_out.empty and good_shots_c:
    meta_c_out = pd.DataFrame({'shot': list(good_shots_c), 'snr_min': [snr_c.get(s, np.nan) for s in good_shots_c], 'split': [clear_splits.get(s, 'train') for s in good_shots_c]})
if not meta_c_out.empty:
    meta_c_out.to_csv(OUT_CLEAR_PCA / 'meta.csv', index=False)
for shot in tqdm([s for s in clear_lengths if s in good_shots_c], desc=f'Save clear {N_PCA_OUT}-PC (SNR>3)'):
    T = clear_lengths[shot]
    data = transform_shot_global(CLEAR_DECIMATED_ROOT, shot, T, mu, W_save)
    with h5py.File(OUT_CLEAR_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)
print(f'Saved {len(good_shots_c)} clear shots (SNR>3) to {OUT_CLEAR_PCA} (shape {N_PCA_OUT}×T per shot)')

## 4. Transform and save disruptive shots -> dsrpt_decimated_pca

In [None]:
OUT_DSRPT_PCA.mkdir(parents=True, exist_ok=True)
meta_d_out = meta_d[meta_d['shot'].astype(int).isin(good_shots_d)] if not meta_d.empty else meta_d
if not meta_d_out.empty:
    meta_d_out.to_csv(OUT_DSRPT_PCA / 'meta.csv', index=False)
for shot in tqdm([s for s in dsrpt_lengths if s in good_shots_d], desc='Save disruptive PCA (SNR>3)'):
    data = transform_shot_global(DECIMATED_ROOT, shot, dsrpt_lengths[shot], mu, W_K)
    with h5py.File(OUT_DSRPT_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)
print(f'Saved {len(good_shots_d)} disruptive shots (SNR>3) to {OUT_DSRPT_PCA}')

## 5. Transform and save clear shots -> clear_decimated_pca

In [None]:
OUT_CLEAR_PCA.mkdir(parents=True, exist_ok=True)
meta_c_out = meta_c[meta_c['shot'].astype(int).isin(good_shots_c)] if not meta_c.empty else meta_c
if not meta_c_out.empty:
    meta_c_out.to_csv(OUT_CLEAR_PCA / 'meta.csv', index=False)
for shot in tqdm([s for s in clear_lengths if s in good_shots_c], desc='Save clear PCA (SNR>3)'):
    data = transform_shot_global(CLEAR_DECIMATED_ROOT, shot, clear_lengths[shot], mu, W_K)
    with h5py.File(OUT_CLEAR_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)
print(f'Saved {len(good_shots_c)} clear shots (SNR>3) to {OUT_CLEAR_PCA}')

## 6. Sanity: shapes

In [None]:
if good_shots_d:
    shot0 = next(iter(good_shots_d))
    with h5py.File(OUT_DSRPT_PCA / f'{shot0}.h5', 'r') as f:
        sh = f['LFS'].shape
    print(f'Disruptive PCA example: shot {shot0} LFS shape = {sh} (N_components, T)')
if good_shots_c:
    shot0 = next(iter(good_shots_c))
    with h5py.File(OUT_CLEAR_PCA / f'{shot0}.h5', 'r') as f:
        sh = f['LFS'].shape
    print(f'Clear PCA example:      shot {shot0} LFS shape = {sh} (N_components, T)')