In [None]:
import numpy as np
import matplotlib.pyplot as plt
import heartpy as hp
from scipy.signal import butter, filtfilt, iirnotch, welch
from scipy.integrate import simps
import json
import os
plt.style.use('ggplot')

In [None]:
file_path = 'muse_data.npz'

if not os.path.exists(file_path):
    print(f"Error: {file_path} not found.")
    print("Please make sure the .npz file is in the same directory as this script.")
else:
    print(f"Loading data from {file_path}...")
    data = np.load(file_path, allow_pickle=True)

    # See what's inside
    print(f"\nKeys in the .npz file: {list(data.keys())}\n")

    # Extract the data streams
    eeg_data = data['eeg.npy']
    ppg_data = data['ppg.npy'].flatten() # Flatten to 1D array
    acc_data = data['accel.npy']
    gyro_data = data['gyro.npy']
    timestamps = data['timestamp.npy']

    # Extract metadata
    # .item() is used to extract scalar/dict objects from 0-dim arrays
    sampling_rates = data['sampling_rates.npy'].item()
    channel_names = data['channel_names.npy']
    duration_sec = data['duration_seconds.npy'].item()

    eeg_sfreq = sampling_rates.get('eeg', 256) # Default to 256 if not found
    ppg_sfreq = sampling_rates.get('ppg', 64)  # Default to 64 if not found

    # Print a report
    print("--- Data Report ---")
    print(f"Duration: {duration_sec} seconds")
    print(f"EEG Shape: {eeg_data.shape} (Samples, Channels)")
    print(f"EEG Sampling Rate: {eeg_sfreq} Hz")
    print(f"EEG Channel Names: {channel_names}")
    print(f"PPG Shape: {ppg_data.shape} (Samples,)")
    print(f"PPG Sampling Rate: {ppg_sfreq} Hz")
    print(f"Accelerometer Shape: {acc_data.shape}")
    print(f"Gyroscope Shape: {gyro_data.shape}")

In [None]:
def preprocess_eeg(eeg_data, acc_data, gyro_data, sfreq, channel_names):
    """
    Applies the full pre-processing pipeline to raw EEG data.
    (Adapted to use provided channel names)
    """
    
    # --- Step 1: Baseline Removal and Re-referencing ---
    eeg_referenced = eeg_data
    if 'TP9' in channel_names and 'TP10' in channel_names:
        try:
            tp9_idx = np.where(channel_names == 'TP9')[0][0]
            tp10_idx = np.where(channel_names == 'TP10')[0][0]
            tp_avg = eeg_data[:, [tp9_idx, tp10_idx]].mean(axis=1, keepdims=True)
            eeg_referenced = eeg_data - tp_avg
            print("Re-referenced EEG to average of TP9 and TP10.")
        except Exception as e:
            print(f"Could not re-reference: {e}. Using original data.")
            eeg_referenced = eeg_data
    else:
        print("TP9 or TP10 not found. Skipping re-referencing.")

    # --- Step 2: Filtering ---
    bp_low = 1.0
    bp_high = 45.0
    nyquist = 0.5 * sfreq
    b, a = butter(N=4, Wn=[bp_low/nyquist, bp_high/nyquist], btype='bandpass')
    eeg_bandpassed = filtfilt(b, a, eeg_referenced, axis=0)

    notch_freq = 50.0 
    Q = 30
    b_notch, a_notch = iirnotch(notch_freq, Q, fs=sfreq)
    filtered_eeg = filtfilt(b_notch, a_notch, eeg_bandpassed, axis=0)

    # --- Step 3: Artifact Removal ---
    acc_mag = np.linalg.norm(acc_data, axis=1)
    gyro_mag = np.linalg.norm(gyro_data, axis=1)

    # Heuristic thresholds - these may need tuning!
    acc_thresh = np.mean(acc_mag) + 3 * np.std(acc_mag)
    gyro_thresh = np.mean(gyro_mag) + 3 * np.std(gyro_mag)
    motion_mask = (acc_mag > acc_thresh) | (gyro_mag > gyro_thresh)

    # Blink artifact detection (simple threshold on frontal channels)
    blink_mask = np.zeros(filtered_eeg.shape[0], dtype=bool)
    if 'AF7' in channel_names and 'AF8' in channel_names:
        try:
            af7_idx = np.where(channel_names == 'AF7')[0][0]
            af8_idx = np.where(channel_names == 'AF8')[0][0]
            frontal_diff = np.diff(filtered_eeg[:, [af7_idx, af8_idx]], axis=0, prepend=0)
            blink_thresh_diff = 50 # 50 uV change in 1 sample
            blink_mask = (np.abs(frontal_diff) > blink_thresh_diff).any(axis=1)
            print("Calculated blink artifacts from AF7 and AF8.")
        except Exception as e:
            print(f"Could not calculate blink artifacts: {e}")
    else:
        print("AF7 or AF8 not in channels. Skipping blink detection.")

    artifact_mask = motion_mask | blink_mask
    
    # Expand mask slightly to cover edges of artifacts
    # This is a simple 'dilation' operation
    artifact_mask_expanded = np.convolve(artifact_mask, np.ones(int(sfreq * 0.5)), mode='same').astype(bool)
    
    return filtered_eeg, artifact_mask_expanded

