In [1]:
"""
Convert .mat file to compact JSON with pre-computed bandpower and YASA analysis
This creates a comprehensive JSON with spectrograms, hypnogram, sleep stats, etc.
"""

import scipy.io as sio
import numpy as np
import yasa
import mne
import json
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import base64

def fig_to_base64(fig):
    """Convert matplotlib figure to base64 string."""
    buf = BytesIO()
    # Handle both Figure and Axes objects
    if hasattr(fig, 'savefig'):
        fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    else:
        # If it's an Axes object, get the figure
        fig.get_figure().savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    img_str = base64.b64encode(buf.read()).decode()
    plt.close('all')  # Close all figures to prevent memory leaks
    return img_str

def process_mat_file(input_path, output_path, hypno_path=None):
    """Process .mat file and save compact JSON with full YASA analysis."""
    print(f"Loading {input_path}...")
    
    # Load .mat file
    mat = sio.loadmat(input_path)
    
    # Stack 4 EEG channels
    signals = np.stack((mat["sig1"], mat["sig2"], mat["sig3"], mat["sig4"]))
    
    # Sleep stage mapping from .mat file to YASA standard
    # Original: 0=N3, 1=N2, 2=N1, 3=REM, 4=Wake, -1=Artifact/Unknown
    # YASA:     0=Wake, 1=N1, 2=N2, 3=N3, 4=REM
    STAGE_REMAP = {
        0: 3,  # N3 (Deep) -> N3
        1: 2,  # N2 (Light) -> N2
        2: 1,  # N1 (Light) -> N1
        3: 4,  # REM -> REM
        4: 0   # Wake -> Wake
        # -1: excluded (artifact/unknown)
    }
    
    # Check for labels and filter out artifacts
    epoch_len = 30  # seconds
    sf = 100  # Hz
    epoch_samples = epoch_len * sf
    n_epochs_raw = signals.shape[1]  # Number of epochs in raw data
    
    labels_raw = None  # Initialize
    if 'labels' in mat:
        labels_raw = mat['labels'].flatten()[:n_epochs_raw]
        
        # Identify valid epochs (not -1)
        valid_mask = labels_raw != -1
        n_artifacts = np.sum(~valid_mask)
        
        print(f"Found {n_epochs_raw} total epochs in .mat file")
        if n_artifacts > 0:
            print(f"Excluding {n_artifacts} artifact/unknown epochs (label = -1)")
        
        # Filter signals to keep only valid epochs
        signals = signals[:, valid_mask, :]
        labels_raw = labels_raw[valid_mask]
    else:
        print(f"No labels found in .mat file - will use all {n_epochs_raw} epochs")
    
    # Reshape into continuous data
    signals_continuous = signals.reshape(signals.shape[0], -1)
    
    # Create MNE Raw object
    ch_names = ["F3-C3", "C3-O1", "F4-C4", "C4-O2"]
    info = mne.create_info(ch_names, sf, ch_types=["eeg"] * 4)
    raw = mne.io.RawArray(signals_continuous * 1e-6, info)
    
    print(f"Processing {raw.n_times / sf / 60:.1f} minutes of valid data...")
    
    # Define frequency bands
    bands = [
        (0.5, 4, "Delta"),
        (4, 8, "Theta"),
        (8, 12, "Alpha"),
        (12, 16, "Sigma"),
        (16, 30, "Beta"),
        (30, 45, "Gamma")
    ]
    
    # Calculate bandpower for each epoch
    total_samples = raw.n_times
    n_epochs = int(total_samples / epoch_samples)
    
    print(f"Computing bandpower for {n_epochs} epochs...")
    
    bandpower_list = []
    for i in range(n_epochs):
        if i % 100 == 0:
            print(f"  Processing epoch {i}/{n_epochs}...")
            
        start = i * epoch_len
        end = (i + 1) * epoch_len
        segment = raw.copy().crop(tmin=start, tmax=end, include_tmax=False)
        
        bp = yasa.bandpower(segment, bands=bands, relative=True)
        
        # Convert to compact dict format
        epoch_data = {
            'epoch': i,
            'channels': {}
        }
        
        for ch_idx, ch_name in enumerate(ch_names):
            epoch_data['channels'][ch_name] = {
                'Delta': float(bp['Delta'].iloc[ch_idx]),
                'Theta': float(bp['Theta'].iloc[ch_idx]),
                'Alpha': float(bp['Alpha'].iloc[ch_idx]),
                'Sigma': float(bp['Sigma'].iloc[ch_idx]),
                'Beta': float(bp['Beta'].iloc[ch_idx]),
                'Gamma': float(bp['Gamma'].iloc[ch_idx])
            }
        
        bandpower_list.append(epoch_data)
    
    # Load or generate hypnogram
    print("\nProcessing hypnogram...")
    
    # First check if labels are in the .mat file
    if labels_raw is not None and len(labels_raw) > 0:
        print("  Using labels from .mat file (artifacts already excluded)")
        
        # Remap to YASA standard
        print("  Remapping sleep stages to YASA standard format...")
        hypno = np.array([STAGE_REMAP[int(label)] for label in labels_raw])
        
        # Show distribution
        unique, counts = np.unique(labels_raw, return_counts=True)
        print("  Original stage distribution (valid epochs only):")
        stage_names_original = {0: 'N3', 1: 'N2', 2: 'N1', 3: 'REM', 4: 'Wake'}
        for stage, count in zip(unique, counts):
            print(f"    {stage_names_original.get(int(stage), 'Unknown')}: {count} epochs")
            
    elif hypno_path and os.path.exists(hypno_path):
        print(f"  Loading hypnogram from {hypno_path}")
        hypno = pd.read_csv(hypno_path, squeeze=True)
        if isinstance(hypno, pd.DataFrame):
            hypno = hypno.iloc[:, 0]
        hypno = hypno.values[:n_epochs]  # Ensure same length
    else:
        print("  No hypnogram provided. Using automatic sleep staging...")
        print("  This may take a few minutes...")
        sls = yasa.SleepStaging(raw, eeg_name=ch_names[0])
        hypno_pred = sls.predict()
        hypno = yasa.hypno_str_to_int(hypno_pred)
        print("  Automatic staging complete!")
    
    # Upsample hypnogram to match data
    hypno_up = yasa.hypno_upsample_to_data(hypno, sf_hypno=1/30, data=raw)
    
    # Generate full-night spectrogram
    print("\nGenerating full-night spectrogram...")
    fig = yasa.plot_spectrogram(raw.get_data()[0], sf, hypno_up, win_sec=30, fmax=30)
    spectrogram_fullnight = fig_to_base64(fig)
    
    # Generate hypnogram plot
    print("Generating hypnogram plot...")
    fig = yasa.plot_hypnogram(hypno)
    hypnogram_plot = fig_to_base64(fig)
    
    # Calculate sleep statistics
    print("Calculating sleep statistics...")
    sleep_stats = yasa.sleep_statistics(hypno, sf_hyp=1/30)
    # Convert numpy types to native Python types
    sleep_stats = {k: float(v) if isinstance(v, (np.integer, np.floating)) else v 
                   for k, v in sleep_stats.items()}
    
    # Calculate transition matrix
    print("Calculating stage-transition matrix...")
    counts, probs = yasa.transition_matrix(hypno)
    transition_matrix = {
        'labels': probs.index.tolist(),
        'values': probs.values.tolist()
    }
    
    # Generate epoch spectrograms (for first 10 epochs as examples)
    print("Generating epoch spectrograms...")
    epoch_spectrograms = {}
    stage_labels = ['Wake', 'N1', 'N2', 'N3', 'REM']
    
    

    for i in range(n_epochs):
        start_sample = i * epoch_samples
        end_sample = (i + 1) * epoch_samples
        epoch_data = raw.get_data()[:, start_sample:end_sample]
        
        # Create spectrogram for this epoch
        fig, ax = plt.subplots(figsize=(8, 4))
        from scipy import signal
        f, t, Sxx = signal.spectrogram(epoch_data[0], sf, nperseg=int(2*sf), 
                                        noverlap=int(1.5*sf))
        # Limit to 30 Hz
        freq_mask = f <= 30
        im = ax.pcolormesh(t, f[freq_mask], 10 * np.log10(Sxx[freq_mask] + 1e-10), 
                           shading='auto', cmap='RdBu_r')
        ax.set_ylabel('Frequency (Hz)')
        ax.set_xlabel('Time (s)')
        stage_name = stage_labels[int(hypno[i])]
        ax.set_title(f'Epoch {i} - Stage: {stage_name}')
        plt.colorbar(im, ax=ax, label='Power (dB)')
        
        epoch_spectrograms[str(i)] = fig_to_base64(fig)
    
    # Store hypnogram
    hypno_data = {
        'values': hypno.tolist(),
        'labels': ['Wake', 'N1', 'N2', 'N3', 'REM']
    }
    
    # Create output structure
    output = {
        'n_epochs': n_epochs,
        'channels': ch_names,
        'bands': ['Delta', 'Theta', 'Alpha', 'Sigma', 'Beta', 'Gamma'],
        'sampling_rate': sf,
        'epoch_length': epoch_len,
        'data': bandpower_list,
        'hypnogram': hypno_data,
        'hypnogram_plot': hypnogram_plot,
        'spectrogram_fullnight': spectrogram_fullnight,
        'epoch_spectrograms': epoch_spectrograms,
        'sleep_statistics': sleep_stats,
        'transition_matrix': transition_matrix
    }
    
    # Save to JSON
    print(f"\nSaving to {output_path}...")
    with open(output_path, 'w') as f:
        json.dump(output, f)
    
    # Get file size
    file_size = os.path.getsize(output_path)
    print(f"Done! Output file size: {file_size / (1024*1024):.1f} MB")
    print(f"\nSummary:")
    print(f"  - {n_epochs} valid epochs processed")
    
    # Show artifact info if applicable
    if 'labels' in mat:
        labels_all = mat['labels'].flatten()
        n_artifacts = np.sum(labels_all == -1)
        n_total = len(labels_all)
        if n_artifacts > 0:
            print(f"  - {n_artifacts} artifact epochs excluded (label = -1)")
            print(f"  - Total epochs in file: {n_total} ({n_epochs} valid + {n_artifacts} artifacts)")
    
    # Determine hypnogram source
    if labels_raw is not None:
        print(f"  - Hypnogram: extracted from .mat file and remapped")
    elif hypno_path:
        print(f"  - Hypnogram: loaded from {hypno_path}")
    else:
        print(f"  - Hypnogram: auto-generated using YASA ML model")
    
    print(f"  - Full-night spectrogram: ✓")
    print(f"  - {len(epoch_spectrograms)} epoch spectrograms: ✓")
    print(f"  - Sleep statistics: ✓")
    print(f"  - Transition matrix: ✓")


