# Data & normalization visualization

Inspect how **decimated** and **PCA** ECEi data look, and how **per-channel normalization** (used before training) changes the signal. We use **one example shot** and compare raw vs normalized for both representations.

In [None]:
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt
from pathlib import Path

BASE = Path('/home/idies/workspace/Storage/yhuang2/persistent/ecei')
DECIMATED_ROOT = BASE / 'dsrpt_decimated'
CLEAR_DECIMATED_ROOT = BASE / 'clear_decimated'
PCA_ROOT = BASE / 'dsrpt_decimated_pca1'   # use pca1 for default; can switch to pca4, pca8, pca16

# Norm stats (computed from training split in real training; here we load if present)
NORM_STATS_PATH = 'norm_stats.npz'
NORM_STATS_PCA_PATH = 'norm_stats_pca1.npz'

# Time window for plotting (samples at 100 kHz decimated)
PLOT_SAMPLES = 50_000   # 0.5 s at 100 kHz

## 1. Pick one example shot

Use the first disruptive shot that exists in the decimated directory.

In [None]:
if DECIMATED_ROOT.exists() and (DECIMATED_ROOT / 'meta.csv').exists():
    meta = pd.read_csv(DECIMATED_ROOT / 'meta.csv')
    # first shot that has an .h5 file
    for _, row in meta.iterrows():
        shot = int(row['shot'])
        if (DECIMATED_ROOT / f'{shot}.h5').exists():
            example_shot = shot
            t_disrupt_ms = float(row.get('t_disruption', 0))
            break
    else:
        example_shot = None
else:
    h5_files = list(DECIMATED_ROOT.glob('*.h5')) if DECIMATED_ROOT.exists() else []
    example_shot = int(h5_files[0].stem) if h5_files and h5_files[0].stem.isdigit() else None
    t_disrupt_ms = None

if example_shot is None:
    raise FileNotFoundError(f'No .h5 found under {DECIMATED_ROOT}')
print(f'Example shot: {example_shot}')
if t_disrupt_ms is not None:
    print(f't_disruption (ms): {t_disrupt_ms}')

## 2. Load raw decimated data (20×8×T)

One shot: shape `(20, 8, T)` at 100 kHz (already offset-removed and decimated in preprocessing).

In [None]:
with h5py.File(DECIMATED_ROOT / f'{example_shot}.h5', 'r') as f:
    lfs_dec = np.asarray(f['LFS'][...], dtype=np.float32)

