# PCA from decimated data (per time point)

Starting from the **decimated** disruptive and clear shots, fit **PCA per time index** (across the 160 channels), keep the **top N** components, and save the reduced data as `*_pca` (e.g. `dsrpt_decimated_pca`, `clear_decimated_pca`).

- At each time index `t`, we collect the 160-D vector from every shot that has length > t, fit PCA, keep top N.
- Each shot is then transformed to shape `(N, T_shot)` and saved.
- Requires decimated data to exist (run preprocessing_viz 4b first, or point to existing decimated dirs).

In [None]:
import numpy as np
import pandas as pd
import h5py
from pathlib import Path
from sklearn.decomposition import PCA
from tqdm import tqdm

# Paths — decimated inputs (disruptive + clear)
DECIMATED_ROOT = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei/dsrpt_decimated')
CLEAR_DECIMATED_ROOT = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei/clear_decimated')
# Outputs (will be created)
OUT_DSRPT_PCA = DECIMATED_ROOT.parent / 'dsrpt_decimated_pca'
OUT_CLEAR_PCA = DECIMATED_ROOT.parent / 'clear_decimated_pca'

N_COMPONENTS = 16   # top N PCA components per time point (1–160)
CHANNELS = 20 * 8   # 160

## 1. List shots and their lengths

In [None]:
def get_shot_lengths(root: Path) -> dict[int, int]:
    """Return {shot_id: T} for each .h5 in root."""
    out = {}
    for p in root.glob('*.h5'):
        if not p.stem.isdigit():
            continue
        shot = int(p.stem)
        with h5py.File(p, 'r') as f:
            T = f['LFS'].shape[-1]
        out[shot] = T
    return out

dsrpt_lengths = get_shot_lengths(DECIMATED_ROOT) if DECIMATED_ROOT.exists() else {}
clear_lengths = get_shot_lengths(CLEAR_DECIMATED_ROOT) if CLEAR_DECIMATED_ROOT.exists() else {}
print(f'Disruptive: {len(dsrpt_lengths)} shots')
print(f'Clear:      {len(clear_lengths)} shots')
if dsrpt_lengths:
    Ts = list(dsrpt_lengths.values())
    print(f'  Disruptive T: min={min(Ts)}, max={max(Ts)}')
if clear_lengths:
    Ts = list(clear_lengths.values())
    print(f'  Clear T:      min={min(Ts)}, max={max(Ts)}')

## 2. Build (time_index -> list of 160-D vectors) for fitting

In [None]:
def load_channel_vector(root: Path, shot: int, t: int) -> np.ndarray:
    """Load the 160-D vector at time t for one shot. Shape (160,)."""
    with h5py.File(root / f'{shot}.h5', 'r') as f:
        x = np.asarray(f['LFS'][..., t], dtype=np.float32)
    return x.ravel()  # (20,8) -> (160,)

def collect_vectors_at_t(roots_with_shots: list[tuple[Path, list[tuple[int, int]]]], t: int) -> np.ndarray:
    """Collect (n, 160) array of all vectors at time t from given (root, [(shot, T), ...])."""
    rows = []
    for root, shots_with_T in roots_with_shots:
        for shot, T in shots_with_T:
            if T <= t:
                continue
            rows.append(load_channel_vector(root, shot, t))
    if not rows:
        return np.empty((0, CHANNELS), dtype=np.float32)
    return np.stack(rows, axis=0)

# Pairs (root, [(shot, T), ...]) for fitting
dsrpt_list = [(DECIMATED_ROOT, list(dsrpt_lengths.items()))] if dsrpt_lengths else []
clear_list = [(CLEAR_DECIMATED_ROOT, list(clear_lengths.items()))] if clear_lengths else []
all_roots_shots = dsrpt_list + clear_list
max_T = max(
    (max((T for _, T in s])) for _, s in all_roots_shots if s),
    default=0
)
print(f'Max T over all shots: {max_T}')

## 3. Fit PCA at each time index (top N)

## 3b. How many components for ≥99% variance?

At a sample of time indices, fit PCA with **all** components and find the smallest number of components whose cumulative explained variance ≥ 99%. Use this to choose `N_COMPONENTS` if you want to retain 99% of the information.

