# Subject Data Diagnostics

This notebook runs diagnostic checks to identify potential data quality issues with specific subjects.

## Motivation
Subject 7 (A07T) appears to have issues affecting model performance. We run two specific checks:
- **Check A**: Channel Variance ("Loose Wire" Test) - Detects channels with abnormal signal power
- **Check B**: Ghost Label Test - Verifies that labels actually separate the data


In [None]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')


## Load Subject Data

We compare a "good" subject (A03T) against the potentially "bad" subject (A07T).


In [None]:
# Load "Good" Subject vs "Bad" Subject
sub_good = np.load('../data/processed/A03T.npz')
sub_bad  = np.load('../data/processed/A07T.npz')

X_good = sub_good['X']  # (Trials, Channels, Time)
X_bad  = sub_bad['X']

print(f"A03T (Good): X shape = {X_good.shape}, y shape = {sub_good['y'].shape}")
print(f"A07T (Bad):  X shape = {X_bad.shape}, y shape = {sub_bad['y'].shape}")


---
## Check A: Channel Variance ("Loose Wire" Test)

Compare the signal power of each channel between subjects. A bad channel (loose electrode, disconnection, or artifact) will show variance that is 10x-100x higher or lower than other channels.

**What to look for:** If one point on the Red line is 10x-100x higher (or lower) than the others, that channel is breaking Euclidean Alignment.


In [None]:
# Compute variance per channel (averaged over trials)
var_good = np.var(X_good, axis=(0, 2))
var_bad  = np.var(X_bad, axis=(0, 2))

# Plot
plt.figure(figsize=(10, 5))
plt.plot(var_good, label='A03T (Good)', marker='o')
plt.plot(var_bad, label='A07T (Bad)', marker='x', color='red')
plt.xlabel('Channel Index')
plt.ylabel('Signal Variance')
plt.title('Diagnostic: Detecting Exploding Channels')
plt.legend()
plt.grid(True)
plt.yscale('log')  # Log scale is crucial to see outliers
plt.show()

# Print the offender
if np.max(var_bad) > 10 * np.median(var_bad):
    bad_ch = np.argmax(var_bad)
    print(f"⚠️ DETECTED BAD CHANNEL IN A07T: Index {bad_ch} has variance {np.max(var_bad):.2f}")


---
## Check B: The "Ghost Label" Test

Check if the labels are actually separating the data by plotting average ERPs (Event-Related Potentials) for different classes.

**What to look for:**
- **Good (A03):** Distinct differences between Left and Right class lines
- **Bad (A07):** If lines overlap perfectly, labels might be shuffled or signal is pure noise


In [None]:
# Plot Average ERP for Left vs Right
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
# Channel 7 (C3) is usually motor dominant
plt.plot(X_bad[sub_bad['y']==0].mean(axis=0)[7], label='Left (Class 0)') 
plt.plot(X_bad[sub_bad['y']==1].mean(axis=0)[7], label='Right (Class 1)')
plt.title('Subject 7: Class Separation (C3)')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(X_good[sub_good['y']==0].mean(axis=0)[7], label='Left') 
plt.plot(X_good[sub_good['y']==1].mean(axis=0)[7], label='Right')
plt.title('Subject 3: Class Separation (C3)')
plt.legend()
plt.show()


---
## Quantify Class Separation

Measure the correlation between class ERPs - lower correlation means more distinct classes.


In [None]:
# Quantify class separation using correlation
erp_bad_0 = X_bad[sub_bad['y']==0].mean(axis=0)[7]
erp_bad_1 = X_bad[sub_bad['y']==1].mean(axis=0)[7]
erp_good_0 = X_good[sub_good['y']==0].mean(axis=0)[7]
erp_good_1 = X_good[sub_good['y']==1].mean(axis=0)[7]

corr_bad = np.corrcoef(erp_bad_0, erp_bad_1)[0, 1]
corr_good = np.corrcoef(erp_good_0, erp_good_1)[0, 1]