print(f'Decimated LFS shape: {lfs_dec.shape}  (20, 8, T)')
T_dec = lfs_dec.shape[-1]
# Take a contiguous window for plotting
start_t = max(0, T_dec // 4 - PLOT_SAMPLES // 2)
end_t = min(T_dec, start_t + PLOT_SAMPLES)
lfs_dec_win = lfs_dec[:, :, start_t:end_t]
print(f'Plot window: samples [{start_t}, {end_t})')

## 3. Normalization for decimated data

Per-channel mean and std over time (computed from training split in practice). Here we either **load** saved stats or **compute from this shot** for illustration.

In [None]:
norm_path = Path(NORM_STATS_PATH)
if norm_path.exists():
    npz = np.load(norm_path)
    mean_dec = npz['mean'].astype(np.float32)   # (20, 8)
    std_dec = npz['std'].astype(np.float32)
    print(f'Loaded norm stats from {norm_path}')
else:
    # Compute from this shot (for viz only; training uses many shots)
    mean_dec = lfs_dec.mean(axis=-1).astype(np.float32)
    std_dec = np.maximum(lfs_dec.std(axis=-1), 1e-8).astype(np.float32)
    print('Computed norm stats from this shot (demo only)')

lfs_dec_norm = (lfs_dec - mean_dec[..., np.newaxis]) / std_dec[..., np.newaxis]
lfs_dec_norm_win = lfs_dec_norm[:, :, start_t:end_t]

## 4. Load PCA data (C×T) and normalize

If `dsrpt_decimated_pca1` (or pca4/8/16) exists, load the same shot; shape `(C, T)`. Apply per-component normalization the same way.

In [None]:
pca_path = PCA_ROOT / f'{example_shot}.h5'
has_pca = pca_path.exists()

if has_pca:
    with h5py.File(pca_path, 'r') as f:
        lfs_pca = np.asarray(f['LFS'][...], dtype=np.float32)
    print(f'PCA LFS shape: {lfs_pca.shape}  (C, T)')
    C = lfs_pca.shape[0]
    lfs_pca_win = lfs_pca[:, start_t:end_t]

    norm_pca_path = Path(NORM_STATS_PCA_PATH)
    if norm_pca_path.exists():
        npz = np.load(norm_pca_path)
        mean_pca = npz['mean'].astype(np.float32)
        std_pca = npz['std'].astype(np.float32)
        print(f'Loaded PCA norm stats from {norm_pca_path}')
    else:
        mean_pca = lfs_pca.mean(axis=-1).astype(np.float32)
        std_pca = np.maximum(lfs_pca.std(axis=-1), 1e-8).astype(np.float32)
        print('Computed PCA norm stats from this shot (demo only)')

    lfs_pca_norm = (lfs_pca - mean_pca[:, np.newaxis]) / std_pca[:, np.newaxis]
    lfs_pca_norm_win = lfs_pca_norm[:, start_t:end_t]
else:
    print(f'PCA path not found: {pca_path}')
    C = 1

## 5. Visualization: raw vs normalized (one example)

**Top:** Decimated data — one channel (0,0) raw vs normalized.  
**Bottom:** PCA data (PC1) raw vs normalized, if available.  

Normalization puts each channel/component on a comparable scale (roughly zero mean, unit variance) so the model is not dominated by high-variance channels.

In [None]:
nrows = 2 if has_pca else 1
fig, axes = plt.subplots(nrows, 2, figsize=(14, 3 * nrows), sharex='col')
if nrows == 1:
    axes = axes[np.newaxis, :]

t_axis = np.arange(lfs_dec_win.shape[-1]) / 100_000.0   # time in s at 100 kHz

# Row 0: Decimated — channel (0,0)
ax0_raw, ax0_norm = axes[0, 0], axes[0, 1]
ch0 = lfs_dec_win[0, 0, :]
ch0_norm = lfs_dec_norm_win[0, 0, :]
ax0_raw.plot(t_axis, ch0, color='#2e7d32', alpha=0.9, linewidth=0.6)
ax0_raw.set_ylabel('Amplitude')
ax0_raw.set_title('Decimated (raw) — channel [0,0]')
ax0_raw.grid(True, alpha=0.3)

ax0_norm.plot(t_axis, ch0_norm, color='#1565c0', alpha=0.9, linewidth=0.6)
ax0_norm.set_ylabel('Amplitude')
ax0_norm.set_title('Decimated (normalized) — channel [0,0]')
ax0_norm.grid(True, alpha=0.3)

if has_pca:
    ax1_raw, ax1_norm = axes[1, 0], axes[1, 1]
    pc1 = lfs_pca_win[0, :]
    pc1_norm = lfs_pca_norm_win[0, :]
    ax1_raw.plot(t_axis, pc1, color='#6a1b9a', alpha=0.9, linewidth=0.6)
    ax1_raw.set_ylabel('Amplitude')
    ax1_raw.set_title('PCA (raw) — PC1')
    ax1_raw.grid(True, alpha=0.3)

    ax1_norm.plot(t_axis, pc1_norm, color='#00838f', alpha=0.9, linewidth=0.6)
    ax1_norm.set_ylabel('Amplitude')
    ax1_norm.set_title('PCA (normalized) — PC1')
    ax1_norm.grid(True, alpha=0.3)

for ax in axes[-1, :]:
    ax.set_xlabel('Time (s)')
fig.suptitle(f'Shot {example_shot} — raw vs per-channel normalized', y=1.02, fontsize=12)
plt.tight_layout()
plt.show()

## 6. How decimated vs PCA differ (same shot)

Decimated: 160 channels (20×8). PCA: 1 (or 4/8/16) components. Below: **mean across all decimated channels** vs **PC1** (raw and normalized) to show the reduction in dimensionality and scale.

In [None]:
# Mean over 160 decimated channels vs PC1
dec_mean_raw = lfs_dec_win.reshape(-1, lfs_dec_win.shape[-1]).mean(axis=0)
dec_mean_norm = lfs_dec_norm_win.reshape(-1, lfs_dec_norm_win.shape[-1]).mean(axis=0)

fig2, axes2 = plt.subplots(2, 1, figsize=(12, 5), sharex=True)
axes2[0].plot(t_axis, dec_mean_raw, color='#2e7d32', alpha=0.8, label='Decimated (raw), mean over 160 ch')
if has_pca:
    axes2[0].plot(t_axis, pc1, color='#6a1b9a', alpha=0.8, label='PCA (raw), PC1')
axes2[0].set_ylabel('Amplitude')
axes2[0].set_title('Raw: decimated mean vs PC1')
axes2[0].legend(loc='upper right', fontsize=9)
axes2[0].grid(True, alpha=0.3)

axes2[1].plot(t_axis, dec_mean_norm, color='#1565c0', alpha=0.8, label='Decimated (normalized), mean')
if has_pca:
    axes2[1].plot(t_axis, pc1_norm, color='#00838f', alpha=0.8, label='PCA (normalized), PC1')
axes2[1].set_ylabel('Amplitude')
axes2[1].set_xlabel('Time (s)')
axes2[1].set_title('Normalized: decimated mean vs PC1')
axes2[1].legend(loc='upper right', fontsize=9)
axes2[1].grid(True, alpha=0.3)
fig2.suptitle(f'Shot {example_shot} — decimated (160 ch) vs PCA', y=1.02)
plt.tight_layout()
plt.show()

## 7. Summary: scale before vs after normalization

In training, norm stats are computed over the **training split** (many shots), not one shot. Here we used this shot only for illustration. Normalized signals have per-channel std ≈ 1 when stats are computed over the same data.

In [None]:
print('Decimated (window):')
print(f'  Raw  — channel [0,0] std = {ch0.std():.4f}')
print(f'  Norm — channel [0,0] std = {ch0_norm.std():.4f}')
if has_pca:
    print('PCA (window):')
    print(f'  Raw  — PC1 std = {pc1.std():.4f}')
    print(f'  Norm — PC1 std = {pc1_norm.std():.4f}')