In [None]:
target_variance = 0.99
# Sample time indices so we don't fit at every t (expensive)
step = max(1, max_T // 50)   # ~50 time points, or every t if max_T < 50
t_sample = list(range(0, max_T, step))

n_for_99 = []
for t in tqdm(t_sample, desc='99% variance check'):
    X = collect_vectors_at_t(all_roots_shots, t)
    if X.shape[0] < 2:
        continue
    n_full = min(X.shape[0], X.shape[1])
    pca_full = PCA(n_components=n_full)
    pca_full.fit(X)
    cumvar = np.cumsum(pca_full.explained_variance_ratio_)
    k = int(np.searchsorted(cumvar, target_variance)) + 1
    k = min(k, len(cumvar))
    n_for_99.append(k)

if n_for_99:
    n_for_99 = np.array(n_for_99)
    print(f'At {target_variance*100:.0f}% explained variance (over {len(t_sample)} time samples):')
    print(f'  Components needed: min={n_for_99.min()}, max={n_for_99.max()}, mean={n_for_99.mean():.1f}, median={np.median(n_for_99):.0f}')
    print(f'  → Consider N_COMPONENTS >= {int(np.ceil(n_for_99.max()))} to retain ≥99% at all sampled t.')
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(6, 3))
    ax.hist(n_for_99, bins=min(50, len(np.unique(n_for_99))), color='steelblue', edgecolor='white')
    ax.axvline(np.median(n_for_99), color='red', linestyle='--', label=f'median={np.median(n_for_99):.0f}')
    ax.set_xlabel('Number of components for ≥99% variance')
    ax.set_ylabel('Count (time samples)')
    ax.legend()
    ax.set_title(f'Per-time-point PCA: components needed for ≥{target_variance*100:.0f}% variance')
    plt.tight_layout()
    plt.show()
else:
    print('Not enough data to compute (need at least 2 samples at some t).')

In [None]:
n_comp = min(N_COMPONENTS, CHANNELS)
pca_per_t = []  # pca_per_t[t] = fitted PCA or None if too few samples

for t in tqdm(range(max_T), desc='Fit PCA per t'):
    X = collect_vectors_at_t(all_roots_shots, t)
    if X.shape[0] < 2:
        pca_per_t.append(None)
        continue
    n = min(n_comp, X.shape[0], X.shape[1])
    pca = PCA(n_components=n)
    pca.fit(X)
    pca_per_t.append(pca)

n_fitted = sum(1 for p in pca_per_t if p is not None)
print(f'Fitted PCA at {n_fitted}/{max_T} time indices')

## 4. Transform and save disruptive shots -> dsrpt_decimated_pca

In [None]:
def transform_shot(root: Path, shot: int, T: int, pca_per_t: list) -> np.ndarray:
    """Return (N, T) array: at each t, project 160-D to top N components."""
    N = N_COMPONENTS
    out = np.zeros((N, T), dtype=np.float32)
    for t in range(T):
        pca = pca_per_t[t] if t < len(pca_per_t) else None
        if pca is None:
            out[:, t] = 0
            continue
        x = load_channel_vector(root, shot, t).reshape(1, -1)
        out[:pca.n_components_, t] = pca.transform(x).ravel()
        if pca.n_components_ < N:
            out[pca.n_components_:, t] = 0
    return out

OUT_DSRPT_PCA.mkdir(parents=True, exist_ok=True)
if (DECIMATED_ROOT / 'meta.csv').exists():
    import shutil
    shutil.copy(DECIMATED_ROOT / 'meta.csv', OUT_DSRPT_PCA / 'meta.csv')

for shot, T in tqdm(list(dsrpt_lengths.items()), desc='Save disruptive PCA'):
    data = transform_shot(DECIMATED_ROOT, shot, T, pca_per_t)
    with h5py.File(OUT_DSRPT_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)

print(f'Saved {len(dsrpt_lengths)} shots to {OUT_DSRPT_PCA}')

## 5. Transform and save clear shots -> clear_decimated_pca

In [None]:
OUT_CLEAR_PCA.mkdir(parents=True, exist_ok=True)
if CLEAR_DECIMATED_ROOT.exists() and (CLEAR_DECIMATED_ROOT / 'meta.csv').exists():
    import shutil
    shutil.copy(CLEAR_DECIMATED_ROOT / 'meta.csv', OUT_CLEAR_PCA / 'meta.csv')

for shot, T in tqdm(list(clear_lengths.items()), desc='Save clear PCA'):
    data = transform_shot(CLEAR_DECIMATED_ROOT, shot, T, pca_per_t)
    with h5py.File(OUT_CLEAR_PCA / f'{shot}.h5', 'w') as f:
        f.create_dataset('LFS', data=data, dtype=np.float32)

print(f'Saved {len(clear_lengths)} shots to {OUT_CLEAR_PCA}')

## 6. Sanity: shapes

In [None]:
if dsrpt_lengths:
    shot0 = list(dsrpt_lengths.keys())[0]
    with h5py.File(OUT_DSRPT_PCA / f'{shot0}.h5', 'r') as f:
        sh = f['LFS'].shape
    print(f'Disruptive PCA example: shot {shot0} LFS shape = {sh} (N_components, T)')
if clear_lengths:
    shot0 = list(clear_lengths.keys())[0]
    with h5py.File(OUT_CLEAR_PCA / f'{shot0}.h5', 'r') as f:
        sh = f['LFS'].shape
    print(f'Clear PCA example:      shot {shot0} LFS shape = {sh} (N_components, T)')