In [6]:
input_file = "../../preprocessed_data_201_N1.mat"
output_file = input_file.replace(".mat", "_full.json")
process_mat_file(input_file, output_file, )

Loading ../../preprocessed_data_201_N1.mat...
Found 1487 total epochs in .mat file
Excluding 422 artifact/unknown epochs (label = -1)
Creating RawArray with float64 data, n_channels=4, n_times=3195000
    Range : 0 ... 3194999 =      0.000 ... 31949.990 secs
Ready.
Processing 532.5 minutes of valid data...
Computing bandpower for 1065 epochs...
  Processing epoch 0/1065...
  Processing epoch 100/1065...
  Processing epoch 200/1065...
  Processing epoch 300/1065...
  Processing epoch 400/1065...
  Processing epoch 500/1065...
  Processing epoch 600/1065...
  Processing epoch 700/1065...
  Processing epoch 800/1065...
  Processing epoch 900/1065...
  Processing epoch 1000/1065...

Processing hypnogram...
  Using labels from .mat file (artifacts already excluded)
  Remapping sleep stages to YASA standard format...
  Original stage distribution (valid epochs only):
    N3: 7 epochs
    N2: 668 epochs
    N1: 57 epochs
    REM: 191 epochs
    Wake: 142 epochs

Generating full-night spectrog

  freq_str = pd.tseries.frequencies.to_offset(pd.Timedelta(1 / sf_hypno, "S")).freqstr



Saving to ../../preprocessed_data_201_N1_full.json...
Done! Output file size: 35.7 MB

Summary:
  - 1065 valid epochs processed
  - 422 artifact epochs excluded (label = -1)
  - Total epochs in file: 1487 (1065 valid + 422 artifacts)
  - Hypnogram: extracted from .mat file and remapped
  - Full-night spectrogram: ✓
  - 1065 epoch spectrograms: ✓
  - Sleep statistics: ✓
  - Transition matrix: ✓