def extract_eeg_features(filtered_eeg, artifact_mask, sfreq, channel_names, epoch_sec, overlap_sec):
    """
    Performs epoching and feature extraction (PSD band powers)
    on the cleaned EEG data.
    """
    BANDS = {
        'delta': [1, 4],
        'theta': [4, 8],
        'alpha': [8, 12],
        'beta': [12, 30],
        'gamma': [30, 45]
    }
    
    epoch_samples = int(epoch_sec * sfreq)
    overlap_samples = int(overlap_sec * sfreq)
    step_samples = epoch_samples - overlap_samples
    
    num_channels = filtered_eeg.shape[1]
    clean_epochs = []
    
    for start in range(0, filtered_eeg.shape[0] - epoch_samples + 1, step_samples):
        end = start + epoch_samples
        
        # --- Artifact Check ---
        # Reject epoch if > 25% of it is an artifact
        if artifact_mask[start:end].mean() > 0.25:
            print(f"Skipping epoch {start / sfreq:.2f}s (too many artifacts)")
            continue
            
        epoch_data = filtered_eeg[start:end, :]
        epoch_features = {'start_time_sec': start / sfreq}
        
        for ch_idx in range(num_channels):
            freqs, psd = welch(epoch_data[:, ch_idx], fs=sfreq, nperseg=epoch_samples)
            ch_name = channel_names[ch_idx]
            ch_features = {}
            
            for band, (low, high) in BANDS.items():
                band_mask = (freqs >= low) & (freqs <= high)
                if not np.any(band_mask):
                    band_power = 0.0
                else:
                    band_power = simps(psd[band_mask], freqs[band_mask])
                ch_features[f'{band}'] = band_power
            
            epoch_features[ch_name] = ch_features
            
        clean_epochs.append(epoch_features)
        
    return clean_epochs

In [None]:
def preprocess_ppg_and_extract_hrv(ppg_data, sfreq):
    """
    Processes raw PPG signal to find Interbeat Intervals (IBIs) and
    extracts Heart Rate Variability (HRV) metrics.
    """
    print(f"Processing PPG data (length {len(ppg_data)} samples) at {sfreq} Hz...")
    
    working_data = {}
    measures = {}
    
    try:
        # HeartPy's process function handles filtering, peak detection, 
        # IBI calculation, and feature extraction all in one.
        working_data, measures = hp.process(ppg_data, sample_rate=sfreq)
        
        if 'lf/hf' in measures:
            print(f"Successfully computed HRV. LF/HF Ratio: {measures['lf/hf']:.3f}")
        else:
            print("HRV computed, but 'lf/hf' was not found.")
    except Exception as e:
        print(f"HeartPy processing failed: {e}")
        
    return working_data, measures


In [None]:
def get_relative_band_power(band_powers, target_band):
    total_power = sum(v for k, v in band_powers.items() if k in ['delta', 'theta', 'alpha', 'beta', 'gamma'])
    if total_power == 0:
        return 0.0
    return band_powers.get(target_band, 0.0) / total_power

def normalize_score(value, min_val, max_val):
    if value is None or np.isnan(value):
        return None
    score = ((value - min_val) / (max_val - min_val))
    score = max(0.0, min(1.0, score))
    return round(score * 100, 2)