print("=" * 50)
print("CLASS SEPARATION ANALYSIS")
print("=" * 50)
print(f"\nCorrelation between class ERPs (lower = more distinct):")
print(f"  A03T (Good): {corr_good:.4f}")
print(f"  A07T (Bad):  {corr_bad:.4f}")

if corr_bad > 0.95:
    print(f"\n⚠️  WARNING: A07T class ERPs are nearly identical (corr={corr_bad:.3f})")
    print("    This suggests labels may be shuffled or signal contains no class info!")
elif corr_bad > 0.8:
    print(f"\n⚠️  CAUTION: A07T shows weak class separation (corr={corr_bad:.3f})")
else:
    print(f"\n✓ A07T shows reasonable class separation (corr={corr_bad:.3f})")


---
## Summary

Based on the diagnostics above, determine:

1. **Bad Channels**: Are there channels with abnormal variance that need interpolation or removal?
2. **Label Issues**: Do the labels actually separate the neural signals?
3. **Next Steps**: Based on findings, consider:
   - Robust preprocessing (channel rejection, artifact removal)
   - Subject-specific normalization
   - Excluding problematic subjects from training


In [None]:
# Check 1: Covariance Matrix Structure
# Euclidean Alignment depends on the covariance matrix - let's compare them

def compute_trial_covariance(X):
    """Compute average covariance matrix across trials"""
    n_trials, n_channels, n_times = X.shape
    covs = []
    for trial in range(n_trials):
        cov = np.cov(X[trial])  # (channels, channels)
        covs.append(cov)
    return np.mean(covs, axis=0)

cov_good = compute_trial_covariance(X_good)
cov_bad = compute_trial_covariance(X_bad)

# Eigenvalue analysis - if eigenvalues are very different, EA will struggle
eig_good = np.linalg.eigvalsh(cov_good)
eig_bad = np.linalg.eigvalsh(cov_bad)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot eigenvalue spectra
ax = axes[0]
ax.semilogy(sorted(eig_good, reverse=True), 'o-', label='A03T', markersize=5)
ax.semilogy(sorted(eig_bad, reverse=True), 'x-', label='A07T', color='red', markersize=5)
ax.set_xlabel('Eigenvalue Index')
ax.set_ylabel('Eigenvalue (log scale)')
ax.set_title('Covariance Eigenvalue Spectrum')
ax.legend()
ax.grid(True, alpha=0.3)

# Condition number (ratio of max/min eigenvalue) - high = ill-conditioned
cond_good = np.max(eig_good) / np.max(eig_good[eig_good > 1e-10])
cond_bad = np.max(eig_bad) / np.max(eig_bad[eig_bad > 1e-10])

# Plot covariance matrices
ax = axes[1]
im = ax.imshow(cov_good, cmap='RdBu_r', aspect='auto')
ax.set_title(f'A03T Covariance')
plt.colorbar(im, ax=ax)

ax = axes[2]
im = ax.imshow(cov_bad, cmap='RdBu_r', aspect='auto')
ax.set_title(f'A07T Covariance')
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

print(f"Condition number (higher = harder to align):")
print(f"  A03T: {np.max(eig_good)/np.min(eig_good[eig_good > 0]):.2f}")
print(f"  A07T: {np.max(eig_bad)/np.min(eig_bad[eig_bad > 0]):.2f}")


In [None]:
# Check 2: Signal-to-Noise Ratio (Fisher's criterion)
# Measures how separable the classes are

def fisher_score_per_channel(X, y):
    """Compute Fisher's discriminant ratio for each channel"""
    classes = np.unique(y)
    n_channels = X.shape[1]
    
    scores = []
    for ch in range(n_channels):
        # Flatten time dimension
        x_ch = X[:, ch, :].mean(axis=1)  # Average over time -> (trials,)
        
        # Between-class variance
        class_means = [x_ch[y == c].mean() for c in classes]
        overall_mean = x_ch.mean()
        between_var = sum([(m - overall_mean)**2 for m in class_means])
        
        # Within-class variance
        within_var = sum([x_ch[y == c].var() for c in classes])
        
        scores.append(between_var / (within_var + 1e-10))
    
    return np.array(scores)