def interpret_epoch_insights_v1(epoch_features, hrv_measures, channel_names):
    final_insights = {
        'timestamp': epoch_features.get('start_time_sec'),
        'scores': { 'stress': None, 'relaxation': None, 'focus': None },
    }
    
    # --- 1. Stress Score (from PPG) ---
    lf_hf_ratio = hrv_measures.get('lf/hf')
    if lf_hf_ratio is not None and not np.isnan(lf_hf_ratio):
        STRESS_MIN_VAL = 0.5
        STRESS_MAX_VAL = 2.5 
        final_insights['scores']['stress'] = normalize_score(lf_hf_ratio, STRESS_MIN_VAL, STRESS_MAX_VAL)
    
    # --- 2. Relaxation Score (from EEG) ---
    alpha_channels = [ch for ch in ['TP9', 'TP10'] if ch in channel_names]
    if alpha_channels:
        total_relative_alpha = 0.0
        for ch in alpha_channels:
            ch_powers = epoch_features.get(ch, {})
            total_relative_alpha += get_relative_band_power(ch_powers, 'alpha')
        avg_relative_alpha = total_relative_alpha / len(alpha_channels)
        RELAX_MIN_VAL = 0.15 # 15%
        RELAX_MAX_VAL = 0.40 # 40%
        final_insights['scores']['relaxation'] = normalize_score(avg_relative_alpha, RELAX_MIN_VAL, RELAX_MAX_VAL)

    # --- 3. Focus Score (from EEG) ---
    focus_channels = [ch for ch in ['AF7', 'AF8'] if ch in channel_names]
    if focus_channels:
        total_tbr = 0.0
        for ch in focus_channels:
            ch_powers = epoch_features.get(ch, {})
            theta = ch_powers.get('theta', 0.0)
            beta = ch_powers.get('beta', 0.0)
            if beta > 0:
                total_tbr += (theta / beta)
        avg_tbr = total_tbr / len(focus_channels)
        FOCUS_MIN_VAL = 4.0 # High TBR = Low Focus
        FOCUS_MAX_VAL = 1.5 # Low TBR = High Focus
        final_insights['scores']['focus'] = normalize_score(avg_tbr, FOCUS_MIN_VAL, FOCUS_MAX_VAL)
        
    return final_insights

In [None]:
print("\nRunning EEG Pre-processing...")
if 'eeg_data' in locals(): # Check if data loaded successfully
    filtered_eeg, artifact_mask = preprocess_eeg(eeg_data, 
                                                 acc_data, 
                                                 gyro_data, 
                                                 eeg_sfreq, 
                                                 channel_names)
    print("EEG Pre-processing complete.")
    print(f"{artifact_mask.sum() / len(artifact_mask) * 100:.2f}% of samples flagged as artifacts.")

    # --- EEG Visualization 1: Raw vs. Filtered ---
    print("\nPlotting Raw vs. Filtered EEG...")
    time_axis = np.arange(eeg_data.shape[0]) / eeg_sfreq
    num_channels = eeg_data.shape[1]

    fig, axes = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)
    fig.suptitle('Raw vs. Filtered EEG Data', fontsize=16)

    for i in range(num_channels):
        ax = axes[i]
        ch_name = channel_names[i]
        
        # Plot Raw Data (with offset for clarity)
        raw_mean = np.mean(eeg_data[:, i])
        ax.plot(time_axis, eeg_data[:, i] - raw_mean, 'b', alpha=0.5, label='Raw (mean-centered)')
        
        # Plot Filtered Data
        ax.plot(time_axis, filtered_eeg[:, i], 'r', alpha=0.8, label='Filtered')
        
        ax.set_ylabel(f"{ch_name}\n(uV)")
        ax.legend(loc='upper right')

    axes[-1].set_xlabel('Time (seconds)')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show() # This will open a plot window


    # --- EEG Visualization 2: Filtered Data with Artifacts ---
    print("\nPlotting Filtered EEG with Artifact Mask...")

    fig, axes = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)
    fig.suptitle('Filtered EEG with Detected Artifacts', fontsize=16)

    for i in range(num_channels):
        ax = axes[i]
        ch_name = channel_names[i]
        
        # Plot Filtered Data
        ax.plot(time_axis, filtered_eeg[:, i], 'g', label='Filtered Data')
        
        # Shade artifact regions
        mask_indices = np.where(artifact_mask)[0]
        for start in mask_indices:
            # This is inefficient for plotting, but fine for 10s
            ax.axvspan(start/eeg_sfreq, (start+1)/eeg_sfreq, color='red', alpha=0.2)

        ax.set_ylabel(f"{ch_name}\n(uV)")
        
        # Add a custom legend entry for the artifact shading
        if i == 0:
            ax.plot([], [], color='red', alpha=0.2, linewidth=10, label='Artifact Detected')
        ax.legend(loc='upper right')

    axes[-1].set_xlabel('Time (seconds)')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show() # This will open a second plot window