fisher_good = fisher_score_per_channel(X_good, sub_good['y'])
fisher_bad = fisher_score_per_channel(X_bad, sub_bad['y'])

plt.figure(figsize=(10, 5))
plt.bar(np.arange(len(fisher_good)) - 0.2, fisher_good, 0.4, label='A03T (Good)', alpha=0.8)
plt.bar(np.arange(len(fisher_bad)) + 0.2, fisher_bad, 0.4, label='A07T (Bad)', alpha=0.8, color='red')
plt.xlabel('Channel Index')
plt.ylabel('Fisher Score (higher = more discriminative)')
plt.title('Per-Channel Class Discriminability')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"Total Fisher Score (sum across channels):")
print(f"  A03T: {fisher_good.sum():.4f}")
print(f"  A07T: {fisher_bad.sum():.4f}")
print(f"  Ratio (Good/Bad): {fisher_good.sum()/fisher_bad.sum():.2f}x")


In [None]:
# Check 3: Compare ALL subjects to see where Subject 7 falls
import os

subjects = ['A01T', 'A02T', 'A03T', 'A04T', 'A05T', 'A06T', 'A07T', 'A08T', 'A09T']
fisher_scores = []
cov_condition_numbers = []

for subj in subjects:
    data = np.load(f'../data/processed/{subj}.npz')
    X, y = data['X'], data['y']
    
    # Fisher score
    fs = fisher_score_per_channel(X, y).sum()
    fisher_scores.append(fs)
    
    # Covariance condition number
    cov = compute_trial_covariance(X)
    eigs = np.linalg.eigvalsh(cov)
    cond = np.max(eigs) / np.min(eigs[eigs > 0])
    cov_condition_numbers.append(cond)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Fisher scores
ax = axes[0]
colors = ['red' if 'A07' in s else 'steelblue' for s in subjects]
bars = ax.bar(subjects, fisher_scores, color=colors)
ax.set_ylabel('Total Fisher Score')
ax.set_title('Class Discriminability by Subject\n(Higher = Better Signal)')
ax.axhline(np.mean(fisher_scores), color='gray', linestyle='--', label='Mean')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Condition numbers  
ax = axes[1]
bars = ax.bar(subjects, cov_condition_numbers, color=colors)
ax.set_ylabel('Condition Number')
ax.set_title('Covariance Matrix Condition Number\n(Higher = Harder to Align)')
ax.axhline(np.mean(cov_condition_numbers), color='gray', linestyle='--', label='Mean')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Summary table
print("\n" + "="*60)
print("SUBJECT COMPARISON SUMMARY")
print("="*60)
print(f"{'Subject':<10} {'Fisher Score':<15} {'Cov Condition':<15}")
print("-"*60)
for i, subj in enumerate(subjects):
    marker = " ⚠️" if 'A07' in subj else ""
    print(f"{subj:<10} {fisher_scores[i]:<15.4f} {cov_condition_numbers[i]:<15.2f}{marker}")


---
## Interpretation Guide

### What the deeper diagnostics reveal:

1. **Covariance Eigenvalue Spectrum**: If A07T has a very different spectrum (much steeper or flatter), Euclidean Alignment will have trouble mapping it to other subjects.

2. **Condition Number**: High condition number = covariance matrix is nearly singular = numerical instability in EA's matrix square root operation.

3. **Fisher Score**: Low Fisher = the subject's brain signals don't clearly differentiate between left/right motor imagery. This is **BCI illiteracy** - ~15-30% of people have weak motor imagery signals.

### Likely Conclusions:
- If A07T has **low Fisher score** → The subject is a weak motor imagery performer (nothing wrong with data, just biology)
- If A07T has **high condition number** → EA is numerically unstable for this subject
- If A07T looks **normal** on all metrics → The issue may be model-specific, not data-specific


---
## Deeper Diagnostic: All Subjects Comparison

Since no obvious issues were found, let's compare ALL subjects to see if A07T is truly an outlier or just a naturally "hard" subject (BCI illiteracy).