else:
    print("Skipping EEG processing as data was not loaded.")


# %% [markdown]
# ## 4. Run EEG Feature Extraction
# 
# Now we'll segment the 10-second cleaned signal into smaller epochs (e.g., 2 seconds) and extract the band power features.

# %% [code]
# ==================================
# Cell 7: Run EEG Feature Extraction
# ==================================
if 'filtered_eeg' in locals(): # Check if previous step ran
    epoch_sec = 2.0
    overlap_sec = 1.0

    print(f"\nExtracting EEG features ({epoch_sec}s epochs, {overlap_sec}s overlap)...")

    eeg_features = extract_eeg_features(filtered_eeg, 
                                      artifact_mask, 
                                      eeg_sfreq, 
                                      channel_names, 
                                      epoch_sec, 
                                      overlap_sec)

    print(f"\nExtracted features from {len(eeg_features)} clean epochs.")

    if eeg_features:
        print("\n--- Example features from first clean epoch ---")
        print(json.dumps(eeg_features[0], indent=2))
    else:
        print("No clean epochs were found in this 10-second sample.")
else:
    print("Skipping EEG feature extraction as pre-processing did not run.")
    eeg_features = [] # Define as empty list to avoid errors


# %% [markdown]
# ## 5. Run PPG Pipeline & Visualize Results
# 
# Next, we process the PPG data. 
# 
# **CRITICAL NOTE:** Your data sample is only ~10 seconds long. **This is too short for a reliable HRV Frequency-Domain (LF/HF) analysis.** # 
# Standard HRV frequency analysis requires a minimum of 60 seconds. The `heartpy` library will likely fail to calculate the `lf/hf` ratio, or it will be `NaN`. This is **expected behavior**. For a real application, you would buffer 60 seconds of PPG data *before* running this analysis.

# %% [code]
# ==================================
# Cell 8: Run PPG Pipeline
# ==================================
if 'ppg_data' in locals():
    # --- PPG Visualization 1: Raw Data ---
    print("\nPlotting Raw PPG Data...")
    ppg_time_axis = np.arange(ppg_data.shape[0]) / ppg_sfreq

    plt.figure(figsize=(15, 3))
    plt.title("Raw PPG Signal")
    plt.plot(ppg_time_axis, ppg_data, label="Raw PPG")
    plt.xlabel("Time (seconds)")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.show() # This will open a third plot window

    # --- Run PPG Pipeline ---
    ppg_working_data, ppg_measures = preprocess_ppg_and_extract_hrv(ppg_data, ppg_sfreq)

    # --- PPG Visualization 2: HeartPy Plotter ---
    if ppg_working_data and 'filtered' in ppg_working_data:
        print("\nPlotting HeartPy Processing Results...")
        try:
            hp.plotter(ppg_working_data, ppg_measures, figsize=(15, 6))
            plt.show() # This will open a fourth plot window
        except Exception as e:
            print(f"HeartPy plotter failed: {e}. (Often due to no peaks found in short signal)")
    else:
        print("HeartPy processing failed, skipping plot.")

    print("\n--- Key PPG/HRV Measures (from 10s sample) ---")
    print(f"  BPM (mean): {ppg_measures.get('bpm', 'N/A')}")
    print(f"  RMSSD (Time Domain HRV): {ppg_measures.get('rmssd', 'N/A')}")
    print(f"  LF/HF Ratio (Freq Domain): {ppg_measures.get('lf/hf', 'N/A')} <-- SEE NOTE ABOVE")
else:
    print("Skipping PPG processing as data was not loaded.")
    ppg_measures = {} # Define as empty dict


In [None]:
print("\n--- FINAL INSIGHTS (from V1 Engine) ---")
all_insights = []
if 'eeg_features' in locals() and 'ppg_measures' in locals():
    for epoch in eeg_features:
        insight = interpret_epoch_insights_v1(epoch, ppg_measures, channel_names)
        all_insights.append(insight)
        
    print(json.dumps(all_insights, indent=2))
else:
    print("Skipping insights engine as features were not extracted.")