# Estimate sleep State from Video

In [None]:
import gc
import glob
from pathlib import Path
import numpy as np
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
#For Arial font
#!conda install -c conda-forge -y mscorefonts
##-> The below was also needed in matplotlib 3.4.2
#import shutil
#import matplotlib
#shutil.rmtree(matplotlib.get_cachedir())
import warnings
warnings.filterwarnings('ignore')
from IPython.display import display
import time
#For exporting .pdf file with editable text
import matplotlib
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests
import scikit_posthocs as sp

matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
import sys
from decimal import Decimal, ROUND_HALF_UP
import re
import os
from statsmodels.stats.multicomp import pairwise_tukeyhsd

#Create a new conda environment as follows; otherwise, an issue happened (probably due to pip vs. conda)
#!conda create --channel=conda-forge --strict-channel-priority --name=mne-py3 -y mne
##–> Then, use the mne-py3 kernel
import mne
from mne.preprocessing import ICA, create_ecg_epochs, find_ecg_events
from mne.preprocessing import compute_proj_ecg

import scipy.signal as signal

#!pip install meegkit
#import meegkit

#!pip install spkit
#import spkit

#!conda install -c conda-forge -y gwpy
##To avoid the plotting error indcuded by importing gwpy, matplotlib must be downgraded to 3.2.2?
#!conda install -c conda-forge -y matplotlib==3.2.2
#!pip install matplotlib==3.2.2
#import gwpy.timeseries as gpts

#!pip install autoreject
#import autoreject

#!pip install colorcet
##-> To avoid "AttributeError: `np.mat` was removed in the NumPy 2.0 release.", use the previous version
#!pip install numpy==1.26.4
from preraulab.multitaper_toolbox.python import multitaper_spectrogram_python as ms



> **Warning**  
> This project has two incompatible dependencies:  
> - MNE-Python requires **Python 3.10 or later**. 
> - sleepens requires **Python 3.7**.
>  
> Because these requirements conflict, they cannot be satisfied in a single environment.


# 1. Prepare data as the mne.io.Raw object
memo  
以下の内容は予測に過ぎない
- .tsp: 動画のフレームの秒数
- .dat: eeg, emgが保存されているであろうデータ
- .ini, .meta: eeg, emgを保存した際の設定関係の値
  
---

.meta
> Number of recorded channels = 12

.ini 
> aiString=0:7  
> ChanList4=0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, -1, -1, -1, ... , -1  
> doCtlChan=0  
> doSettleChan=2  
> doRECLEDChan=4  
> acqPDChan=4  


> 1,2 ch 筋電  
> 3, 4 ch 頭蓋脳波  
> 5, 6 ch 海馬 (HPC)  
> 7, 8 ch 一次体性感覚野（S1）  
> 9, 10, 11, 12 ch なし（空）
---

sleep ens
>ブレグマから AP +1 mm、ML ±1 mm、およびブレグマから AP -2 mm、ML ±3 mm であった。
>前頭骨と頭頂骨の両側に移植した（ブレグマから前方1.5 mm、外側±1.5 mm、後方2 mm、外側±2.75 mm）
>
## 1-1. load config



In [None]:
def load_meta(path):
    meta = {}
    with open(path, encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            splitter = " = "
            if not line or splitter not in line:
                continue
            key, value = map(str.strip, line.split(splitter, 1))
            try:
                value = float(value)
            except ValueError:
                value = value
            meta[key] = value
    return meta


In [None]:
DATA_FOLDER = Path("/home/data/sleepStateExperiment_from_video/")
meta_path = Path(DATA_FOLDER, "20250917-001.meta")

meta = load_meta(meta_path)
START_TS = meta["TimeStamp of the start of recording (computer clock - ms)"]
END_TS = meta["TimeStamp of the end of recording (computer clock - ms)"]
SAMPLING_RATE = int(meta["Sampling rate"])
FILE_PATH = Path(meta["Filename"])


## 1-2. load data

In [None]:
dat_path = Path(DATA_FOLDER, FILE_PATH.name)

N_CHANNELS = 12
dtype = np.dtype('<i2')  # Little Endian の uint16。合わなければ <i2 に切替

array = np.fromfile(dat_path, dtype=dtype)
if array.size % N_CHANNELS != 0:
    raise Exception(f"Warning: can't divede with N_CHANNELS. size: {array.size}")

n_frames =array.size // N_CHANNELS
data = array.reshape(n_frames, N_CHANNELS)

# delete unused channels
data = data[:, :8]
N_CHANNELS -= 4




In [None]:
# check array when Error( Warning: can't divede with N_CHANNELS) occured
dat_path = Path(DATA_FOLDER, FILE_PATH.name)

N_CHANNELS = 12
dtype = np.dtype('<i2')  # Little Endian の uint16。合わなければ <i2 に切替

array = np.fromfile(dat_path, dtype=dtype)
array = array[ : array.size // N_CHANNELS * N_CHANNELS]
n_frames =array.size // N_CHANNELS
data = array.reshape(n_frames, N_CHANNELS)

# delete unused channels
data = data[:, :8]
N_CHANNELS -= 4




In [None]:
display(data.shape)

## 1-3. convert to V

In [None]:
# data is to large to run following,
# data = data.astype(np.float16)/(1 << 15) * 5 / (10 ** 3)
# hence, use chunk
# TODO: is it better to use float32?
chunk = 50_000_000
scale = np.float16(1 << 15)
data_v = np.empty_like(data, dtype=np.float16)
for start in range(0, data.size, chunk):
    end = min(start + chunk, data.size)
    tmp_data = data[start:end].astype(np.float16)
    tmp_data = tmp_data / scale * 5 / (10 ** 3)
    data_v[start:end] = tmp_data

In [None]:
tmp = data_v[1500000:2000000]
for i in range(8):
    plt.figure(figsize=(10, 3))
    plt.plot(tmp[:, i])
    plt.title(f"Channel {i}")
    plt.xlabel("Index")
    plt.ylabel("Value")
    plt.tight_layout()
    plt.show()

## 1-4 create mne.io.Raw object

In [None]:
CHANNEL_NAMES = [
    "EMG-1", "EMG-2",
    "Skull-1", "Skull-2",
    "HPC-1", "HPC-2",
    "S1-1", "S1-2"
]
CHANNEL_TYPES = ["emg"] * 2 + ["eeg"] * 6
bad_channels = []

# create an mne.Info object
mneInfo = mne.create_info(CHANNEL_NAMES, SAMPLING_RATE, ch_types=CHANNEL_TYPES)

# Set bad channels
mneInfo["bads"] = bad_channels

display(mneInfo)

In [None]:
# create an mne.io.Raw object
print('Original np.array:', data_v.shape)

#Transpose the table
tempA = data_v.transpose()

#Select the target subject
tempA = tempA[:]
print('-> After transpose and selection:', tempA.shape)

#Convert to the default unit
# tempA = tempA / (10**6)#From uV to V



In [None]:
# # with lots of memory, we can use this code
# # or no split needed if memory is sufficient.

# # #Create an mne.io.Raw 
# MNE_DIR = Path(DATA_FOLDER, "mneRaw")
# n_t = tempA.shape[1]

# # separate with (55 minutes 30 sec) * 3 . this value almost corresponds to 2GB(mne limit) * 3
# # this separation is due to memory crash
# seg_len = SAMPLING_RATE * (60 * 55 + 30) * 3
# seg_len = min(seg_len, n_t)

# raw_written = False
# for part, start in enumerate(range(0, n_t, seg_len)):
#     end = min(start + seg_len, n_t)

#     seg = np.ascontiguousarray(tempA[:, start:end], dtype=np.float32)

#     raw = mne.io.RawArray(seg, mneInfo, first_samp=start)

#     # f_path = Path(MNE_DIR, f"{FILE_PATH.stem}{f'-{part}' if part > 0 else ''}.fif")
#     f_path = Path(MNE_DIR, f"{FILE_PATH.stem}-{part}.fif")
#     raw.save(f_path, overwrite=True)
#     print(f"Saved {f_path}")

#     del raw, seg
#     gc.collect()

In [None]:


def rows_for_2gb_segment(
    n_ch: int,
    sampling_rate: int,
    write_dtype: np.dtype = np.dtype(np.float32),
    target_bytes: int = 2 * 1024**3,
    header_margin: int = 16 * 1024**2,
    round_sec: int = 5,
) -> int:

    # bytes per round_sec
    dtype_size = write_dtype.itemsize
    bytes_per_row = n_ch * dtype_size
    bytes_per_round_sec = bytes_per_row * int(sampling_rate * round_sec)

    # effective size budget (avoid header/metadata overhead)
    effective_max = target_bytes - header_margin

    # max rows that fit in the budget
    max_round_sec = effective_max // bytes_per_round_sec
    max_rows = max_round_sec * int(sampling_rate * round_sec)

    return max_rows


WRITE_DTYPE = np.float32  # recommended by mne
MNE_DIR = Path(DATA_FOLDER, "mneRaw")
n_ch, n_t = tempA.shape

# row length per file
seg_len = rows_for_2gb_segment(n_ch, SAMPLING_RATE)

# save
for part, start in enumerate(range(0, n_t, seg_len)):
    end = min(start + seg_len, n_t)

    seg = np.ascontiguousarray(tempA[:, start:end], dtype=WRITE_DTYPE)

    raw = mne.io.RawArray(seg, mneInfo, first_samp=start)

    f_path = Path(MNE_DIR, f"{FILE_PATH.stem}-{part}.fif")
    raw.save(f_path, overwrite=True, fmt="single") # single means float32. defaults to single
    print(f"Saved {f_path}  |  samples={seg.shape[1]}  seconds≈{seg.shape[1]/SAMPLING_RATE:.1f}")

    del raw, seg
    gc.collect()


# 2. Preprocessing

In [None]:
# mneR = mne.io.read_raw_fif(
#     Path(MNE_DIR, FILE_PATH.with_suffix(".fif").name), preload=False
# )

mneR = mne.io.read_raw_fif(
    # to develop faster, use small file.
    # Path(DATA_FOLDER, "mneRaw", "20250917-001-0.fif")
    Path(DATA_FOLDER, "mneRaw", "20250917-001-3.fif")
)
mneI = mneR.info.copy()
mneR


In [None]:
def display_psdPlot(legends, raws):
    f_min, f_max = 0, 200#Target frequency [Hz]
    t_min, t_max = 600, 610#Target time range [sec]
    picks = ["eeg", "emg"]
    sns.set(style='ticks', context='notebook')
    for title, data in zip(legends, raws):
        psd = mneR.compute_psd(method='multitaper', fmin=f_min, fmax=f_max, tmin=t_min, tmax=t_max, picks=picks)
        fig = psd.plot(picks=picks, exclude="bads", amplitude=False, average=False, dB=True)
        fig.suptitle(title, size='xx-large', weight='bold')
    # scalling = dict(eeg=500e-3, emg=500e-3)
    # for data in raws:
    #     fig = data.plot(duration=1, n_channels=8, scalings=scalling)


In [None]:
display_psdPlot(["Original"], [mneR])

In [None]:
def display_waveforms(raw, title_text="", span=(0, -1), unit="sec"):
    """
    Display waveform traces for each channel in an MNE Raw object.

    Parameters
    ----------
    raw : mne.io.Raw
        MNE Raw with time-series data.
    title_text : str, optional
        Title displayed on the first subplot.
    span : tuple (start, end), optional
        Time range to plot. Interpreted in seconds if unit="sec",
        or in minutes if unit="min". Default (0, -1) means full range.
    unit : {"sec", "min"}, optional
        Controls x-axis conversion and downsampling policy.
        - "sec": no downsampling (native sampling)
        - "min": downsample to 1/60 (keep every 60th sample) and x-axis in minutes
    """

    # -------- Settings & inputs --------
    if unit not in {"sec", "min"}:
        raise ValueError("unit must be one of {'sec', 'min'}")

    sfreq = float(raw.info["sfreq"])  # samples per second
    start, end = span

    # -------- Resolve time window (always slice in seconds) --------
    # Interpret span according to unit, then convert to seconds for slicing
    if start is None:
        start = 0.0
    if end is None:
        end = -1

    if unit == "min":
        # Convert minutes to seconds for indexing
        start_sec = float(start) * 60.0 if start not in (0, -1) else float(start if start != 0 else 0.0)
        end_sec = float(end) * 60.0 if end not in (0, -1) else float(end)
    else:  # unit == "sec"
        start_sec = float(start) if start not in (0, -1) else float(start if start != 0 else 0.0)
        end_sec = float(end)

    # Handle "-1" as "till the end"
    if end_sec == -1:
        end_sec = raw.times[-1]
    # Clip to valid interval
    start_sec = max(start_sec, raw.times[0])
    end_sec = min(end_sec, raw.times[-1])

    # Convert time window to sample indices
    start_idx, end_idx = raw.time_as_index([start_sec, end_sec])

    # -------- Load data (slice first at native rate) --------
    data = raw.get_data()[:, start_idx:end_idx]  # shape: (n_ch, n_times)
    times_sec = raw.times[start_idx:end_idx]

    # -------- Unit-dependent x-axis conversion & downsampling --------
    # For "min": keep every 60th sample (1/60 downsampling) and convert x to minutes
    if unit == "min":
        step = max(1, int(60))  # keep every 60th sample; safeguard for very short segments
        data = data[:, ::step]
        times = (times_sec[::step] / 60.0)  # seconds -> minutes for x-axis
        x_label = 'Time [min]'
    else:
        # "sec": no downsampling, x-axis remains in seconds
        times = times_sec
        x_label = 'Time [sec]'

    # -------- Plot (unchanged) --------
    sns.set(style='ticks', context='notebook')
    n_ch = raw.info["nchan"]

    fig, axes = plt.subplots(
        nrows=n_ch, ncols=1,
        figsize=(10, max(2, n_ch / 2)),
        sharex=True, sharey=True,
        gridspec_kw={'height_ratios': [1] * n_ch, 'hspace': 0.0}
    )

    if n_ch == 1:
        axes = [axes]

    # Axis limits (based on converted times and sliced data)
    xlim = (times[0], times[-1]) if times.size else (0, 0)
    ylim = (data.min() if data.size else -1, data.max() if data.size else 1)
    plt.setp(axes, xlim=xlim, ylim=ylim)

    for ax_i, ax in enumerate(axes):
        # Plot waveform (x is in sec or min depending on unit)
        ax.plot(times, data[ax_i], color='k', linewidth=0.25)

        # Clean axes
        ax.spines.top.set_visible(False)
        ax.spines.right.set_visible(False)
        plt.setp(ax.get_yticklabels(), visible=False)
        ax.set_yticks([])

        if ax_i == n_ch - 1:
            ax.set_xlabel(x_label)  # ← keep as-is per your request (do not touch plot features)
        else:
            ax.spines.bottom.set_visible(False)
            ax.get_xaxis().set_visible(False)

        # Channel label on the right
        ax.set_ylabel(
            raw.ch_names[ax_i],
            rotation=0,
            horizontalalignment='right',
            verticalalignment='center',
            rotation_mode='anchor'
        )

        # Title on the first subplot
        if ax_i == 0 and title_text:
            ax.set_title(title_text)

    plt.show()


In [None]:
title_text = f'Subject: {None}, Condition: {None}, Trial: {None} — LFP waveform (full range)'
display_waveforms(mneR, title_text, unit="min")

## 2-1. Detrending
Detrending can be interpreted as substracting a least squares fit polyonimial.  


this part is included in sleepen. see [function process](https://github.com/paradoxysm/sleepens/blob/master/sleepens/protocols/sleepens4/processor.py)

### 2-1-1. Detrending by myself

In [None]:
%%time
print('Before:')
display(mneR.describe())

#Remove linear trend along axis
tempA = signal.detrend(mneR.get_data().T, axis=0, type='linear', bp=np.arange(0, mneR.n_times, SAMPLING_RATE))
mneR_dt = mne.io.RawArray(tempA.T, mneI)
print('After:')
display(mneR_dt.describe())

display(mneR_dt.get_data())#Changed unit: uV


In [None]:
display_waveforms(mneR_dt, unit="min")

In [None]:
display_psdPlot(["Raw", "Detrended"], [mneR, mneR_dt])

### 2-1-2. Detrending by sleepens
this part is in `python/sleepens/detrend.py`

In [None]:
mneR_temp = mne.io.read_raw_fif(
    Path(DATA_FOLDER, "mneRaw_detrend", "20250917-001-3.fif")
)

In [None]:
display_waveforms(mneR_temp, unit="min")

In [None]:
display_psdPlot(["Raw", "Detrended"], [mneR, mneR_temp])


## 2-2. Remove DC component(i.e., 0 Hz noise)

In [None]:
%%time
#Calculate the DC components per channel from all samples
data = mneR_dt.get_data()
tempS = data.mean(axis=1, keepdims=True)
print('DC components:')
display(tempS)

#Subtract them from the original raw data
print('Before:')
display(mneR_dt.describe())

data -= tempS

#Re-create an mne.io.Raw object
mneR_dc = mne.io.RawArray(data, mneI, first_samp=0, copy='auto')

print('After:')
display(mneR_dc.describe())
display(mneR_dc.get_data())#Changed unit: uV

In [None]:
display_waveforms(mneR_dc)

In [None]:
display_psdPlot([ "Detrended", "DC-removed"], [mneR_dt, mneR_dc])

## ~~2-3. Eliminate the power line noises using MEEGkit~~

Utilize a Denoising Source Separation (DSS) technique: Zapline paper (de Cheveigné, A. Neuroimage 2020).  
> meegkit.dss.dss_line_iter():  
> Remove power line artifact iteratively. This method applies dss_line() until the artifact has been smoothed out from the spectrum.  
> Parameters  
> * data : data, shape=(n_samples, n_chans, n_trials)  
Input data.  
> * fline : float  
> Line frequency.  
> * sfreq : float  
> Sampling frequency.  
> * win_sz : float  
> Half of the width of the window around the target frequency used to fit the polynomial (default=10).  
> * spot_sz : float  
> Half of the width of the window around the target frequency used to remove the peak and interpolate (default=2.5).  
> * nfft : int  
> FFT size for the internal PSD calculation (default=512).  
> * show: bool  
> Produce a visual output of each iteration (default=False).  
> * prefix : str  
> Path and first part of the visualisation output file "{prefix}\_{iteration number}.png" (default="dss_iter").  
> * n_iter_max : int  
> Maximum number of iterations (default=100).  
>
> Returns  
> * data : array, shape=(n_samples, n_chans, n_trials)  
> Denoised data.  
> * iterations : int  
> Number of iterations.  

To increase the computational speed, the overall data is split into dummy windows/epochs/trials while discarding the last remainder samples. Note that the splitting method doesn't matter because the power line noises were constantly observed during recordings.  

> ***–> Skip in this version!***  

## ~~2-4. ATAR algorithm~~
Automatic and Tunable Artifact Removal (ATAR) algorithm (Bajaj, N. et al. Biomed. Signal Process. Control 2020) is used to remove artifacts. This algorithm is based on wavelet packet decomposion (WPD), and provided by the Python spkit package (https://spkit.github.io/guide/index.html).  
> - Docs: https://spkit.readthedocs.io/en/latest/index.html  
> - Post about ICA vs. ATAR: https://medium.com/@nikeshbajaj/artifacts-in-eeg-and-how-to-remove-them-atar-algorithm-ica-fbb91ea8485a  
> - GitHub: https://github.com/Nikeshbajaj/spkit  

    ATAR_mCh  
    ========================================================
    Apply ATAR on short windows of signal (multiple channels:):

    Signal is decomposed in smaller overlapping windows and reconstructed after correcting using overlap-add method.
    ----------------

    input
    -----
    X: input multi-channel signal of shape (n,ch)

    Wavelet family:
    wv = ['db3'.....'db38', 'sym2.....sym20', 'coif1.....coif17', 'bior1.1....bior6.8', 'rbio1.1...rbio6.8', 'dmey']
         :'db3'(default)

    Threshold Computation method:
    thr_method : None (default), 'ipr'
           : None: fixed threshold theta_a is applied
           : ipr : applied with theta_a, bf , gf, beta, k1, k2 and OptMode
           : theta_b = bf*theta_a
           : theta_g = gf*theta_a

    Operating modes:
    OptMode = ['soft','elim','linAtten']
             : default 'soft'
             : use 'elim' with globalgood

    Wavelet Decomposition modes:
    wpd_mode = ['zero', 'constant', 'symmetric', 'periodic', 'smooth', 'periodization']
                default 'symmetric'

    Reconstruction Method - Overlap-Add method
    ReconMethod :  None, 'custom', 'HamWin'
    for 'custom': window[0] is used and applied after denoising is window[1] is True else
    windowing applied before denoising

    output
    ------
    XR: corrected signal of same shape as input X
    '''

> ***–> Skip in this version!*** 

## 2-5. Notch filter
remove AC noise(50Hz)

In [None]:
print("Before: ")
display(mneR_dc.describe())
mneR_notch = mneR_dc.notch_filter(
    freqs=50,                # 自動検出
    picks=["eeg", "emg"],
    method="spectrum_fit",
    filter_length="10s",       # 推定安定化
    mt_bandwidth=3.0,          # マルチテーパー帯域幅
    p_value=0.05
)
print("After: ")
display(mneR_notch.describe())
mneR_notch

In [None]:
display_waveforms(mneR_notch)

In [None]:
display_psdPlot(["DC-removed", "notch-50 Hz"], [mneR_dc, mneR_notch])

## ~~2-6. ecg filter~~

filtering heart beats noise from EMG
using [SSP approach](https://mne.tools/stable/auto_tutorials/preprocessing/50_artifact_correction_ssp.html#sphx-glr-auto-tutorials-preprocessing-50-artifact-correction-ssp-py).  
ICA approach is impossible because emg has only 2 channels.

> ***–> Skip in this version!***

In this part, we only applied it to EMG, but since some noise can be seen in EEG, it may be possible to apply it to EEG as well.

In [None]:
# import numpy as np
# import mne
# from mne.preprocessing import find_ecg_events

# # --- 0) I/O handles ---
# raw = mneR_notch.copy()  # if you already applied notch elsewhere, keep consistent
# emg_ch = 'EMG-1'         # EMG channel with strong ECG contamination

# # --- 1) (Optional but recommended) Notch to suppress power-line harmonics on the DETECTION COPY ---
# # If your region is 50 Hz: use np.arange(50, 251, 50); if 60 Hz: np.arange(60, 301, 60)
# line = 50  # or 60 depending on region
# raw_det = raw.copy()  # .notch_filter(freqs=np.arange(line, 6*line+1, line),
#                                  # picks=[emg_ch], method='spectrum_fit')

# # --- 2) (Optional) Light band-pass to emphasize QRS for detection only ---
# raw_det.filter(5., 35., picks=[emg_ch], phase='zero', verbose=False)

# # --- 3) Detect R-peaks on the detection copy using the EMG channel as pseudo-ECG ---
# events, _, average_pulse, ecg_by_mne = find_ecg_events(raw_det, ch_name=emg_ch, event_id=999, return_ecg=True)
# if len(events) == 0:
#     raise RuntimeError("No ECG-like events detected on the EMG channel")
# print(f"average ecg pulse: {average_pulse}")
# print(events)
# # --- 4) Build epochs on the ORIGINAL (broadband) raw so that removal works across all frequencies ---
# tmin, tmax = -0.2, 0.4  # seconds around R
# epochs = mne.Epochs(raw, events, event_id=999, tmin=tmin, tmax=tmax,
#                     picks=[emg_ch], baseline=(None, 0), preload=True,
#                     reject=None, reject_by_annotation=False)
# if len(epochs) == 0:
#     raise RuntimeError("Epochs are empty; check event_id/tmin-tmax/annotations")

# # --- 5) Average template (Evoked) and extract ndarray waveform ---
# template = epochs.average(picks=[emg_ch]).data[0]  # ndarray, shape=(n_times,)

# # --- 6) Overlap-add to reconstruct ECG-like trace over the full recording (broadband estimate) ---
# sfreq = raw.info['sfreq']
# n_times = raw.n_times
# n_temp = len(template)
# offset = int(round(-tmin * sfreq))
# ecg_est = np.zeros(n_times, dtype=np.float64)

# for ev in events[:, 0]:
#     start = ev - offset
#     end = start + n_temp
#     s0 = max(start, 0); e0 = min(end, n_times)
#     ts = s0 - start; te = n_temp - (end - e0)
#     if s0 < e0:
#         ecg_est[s0:e0] += template[ts:te]

# # --- 7a) (Option A) Add estimated ECG as a new channel (for inspection/regression later) ---
# ecg_data = np.vstack([ecg_est, ecg_by_mne]).astype('float64')  # shape: (2, n_times)
# info = mne.create_info(['ECG_by_TS', 'ECG_by_MNE'], sfreq=raw.info['sfreq'], ch_types=['ecg', 'ecg'])
# raw_ecg = mne.io.RawArray(ecg_data, info)
# mneR_with_ecg = raw.copy().add_channels([raw_ecg], force_update_info=True)

# # --- 7b) (Option B) Direct template subtraction on the EMG channel to remove ECG across ALL bands ---
# #       This realizes broadband removal because template/placement are made on the original raw.
# emg_idx = mne.pick_channels(raw.ch_names, [emg_ch])[0]
# raw_broadband_clean = raw.copy()
# raw_broadband_clean._data[emg_idx, :] -= ecg_est  # subtract artifact estimate


# def remove_ecg_from_emg(raw):
#     # filter ecg from emg
#     raw_emg = raw.copy().pick("emg")
#     raw_ecg = raw_emg.copy()
#     raw_emg_ch_names = raw_emg.ch_names.copy()

#     # temporary treat as eeg.
#     raw_emg.set_channel_types(
#         {ch: 'eeg' for ch in raw_emg_ch_names}
#     )
#     # # use emg as ecg
#     raw_ecg.set_channel_types(
#         {ch: 'ecg' for ch in raw_emg_ch_names}
#     )
#     ecg_ch_names = [f"{i}-ecg" for i in raw_emg_ch_names]
#     raw_ecg.rename_channels({i: f"{i}-ecg" for i in raw_emg_ch_names})
#     raw_emg.add_channels([raw_ecg], force_update_info=True)


#     projs, events = compute_proj_ecg(raw_emg, n_grad=0, n_mag=0, n_eeg=2)
#     raw_emg.add_proj(projs)
#     raw_emg_clean = raw_emg.copy().apply_proj()

#     # Restore channel types back to EMG
#     raw_emg_clean.set_channel_types(
#         {ch: "emg" for ch in raw_emg_clean.ch_names}
#     )

#     # Merge back with original raw
#     raw_no_emg = raw.copy().drop_channels(raw_emg_ch_names)
#     raw_emg_clean.drop_channels(ecg_ch_names)
#     raw_out = raw_emg_clean.add_channels([raw_no_emg], force_update_info=True)

#     return raw_out


# def remove_ecg_from_emg(raw):

#     raw = raw.copy()
#     raw_ecg = raw.copy()
#     raw_emg_ch_names = raw.copy().pick("emg").ch_names.copy()

#     # temporary treat as eeg.
#     raw.set_channel_types(
#         {ch: 'eeg' for ch in raw.ch_names}
#     )
#     # # use as ecg
#     raw_ecg.set_channel_types(
#         {ch: 'ecg' for ch in raw.ch_names}
#     )
#     ecg_ch_names = [f"{i}-ecg" for i in raw.ch_names]
#     raw_ecg.rename_channels({i: f"{i}-ecg" for i in raw.ch_names})
#     raw.add_channels([raw_ecg], force_update_info=True)


#     projs, events = compute_proj_ecg(raw, n_grad=0, n_mag=0, n_eeg=2)
#     raw.add_proj(projs)
#     raw_clean = raw.copy().apply_proj()

#     # Restore channel types back to EMG
#     raw_clean.set_channel_types(
#         {ch: "emg" for ch in raw_emg_ch_names}
#     )

#     # Merge back with original raw
#     raw_clean.drop_channels(ecg_ch_names)

#     return raw_clean


# mneR_ecg = remove_ecg_from_emg(mneR_notch)
# mneR_ecg


## 2-7. Band-pass filter
Use a eighth order zero phase lag Butterworth filter.  

In [A novel machine learning system for identifying sleep–wake states in mice](https://academic.oup.com/sleep/article/46/6/zsad101/7109541), they used between 0.3 and 100 Hz for EEG, between 30 and 30k Hz for EMG.
For now, use between 0.3 and 100 Hz for EEG and between 30 and 9999 Hz for EMG.



In [None]:
print("Before:")
print(mneR_notch.describe())

# Preapre an IIR filter
EEG_D = mne.filter.construct_iir_filter(
    iir_params=dict(order=8, ftype='butter', output='sos'),
    f_pass=[0.3, 120],
    f_stop=None,  # Not used if ‘order’ is specified in iir_params
    sfreq=SAMPLING_RATE,
    btype='bandpass',
    phase='zero',
    return_copy=False
)
EMG_D = mne.filter.construct_iir_filter(
    iir_params=dict(order=8, ftype='butter', output='sos'),
    f_pass=[30, 9999],
    f_stop=None,  # Not used if ‘order’ is specified in iir_params
    sfreq=SAMPLING_RATE,
    btype='bandpass',
    phase='zero',
    return_copy=False
)
# Filter data
mneR_butter = mneR_notch.copy().filter(
    l_freq=None, h_freq=None,  # For FIR filter
    picks="eeg",  # All channels
    filter_length='auto',  # For FIR filter
    l_trans_bandwidth='auto', h_trans_bandwidth='auto',  # For FIR filter
    n_jobs=None,  # For FIR filter
    method='iir',
    iir_params=EEG_D,
    phase='zero',
    fir_window='hamming', fir_design='firwin',  # For FIR filter
    skip_by_annotation=('edge', 'bad_acq_skip'),
    pad='reflect_limited',  # For FIR filter
    verbose=None
)
mneR_butter = mneR_butter.filter(
    l_freq=None, h_freq=None,  # For FIR filter
    picks="emg",  # All channels
    filter_length='auto',  # For FIR filter
    l_trans_bandwidth='auto', h_trans_bandwidth='auto',  # For FIR filter
    n_jobs=None,  # For FIR filter
    method='iir',
    iir_params=EMG_D,
    phase='zero',
    fir_window='hamming', fir_design='firwin',  # For FIR filter
    skip_by_annotation=('edge', 'bad_acq_skip'),
    pad='reflect_limited',  # For FIR filter
    verbose=None
)
print("After: ")
display(mneR_butter.describe())
display(mneR_butter)

In [None]:
display_waveforms(mneR_butter)

In [None]:
display_psdPlot(["notch-50 Hz", "Butterworth"], [mneR_notch, mneR_butter])

## ~~2-8. Spectral whitening~~
> *"A whitening transformation or sphering transformation is a linear transformation that transforms a vector of random variables with a known covariance matrix into a set of new variables whose covariance is the identity matrix, meaning that they are uncorrelated and each have variance 1"* (Wikipedia).  

In signal processing, it usually means a transformation to generate flat Fourier spectrum for a given signal, which is originally colored (not white). This processing tends to sharpen signal, as well as the noise. The whitening process is often used for ambient vibration data before stacking waveforms for cross-correlation.  
> - The Dr. Takeuchi's code (mtspectrumc_whiten.m) just multiplies amplitude/power by frequency (after FFT)... It would surely reduce lower frequency, but it changes the unit, doesn't it!?  
> - The MATLAB whitening function (https://www.mathworks.com/matlabcentral/fileexchange/65345-spectral-whitening) process is simple as Fourier transforming the signal after applying Hann window, then normalizing its magnitude, and then inverse Fourier transforming it. The normalization is forcedly setting magnitudes (within the range) with 1.  

–> After searching, I found that GWpy (a signal processing package for gravitational-wave detectors; https://gwpy.github.io/docs/stable/) contains spectral whitening function. Based on its source code, it normalizes the Fourier spectrum by a convolution with a whitening filter (i.e., multiplying the inverse of amplitude). So, it would be applicable to EEG data!  
> - Tutorial: https://gwpy.github.io/docs/stable/examples/timeseries/whiten/  
> - Function: https://gwpy.github.io/docs/stable/api/gwpy.timeseries.TimeSeries/#gwpy.timeseries.TimeSeries.whiten  

        """Whiten this `TimeSeries` using inverse spectrum truncation  

        Parameters
        ----------
        fftlength : `float`, optional
            FFT integration length (in seconds) for ASD estimation,
            default: choose based on sample rate

        overlap : `float`, optional
            number of seconds of overlap between FFTs, defaults to the
            recommended overlap for the given window (if given), or 0

        method : `str`, optional
            FFT-averaging method (default: ``'median'``)

        window : `str`, `numpy.ndarray`, optional
            window function to apply to timeseries prior to FFT,
            default: ``'hann'``
            see :func:`scipy.signal.get_window` for details on acceptable
            formats

        detrend : `str`, optional
            type of detrending to do before FFT (see `~TimeSeries.detrend`
            for more details), default: ``'constant'``

        asd : `~gwpy.frequencyseries.FrequencySeries`, optional
            the amplitude spectral density using which to whiten the data,
            overrides other ASD arguments, default: `None`

        fduration : `float`, optional
            duration (in seconds) of the time-domain FIR whitening filter,
            must be no longer than `fftlength`, default: 2 seconds

        highpass : `float`, optional
            highpass corner frequency (in Hz) of the FIR whitening filter,
            default: `None`

        **kwargs
            other keyword arguments are passed to the `TimeSeries.asd`
            method to estimate the amplitude spectral density
            `FrequencySeries` of this `TimeSeries`

        Returns
        -------
        out : `TimeSeries`
            a whitened version of the input data with zero mean and unit
            variance

        See also
        --------
        TimeSeries.asd
            for details on the ASD calculation
        TimeSeries.convolve
            for details on convolution with the overlap-save method
        gwpy.signal.filter_design.fir_from_transfer
            for FIR filter design through spectrum truncation

        Notes
        -----
        The accepted ``method`` arguments are:

        - ``'bartlett'`` : a mean average of non-overlapping periodograms
        - ``'median'`` : a median average of overlapping periodograms
        - ``'welch'`` : a mean average of overlapping periodograms

        The ``window`` argument is used in ASD estimation, FIR filter design,
        and in preventing spectral leakage in the output.

        Due to filter settle-in, a segment of length ``0.5*fduration`` will be
        corrupted at the beginning and end of the output. See
        `~TimeSeries.convolve` for more details.

        The input is detrended and the output normalised such that, if the
        input is stationary and Gaussian, then the output will have zero mean
        and unit variance.

        For more on inverse spectrum truncation, see arXiv:gr-qc/0509116.
        """

> ***–> Skip in this version!***  

# 3. Multitaper spectral analysis

## 3-1. Multitaper spectrum estimate

The multitaper spectrogram code (multitaper_spectrogram_python.py) is ready to use. Check their descriptions about the parameters  
> README.md:  
* **data**: 1-dimensional time series data  
* **Fs**: Frequency at which the data was sampled in Hz  
* **frequency_range**: \[min frequency, max frequency\] Range of frequencies (Hz) across which to compute the spectrum. The default for all implementations is [0, fs/2].  
* **taper_params**: \[time-halfbandwidth product, number of tapers\] The time-half bandwidth product (TW) can be computed as N*(BW/2) where N is the length of the window (seconds) and BW is the bandwidth of the main lobe. The bandwidth of the main lobe is also called the frequency resolution because it dictates the minimum difference in frequency that can be detected. "Number of tapers" is the number of DPSS tapers to be used to compute the spectrum. The optimal number of tapers is 2*(TW)-1. The default for all implementations is \[5, 9\].  
* **window_params**: \[window size (seconds), step size (seconds)\] These parameters dictate the temporal resolution of the analysis. The multitaper spectrum is computed for a single window of data, then the window moves based on step size and the spectrum will be computed again on the next window until the whole data array has been covered. The default for all implementations is \[5, 1\].  
* **min_nfft**: Multitaper spectrum computation relies on the Fourier Transform to transform the data from the time domain to the frequency domain. The Fast Fourier Transform (FFT) is an very computationally efficient algorithm to compute the Fourier Transform, and it works most efficiently when the number of data points in the given time series is a power of 2. Therefore, we want to pad the data with 0s to make it reach the closest power of 2. This implementation pads with zeros to the nearest power of 2 automatically, but if a specific power of 2 above the closest power fo 2 is desired, use this parameter. The default for all implementations is 0.  
* **weighting**: The DPSS tapers can be weighted differently, and we have included 2 weighting method options - 'eigen' and 'adaptive' - along with the uniformly weighted option 'unity' which is the default for all implementations. Eigenvalue weighting weights the contribution of each taper to the spectrum by it's eigenvalue (frequency concentration). In most cases this makes little difference because most taper's eigenvalues are very close to one anyway. The adaptive weighting method weights the tapers in such a way as to reduce the broadband leakage of non-white ('colored') noise. This method is adapted from pages 368-370 of Percival and Walden's "Spectral Analysis for Physical Applications: Multitaper and Conventional Univariate Techniques"5. In practice, the adaptive method does not change the results much at all but is provided here for the sake of completeness.  
* **detrend_opt**: Each window of data can be detrended to remove very low frequency oscillation artifacts that can come from a variety of sources. In "linear" detrending, a linear model is fit to the window then subtracted out, while in "constant" detrending the data is set to be zero mean. The default for all implementations is "linear", and the options are "linear", "constant", and "off".  
> multitaper_spectrogram_python.py  
* Arguments:  
data (1d np.array): time series data -- required  
fs (float): sampling frequency in Hz  -- required  
frequency_range (list): 1x2 list - \[<min frequency>, <max frequency>\] (default: \[0 nyquist\])  
time_bandwidth (float): time-half bandwidth product (window duration\*half bandwidth of main lobe) (default: 5 Hz\*s)  
num_tapers (int): number of DPSS tapers to use (default: \[will be computed as floor(2*time_bandwidth - 1)\])  
window_params (list): 1x2 list - \[window size (seconds), step size (seconds)\] (default: \[5 1\])  
detrend_opt (string): detrend data window ('linear' (default), 'constant', 'off') (Default: 'linear')  
min_nfft (int): minimum allowable NFFT size, adds zero padding for interpolation (closest 2^x) (default: 0)  
multiprocess (bool): Use multiprocessing to compute multitaper spectrogram (default: False)  
n_jobs (int): Number of cpus to use if multiprocess = True (default: False). Note: if default is left as None and multiprocess = True, the number of cpus used for multiprocessing will be all available - 1.  
weighting (str): weighting of tapers ('unity' (default), 'eigen', 'adapt');  
plot_on (bool): plot results (default: True)  
return_fig (bool): return plotted spectrogram (default: False)  
clim_scale (bool): automatically scale the colormap on the plotted spectrogram (default: True)  
verbose (bool): display spectrogram properties (default: True)  
xyflip (bool): transpose the mt_spectrogram output (default: False)  
ax (axes): a matplotlib axes to plot the spectrogram on (default: None)  
* Returns:  
mt_spectrogram (TxF np array): spectral power matrix  
stimes (1xT np array): timepoints (s) in mt_spectrogram  
sfreqs (1xF np array)L frequency values (Hz) in mt_spectrogram  

In [None]:
%%time

#Compute the multitaper spectrogram
def compute_multipaper_spectogram(ch_idx):
    frequency_range = [0, 130] if ch_idx in (0, 1) else [0, 130]  # emg, emg

    tempA = mneR_butter.get_data()[ch_idx]
    spect, stimes, sfreqs = ms.multitaper_spectrogram(
        data=tempA,
        fs=SAMPLING_RATE,
        frequency_range=frequency_range,
        time_bandwidth=5,#Time-half bandwidth (TW): (10 s x 1 Hz)/2 = 5
        num_tapers=9,#The number of tapers: 5*2 -1 = 9
        #time_bandwidth=2.5,#Time-half bandwidth (TW): (5 s x 1 Hz)/2 = 2.5
        #num_tapers=4,#The number of tapers: 2.5*2 -1 = 4
        window_params=[3, 1.5],
        detrend_opt='linear',
        min_nfft=0,
        multiprocess=True,
        n_jobs=10,
        weighting='unity',
        plot_on=False,#Customize later
        return_fig=False,
        clim_scale=True,
        verbose=True,
        xyflip=False,
        ax=None,
    )

    #Clean the result as a table
    tempDF = pd.DataFrame(
        spect.T,
        index=pd.Index(stimes, name='Time[sec]'),
        columns=pd.Index(sfreqs, name='Frequency[Hz]')
    )
    display(tempDF)
    display(tempDF.describe())

    psdDF = tempDF

    return spect, stimes, sfreqs, psdDF

ch_idx = 2
spect, stimes, sfreqs, psdDF = compute_multipaper_spectogram(ch_idx)

In [None]:
psdDF

## 3-2. Power spectrum density (PSD)
### 3-2-1. Spectrogram vizualization at the time–frequcency domain

In [None]:

def display_spectrogram_dB(spect, title_text=""):
    # Visualize the spectrogram (customized from the original multitaper_spectrogram() code)
    # # Convert from power to dB
    tempA1 = ms.nanpow2db(spect)
    # # Set ranges of x-axis and y-axis
    tempA2 = stimes / 60  # Convert from sec to min
    dx = tempA2[1] - tempA2[0]
    dy = sfreqs[1] - sfreqs[0]
    extent = [tempA2[0] - dx, tempA2[-1] + dx, sfreqs[-1] + dy, sfreqs[0] - dy]
    # # Plot spectrogram
    # sns.set(style='ticks', font='Arial', context='notebook')
    sns.set(style='ticks', context='notebook')
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 3))
    im = ax.imshow(tempA1, extent=extent, aspect='auto')
    # im.set_cmap(plt.cm.get_cmap('cet_rainbow4'))#Older matplotlib
    im.set_cmap(plt.colormaps.get_cmap('cet_rainbow4'))
    clim = np.percentile(tempA1, [5, 98])  # From 5th percentile to 98th
    im.set_clim(clim)  # Change colorbar scale
    ax.invert_yaxis()
    ax.set_xlabel('Time [min]')
    ax.set_ylabel('Frequency [Hz]')
    # ax.set_title('Subject: '+subject+', Condition: '+condition+', Trial: '+trial+', Channel: '+channel)
    ax.set_title(title_text)
    fig.colorbar(im, ax=ax, label='PSD [' + r'$\mathsf{μV^2/Hz}$' + ', dB]')
    plt.show()



title = f"{FILE_PATH.stem}, Channel: {mneR_butter.ch_names[ch_idx]}"
display_spectrogram_dB(spect, title)

### 3-2-2. Power spectrum visualization at the frequency domain

In [None]:
def display_powerSpectrum_dB(psdDFs, subjects, title_text=""):
    # Visualize power spectrum density
    # # Prepare long-format DF for sns.lineplot(), which calculates the mean
    # # with errors automatically during visualization
    tempL = []
    for psdDF, subject in zip(psdDFs, subjects):
        tempDF = ms.nanpow2db(psdDF)  # Convert from power to dB
        # tempDF.index = tempDF.index/60#Convert from sec to min
        # tempDF.index.name = 'Time[min]'
        tempDF = tempDF.reset_index().melt(var_name='Frequency[Hz]', value_name='PSD[dB]', id_vars=['Time[sec]'])
        tempDF['Subject'] = subject
        tempL.append(tempDF)
    tempDF = pd.concat(tempL, axis=0)
    # #Plot
    # sns.set(style='ticks', font='Arial', context='notebook')
    sns.set(style='ticks', context='notebook')
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 3))
    sns.lineplot(data=tempDF, x='Frequency[Hz]', y='PSD[dB]',
                 hue='Subject', palette='tab10', hue_order=None,
                 estimator='mean', ci=95, n_boot=1000, seed=123, sort=True,
                 err_style='band', err_kws=None, legend='auto', ax=ax)
    sns.despine()
    plt.setp(ax, xlim=(tempDF['Frequency[Hz]'].min(), tempDF['Frequency[Hz]'].max()))
    # # Log-scale
    ax.set_xscale('log', base=10)
    ax.minorticks_off()  # Due to ax.set_xscale()
    plt.setp(ax, xlim=(1, tempDF['Frequency[Hz]'].max()))  # Due to the excess range by ax.set_xscale()
    ax.set_xticks([1, 4, 8, 12, 15, 30, 100])  # delta, theta, alpha, sigma, beta & gamma ranges
    ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    plt.setp(ax, xlabel='Frequency [Hz] (log-scale)', ylabel='PSD ['+r'$\mathsf{μV^2/Hz}$'+', dB]\n(Mean with 95% CI)')
    ax.set_title(title_text)
    plt.legend(title='Subject', bbox_to_anchor=(1.0, 0.5), loc='center left')
    plt.show()


def round_freq_bins(df, step=50.0, agg="mean"):
    freqs = df.columns.astype(float)

    freq_bins = (freqs // step) * step   # 例: 123 Hz → 100 Hz bin

    freq_bins = freq_bins.astype(int).astype(str)

    df_out = df.groupby(freq_bins, axis=1).agg(agg)
    df_out.columns = df_out.columns.astype(float)  # 再び数値に直す

    return df_out


# size down if needed
_psdDF = round_freq_bins(psdDF, step=2.0, agg="mean")
# _psdDF = psdDF

display_powerSpectrum_dB([_psdDF], [mneR_butter.ch_names[ch_idx]], title)

## 3-3. Power spectrum (PS)
> According to Dr. Takeuchi's "whitening" processing, not power spectrum density (PSD, uV^2 / Hz) but power spectrum (PS, uV^2) is used as the final presentation. Also, no dB conversion is needed.  

In [None]:
tempDF = psdDF.copy()

# Convert from PSD to PS
for col_n in tempDF.columns:
    tempDF[col_n] = tempDF[col_n] * col_n

display(tempDF)
display(tempDF.describe())

psDF = tempDF

### 3-3-1. Spectrogram vizualization at the time–frequcency domain

In [None]:
def display_spectrogram_uv2(psDF, title_text=""):
    # Visualize the spectrogram
    # (customized from the original multitaper_spectrogram() code)
    # # Convert from power to dB
    # # -> No need this time
    tempA1 = psDF.to_numpy().T
    # # Set ranges of x-axis and y-axis
    tempA2 = stimes / 60  # Convert from sec to min
    dx = tempA2[1] - tempA2[0]
    dy = sfreqs[1] - sfreqs[0]
    extent = [tempA2[0] - dx, tempA2[-1] + dx, sfreqs[-1] + dy, sfreqs[0] - dy]
    # # Plot spectrogram
    # sns.set(style='ticks', font='Arial', context='notebook')
    sns.set(style='ticks', context='notebook')
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 3))
    im = ax.imshow(tempA1, extent=extent, aspect='auto')
    # im.set_cmap(plt.cm.get_cmap('cet_rainbow4'))#Older matplotlib
    im.set_cmap(plt.colormaps.get_cmap('cet_rainbow4'))
    clim = np.percentile(tempA1, [5, 98])  # From 5th percentile to 98th
    im.set_clim(clim)  # Change colorbar scale
    ax.invert_yaxis()
    ax.set_xlabel('Time [min]')
    ax.set_ylabel('Frequency [Hz]')
    ax.set_title(title_text)
    fig.colorbar(im, ax=ax, label='Power [' + r'$\mathsf{μV^2}$' + ']')
    plt.show()


display_spectrogram_uv2(psDF, title)

In [None]:
#Visualize both waveform and spectrogram
def display_waveforms_and_spectrogram(ch_idx, psDF, title_text=""):
    # sns.set(style='ticks', font='Arial', context='notebook')
    sns.set(style='ticks', context='notebook')
    fig, axes = plt.subplots(
        nrows=2, ncols=1, figsize=(7.5, 6.5), sharex=True, sharey=False,
        gridspec_kw={'height_ratios': [1, 1], 'hspace': 0.1}
    )
    for ax_i, ax in enumerate(axes.flat):
        # sns.despine() seems to override the below bottom spine setting
        ax.spines.top.set_visible(False)
        # sns.despine() seems to override the below bottom spine setting
        ax.spines.right.set_visible(False)
        if ax_i == 0:
            # Visualize LFP waveform trace (full range)
            # # Convert from sec to min for x-axis
            data = mneR_butter.get_data()
            times = mneR_butter.times.copy() / 60
            # # Plot
            ax.plot(times, data[ch_idx], color='k', linewidth=0.1)
            plt.setp(axes, xlim=(times.min(), times.max()))  # Set across panels
            ax.yaxis.set_major_formatter(
                matplotlib.ticker.StrMethodFormatter('{x:,.0f}')
            )
            # # Set axis
            plt.setp(ax.get_xticklabels(), visible=False)
            plt.setp(ax, xlabel='', ylabel='LFP [μV]')
            ax.set_title(title_text)  # Overall title
        else:
            # Visualize the spectrogram
            # (customized from the original multitaper_spectrogram() code)
            # # Convert from power to dB
            # # -> No need this time
            tempA1 = psDF.to_numpy().T
            # # Set ranges of x-axis and y-axis
            tempA2 = stimes / 60  # Convert from sec to min
            dx = tempA2[1] - tempA2[0]
            dy = sfreqs[1] - sfreqs[0]
            extent = [
                tempA2[0] - dx, tempA2[-1] + dx,
                sfreqs[-1] + dy, sfreqs[0] - dy
            ]
            # # Plot spectrogram
            im = ax.imshow(tempA1, extent=extent, aspect='auto')
            # im.set_cmap(plt.cm.get_cmap('cet_rainbow4'))#Older matplotlib
            im.set_cmap(plt.colormaps.get_cmap('cet_rainbow4'))
            clim = np.percentile(tempA1, [5, 98])  # From 5th percentile to 98th
            im.set_clim(clim)  # Change colorbar scale
            ax.invert_yaxis()
            # # Log-scale
            ax.set_yscale('log', base=10)
            ax.minorticks_off()  # Due to ax.set_yscale()
            plt.setp(ax, ylim=(1, sfreqs[-1] + dy))  # Due to the excess range by ax.set_yscale()
            ax.set_yticks([1, 4, 8, 12, 15, 30, 100])  # delta, theta, alpha, sigma, beta & gamma ranges
            ax.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
            plt.setp(ax, xlabel='Time [min]', ylabel='Frequency [Hz] (log-scale)')
            # fig.colorbar(
            #     im, ax=axes, label='PSD [' + r'$\mathsf{μV^2/Hz}$' + ', dB]',
            #     location='bottom', orientation='horizontal', shrink=0.5, anchor=(1.0, 1.0)
            # )
            # #Make a colorbar manually to set title position
            cax = fig.add_axes([0.5, -0.025, 0.4, 0.03])  # Manual adjustment
            plt.colorbar(im, cax=cax, orientation='horizontal')
            cax.set_ylabel(
                'Power [' + r'$\mathsf{μV^2}$' + ']',
                rotation=0, horizontalalignment='right',
                verticalalignment='center', rotation_mode='anchor'
            )
    plt.show()


display_waveforms_and_spectrogram(ch_idx, psDF, title)

### 3-2-2. Power spectrum visualization at the frequency domain

In [None]:
def display_powerSpectrum_uv2(psDFs, subjects, title_text=""):
    # Visualize power spectrum density
    # # Prepare long-format DF for sns.lineplot(),
    # # which calculates the mean with errors automatically during visualization
    tempL = []
    for psDF, subject in zip(psDFs, subjects):
        tempDF = psDF  # Convert from power to dB -> No need this time
        # tempDF.index = tempDF.index/60#Convert from sec to min
        # tempDF.index.name = 'Time[min]'
        tempDF = tempDF.reset_index()
        tempDF = tempDF.melt(
            var_name='Frequency[Hz]',
            value_name='Power[uV^2]',
            id_vars=['Time[sec]']
        )
        tempDF['Subject'] = subject
        tempL.append(tempDF)
    tempDF = pd.concat(tempL, axis=0)
    # # Plot
    # sns.set(style='ticks', font='Arial', context='notebook')
    sns.set(style='ticks', context='notebook')
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 3))
    sns.lineplot(data=tempDF, x='Frequency[Hz]', y='Power[uV^2]',
                 hue='Subject', palette='tab10', hue_order=None,
                 estimator='mean', ci=95, n_boot=1000, seed=123, sort=True,
                 err_style='band', err_kws=None, legend='auto', ax=ax)
    sns.despine()
    plt.setp(ax, xlim=(tempDF['Frequency[Hz]'].min(), tempDF['Frequency[Hz]'].max()))
    # # Log-scale
    ax.set_xscale('log', base=10)
    ax.minorticks_off()  # Due to ax.set_xscale()
    plt.setp(ax, xlim=(1, tempDF['Frequency[Hz]'].max()))  # Due to the excess range by ax.set_xscale()
    ax.set_xticks([1, 4, 8, 12, 15, 30, 100])  # delta, theta, alpha, sigma, beta & gamma ranges
    ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    ax.set_yscale('log', base=10)
    ax.axvline(x=50, **{'linestyle': '--', 'color': 'k'})
    # plt.setp(ax, ylim=(50, 500000))#To fix the range among the channels/samples
    plt.setp(
        ax,
        xlabel='Frequency [Hz] (log-scale)',
        ylabel='Power [' + r'$\mathsf{μV^2}$' + ']\n(Mean with 95% CI)'
    )
    ax.set_title(title_text)
    plt.legend(title='Subject', bbox_to_anchor=(1.0, 0.5), loc='center left')
    plt.show()


# size down if needed
_psDF = round_freq_bins(psDF, step=2.0, agg="mean")
# _psdDF = psdDF

display_powerSpectrum_uv2([_psDF], [mneR_butter.ch_names[ch_idx]], title)

## Cf. 1. Different channel 1

In [None]:
ch_idx = 2
spect, stimes, sfreqs, psdDF = compute_multipaper_spectogram(ch_idx)
title = f"{FILE_PATH.stem}, Channel: {mneR_butter.ch_names[ch_idx]}"
display_spectrogram_dB(spect, title)
display_powerSpectrum_dB([psdDF], [mneR_butter.ch_names[ch_idx]], title)

tempDF = psdDF.copy()

# Convert from PSD to PS
for col_n in tempDF.columns:
    tempDF[col_n] = tempDF[col_n] * col_n

display(tempDF)
display(tempDF.describe())

psDF = tempDF
display_spectrogram_uv2(psDF, title)
display_waveforms_and_spectrogram(ch_idx, psDF, title)
display_powerSpectrum_uv2([psDF], [mneR_butter.ch_names[ch_idx]], title)



## Cf. 2. Different channel 2

In [None]:
ch_idx = 2
spect, stimes, sfreqs, psdDF = compute_multipaper_spectogram(ch_idx)
title = f"{FILE_PATH.stem}, Channel: {mneR_butter.ch_names[ch_idx]}"
display_spectrogram_dB(spect, title)
display_powerSpectrum_dB([psdDF], [mneR_butter.ch_names[ch_idx]], title)

tempDF = psdDF.copy()

# Convert from PSD to PS
for col_n in tempDF.columns:
    tempDF[col_n] = tempDF[col_n] * col_n

display(tempDF)
display(tempDF.describe())

psDF = tempDF
display_spectrogram_uv2(psDF, title)
display_waveforms_and_spectrogram(ch_idx, psDF, title)
display_powerSpectrum_uv2([psDF], [mneR_butter.ch_names[ch_idx]], title)



In [None]:
sfreqs

# 4. SleepEns

## 4-1. Run Sleep Ensemble
reference: [https://github.com/paradoxysm/sleepens/blob/master/sleepens/main.py](https://github.com/paradoxysm/sleepens/blob/master/sleepens/main.py)

We run SleepENS on Python 3.7 with pinned dependencies, using the system Python and a virtual environment (venv).  
Outputs are written to `$DATA_FOLDER/sleepEnsOutput/`, and scripts are located in `./python/sleepens/`.


## 4-2. check result
'SCORE_MAP': { 'AW': 0, 'QW': 1, 'NR': 2, 'R': 3 },

In [None]:
# to develop faster, use small file.
    # Path(DATA_FOLDER, "mneRaw", "20250917-001-0.fif")
file_name = "20250917-001-0{}.fif"

mneR_ens = mne.io.read_raw_fif(
    Path(DATA_FOLDER, "sleepEnsOutput", file_name.format("-predictions"))
)
mneI_ens = mneR_ens.info.copy()
display(mneR_ens)
display(mneR_ens.annotations)
print(mneR_ens.ch_names)
print("=" * 30)

mneR_orig = mne.io.read_raw_fif(
    Path(DATA_FOLDER, "mneRaw_band_notch", file_name.format(""))
)
display(mneR_orig)
print(mneR_orig.ch_names)


In [None]:


# Hard-coded state mapping
SCORE_MAP = {'AW': 0, 'QW': 1, 'NREM': 2, 'REM': 3}
INV_SCORE_MAP = {v: k for k, v in SCORE_MAP.items()}
_STATE_COLORS = {
    0: (162, 53, 47),   # AW
    1: (236, 197, 72),  # QW
    2: (81, 158, 89),   # NREM
    3: (52, 71, 113),   # REM
}

# Normalize to 0–1 for Matplotlib
STATE_COLORS = {k: tuple(np.array(v, dtype=float) / 255.0) for k, v in _STATE_COLORS.items()}


def display_waveforms_with_states_simple(
    raw: mne.io.Raw,
    labels_raw: mne.io.Raw,     # 1-channel Raw; integer labels 0..3 per sample
    title_text: str = "",
    span: tuple = (0, -1),      # (start, end) in sec or min
    unit: str = "sec",          # "sec", "min", or "time"  ← "time" 追加
    alpha_band: float = 0.15
):
    """
    Plot multi-channel waveforms and overlay 4-state categorical bands.
    Labels are provided as a 1-channel mne.io.Raw, with integer states:
      0=AW, 1=QW, 2=NR, 3=R.

    Parameters
    ----------
    raw : mne.io.Raw
        Source time-series to plot.
    labels_raw : mne.io.Raw
        1-channel Raw containing per-sample integer labels (0..3).
        If sfreq differs from `raw`, it will be resampled to match.
    title_text : str
        Title for the figure.
    span : tuple
        Time range to plot (start, end) interpreted in seconds or minutes.
        Use -1 or None for end to indicate the end of recording.
    unit : {"sec","min","time"}
        Unit for `span`. "time" は表示のみ HH:MM:SS（解釈は秒）。
    alpha_band : float
        Alpha value for the background state bands.
    """
    # ---- Validate inputs ----
    if unit not in {"sec", "min", "time"}:  # CHANGED
        raise ValueError("unit must be 'sec' or 'min' or 'time'")  # CHANGED
    if labels_raw.info["nchan"] != 1:
        raise ValueError("labels_raw must have exactly 1 channel (integer labels 0..3).")

    # ---- Match sampling frequency ----
    sfreq = raw.info["sfreq"]
    labs = labels_raw.copy()
    if not np.isclose(labs.info["sfreq"], sfreq):
        labs.resample(sfreq, npad="auto")

    # ---- Extract and sanitize label vector ----
    lab = labs.get_data()[0]              # shape (n_times_labels,)
    # Round to nearest integer, clip to 0..3
    lab = np.rint(lab).astype(int)
    lab = np.clip(lab, 0, 3)

    # Align label length to raw length (nearest-neighbor if different)
    if lab.shape[0] != raw.n_times:
        idx = np.linspace(0, lab.shape[0] - 1, raw.n_times)
        lab = lab[np.rint(idx).astype(int)]

    # ---- Determine plotting window in seconds ----
    start, end = span
    if unit == "min":
        start = None if start is None else float(start) * 60.0
        end   = None if end   is None else float(end)   * 60.0
    else:  # "sec" or "time"  # CHANGED
        start = None if start is None else float(start)
        end   = None if end   is None else float(end)

    if end in (-1, None):
        end = raw.times[-1]
    start = max(start if start is not None else raw.times[0], raw.times[0])
    end   = min(end, raw.times[-1])

    s_idx, e_idx = raw.time_as_index([start, end])
    data  = raw.get_data()[:, s_idx:e_idx]
    times = raw.times[s_idx:e_idx]
    labs_win = lab[s_idx:e_idx]

    # ---- Find contiguous segments of same label ----
    def contiguous_segments(lbl: np.ndarray):
        if lbl.size == 0:
            return []
        edges = np.where(np.diff(lbl) != 0)[0] + 1
        idxs = np.r_[0, edges, lbl.size]
        return [(idxs[i], idxs[i+1], int(lbl[idxs[i]])) for i in range(len(idxs)-1)]

    segs = contiguous_segments(labs_win)

    # ---- Create figure ----
    n_ch = raw.info["nchan"]
    fig_h = max(1.6, n_ch / 1.5)
    fig, axes = plt.subplots(
        nrows=n_ch, ncols=1,
        figsize=(8.5, fig_h),
        sharex=True, sharey=True,
        gridspec_kw={'height_ratios': [1]*n_ch, 'hspace': 0.0}
    )
    if n_ch == 1:
        axes = [axes]

    # NEW: unit=="time" なら HH:MM:SS フォーマッタ（表示のみ変更）
    if unit == "time":  # NEW
        def _sec_to_hhmmss(x, pos):  # x は秒（float）
            s = int(max(0, round(x)))
            h = s // 3600
            m = (s % 3600) // 60
            ss = s % 60
            return f"{h:02d}:{m:02d}:{ss:02d}"
        for ax in axes:
            ax.xaxis.set_major_formatter(FuncFormatter(_sec_to_hhmmss))

    # Axis limits
    if times.size == 0 or data.size == 0:
        xlim = (0, 0); ylim = (-1, 1)
    else:
        xlim = (times[0], times[-1])
        dmin, dmax = float(np.nanmin(data)), float(np.nanmax(data))
        if dmin == dmax:
            dmin, dmax = dmin - 1.0, dmax + 1.0
        ylim = (dmin, dmax)
    for ax in axes:
        ax.set_xlim(*xlim)
        ax.set_ylim(*ylim)

    # ---- Draw channels with background state bands ----
    for i, ax in enumerate(axes):
        # Background bands per contiguous state
        for s, e, lab_val in segs:
            x0 = times[s]
            x1 = times[e-1] if e-1 < times.size else times[-1]
            ax.axvspan(x0, x1, color=STATE_COLORS[lab_val], alpha=alpha_band)

        # Waveform trace
        ax.plot(times, data[i], linewidth=0.35, color='k')

        # Cosmetics
        ax.spines.top.set_visible(False)
        ax.spines.right.set_visible(False)
        ax.set_yticks([])
        ax.set_ylabel(
            raw.ch_names[i],
            rotation=0, ha='right', va='center',
            rotation_mode='anchor'
        )
        if i == 0 and title_text:
            ax.set_title(title_text)
        if i == n_ch - 1:
            if unit == "time":                      # NEW
                ax.set_xlabel("Time [HH:MM:SS]")    # NEW
            else:
                ax.set_xlabel("Time [sec]")
        else:
            ax.spines.bottom.set_visible(False)
            ax.get_xaxis().set_visible(False)

    # Legend
    handles = [Patch(color=STATE_COLORS[k], alpha=alpha_band, label=INV_SCORE_MAP[k]) for k in range(4)]
    plt.subplots_adjust(right=0.80)
    axes[0].legend(
        handles=handles,
        loc="upper left",
        bbox_to_anchor=(1.02, 1.0),  # 右外
        borderaxespad=0.0,
        frameon=False,
        title="State"
    )
    plt.show()


In [None]:
display_waveforms(mneR_ens, unit="min")


In [None]:
display_waveforms(mneR_ens,span=(60 * 60, 60 * 80),unit="sec")

In [None]:
state = mneR_ens.copy().pick("state")

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,)

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,
    span=(6300, 6400),
    unit="time"
)

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,
    span=(6000, 6400)
)

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,
    span=(6100, 6200),
    unit="time"
)

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,
    span=(1990, 2500),
    unit="time"
)

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,
    span=(1990, 2500),
    unit="time"
)

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,
    span=(2330, 2400),
    unit="time"
)

In [None]:
display_waveforms_with_states_simple(
    mneR_orig,
    state,
    span=(450, 600),
    unit="time"
)

In [None]:
time = 6300
sec_to_time(time)

# 5. Check mp4
## 5-1. load .tsp

In [None]:


filepath = DATA_FOLDER / "20250917-001.tsp"
timestamps = np.loadtxt(filepath, dtype=np.float64)

# TimeStamp of the end of recording (computer clock - ms) = 541695087
# TimeStamp of the start of recording (computer clock - ms) = 514427685
start_ms = 514427685
df = pd.DataFrame({"timestamp_ms": timestamps})
df["elapsed_sec"] = (df["timestamp_ms"] - start_ms) / 1000.0
df["elapsed_time"] = pd.to_timedelta(df["elapsed_sec"], unit="s").astype(str)





In [None]:
df.tail()

The time 07:34:27.005000 differs from the time confirmed in the video viewer.
This is almost certainly due to the video not being at 30fps.


In [None]:
# add elapsed_time_30fps column
fps = 30.0
df["elapsed_sec_30fps"] = df.index / fps
df["elapsed_time_30fps"] = pd.to_timedelta(df["elapsed_sec_30fps"], unit="s").astype(str)

In [None]:
df.tail()

In [None]:
def get_misinterpreted_time(df: pd.DataFrame, *time_sec: int):
    for t in time_sec:
        idx = (df["elapsed_sec"] - t).abs().idxmin()
        res = df.loc[idx, "elapsed_time_30fps"]
        print(t, ":", res)


In [None]:
get_misinterpreted_time(df, *[6000 + i * 10 for i in range(50)])

In [None]:
get_misinterpreted_time(df, *[2000 + i * 10 for i in range(50)])

## 5-2. mp4 to mkv
mp4 is not able to contain .tsp data.So, convert to mkv and load .tsp to fix time stamp.
this part is done on shell, save script here.

this script had run on `sleepStateExperiment_from_video`

```.sh
mkdir -p ./out_mkv
for f in ./out/*.mp4; do
    base=$(basename "$f" .mp4)
    ffmpeg -i "$f" -c copy "./out_mkv/${base}.mkv"
done

mkdir -p ./out_mkv_corrected
```
and 
`sleepStateExperiment_from_video/correct_mkv.sh`

# 6. Run at all EEG channel
## 6-1. create mne.io.Raw object

In [None]:
def load_meta(path):
    meta = {}
    with open(path, encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            splitter = " = "
            if not line or splitter not in line:
                continue
            key, value = map(str.strip, line.split(splitter, 1))
            try:
                value = float(value)
            except ValueError:
                value = value
            meta[key] = value
    return meta


def load_data(dat_path, eeg_ch_idx):
    n_channels = 12
    dtype = np.dtype('<i2')  # Little Endian の uint16。合わなければ <i2 に切替

    array = np.fromfile(dat_path, dtype=dtype)
    if array.size % n_channels != 0:
        raise Exception("Warning: can't divede with N_CHANNELS")

    n_frames = array.size // n_channels
    data = array.reshape(n_frames, n_channels)

    # delete unused channels
    data = data[:, [0, 1, eeg_ch_idx]]
    n_channels = 3

    return data, n_channels


def convert_to_Volt(data):
    # data is to large to run following,
    # data = data.astype(np.float16)/(1 << 15) * 5
    # hence, use chunk
    chunk = 50_000_000
    scale = np.float16(1 << 15)
    data_v = np.empty_like(data, dtype=np.float16)
    for start in range(0, data.size, chunk):
        end = min(start + chunk, data.size)
        tmp_data = data[start:end].astype(np.float16)
        tmp_data = tmp_data / scale * 5 / (10 ** 3)
        data_v[start:end] = tmp_data

    return data_v



def rows_for_2gb_segment(
    n_ch: int,
    sampling_rate: int,
    write_dtype: np.dtype = np.dtype(np.float32),
    target_bytes: int = 2 * 1024**3,
    header_margin: int = 16 * 1024**2,
    round_sec: int = 5,
) -> int:

    # bytes per round_sec
    dtype_size = write_dtype.itemsize
    bytes_per_row = n_ch * dtype_size
    bytes_per_round_sec = bytes_per_row * int(sampling_rate * round_sec)

    # effective size budget (avoid header/metadata overhead)
    effective_max = target_bytes - header_margin

    # max rows that fit in the budget
    max_round_sec = effective_max // bytes_per_round_sec
    max_rows = max_round_sec * int(sampling_rate * round_sec)

    return max_rows


def create_and_save_mneRaw(data, save_path, seg_len):
    # Transpose the table
    tempA = data.transpose()
    n_t = tempA.shape[1]
    saved_paths = []
    for part, start in enumerate(range(0, n_t, seg_len)):
        end = min(start + seg_len, n_t)

        seg = np.ascontiguousarray(tempA[:, start:end], dtype=np.float32)

        raw = mne.io.RawArray(seg, mneInfo, first_samp=start)

        f_path = Path(f"{save_path}-{part}.fif")
        raw.save(f_path, overwrite=True, fmt="single")
        print(f"Saved {f_path}  |  samples={seg.shape[1]}  seconds≈{seg.shape[1]/SAMPLING_RATE:.1f}")

        saved_paths.append(f_path)

        del raw, seg
        gc.collect()
    return saved_paths



In [None]:
DATA_FOLDER = Path("/home/data/sleepStateExperiment_from_video/")

In [None]:
meta_path = DATA_FOLDER / "20250917-001.meta"
# load meta
meta = load_meta(meta_path)
START_TS = meta["TimeStamp of the start of recording (computer clock - ms)"]
END_TS = meta["TimeStamp of the end of recording (computer clock - ms)"]
SAMPLING_RATE = int(meta["Sampling rate"])
FILE_PATH = Path(meta["Filename"])

# load data
# we use
# ch-1, 2 as 2 EMG channel
# ch 3-8 as EEG channel

EEG_CH_NAMES = [
    "Skull-1", "Skull-2",
    "HPC-1", "HPC-2",
    "S1-1", "S1-2"
]
EEG_CH_IDX = list(range(2, 8))
dat_path = Path(DATA_FOLDER, FILE_PATH.name)
for ch_idx, ch_name in zip(EEG_CH_IDX, EEG_CH_NAMES):

    data, N_CHANNELS = load_data(dat_path, ch_idx)

    # convert to Volt
    data_v = convert_to_Volt(data)

    # create an mne object
    CHANNEL_NAMES = ["EMG-1", "EMG-2", ch_name]
    CHANNEL_TYPES = ["emg"] * 2 + ["eeg"] * 1
    MNE_DIR = Path(DATA_FOLDER, "mneRaw")
    # create an mne.Info object
    mneInfo = mne.create_info(CHANNEL_NAMES, SAMPLING_RATE, ch_types=CHANNEL_TYPES)

    # create an mne.io.Raw object
    save_path = Path(DATA_FOLDER, "analysis_tmp", FILE_PATH.stem + ch_name)
    seg_len = rows_for_2gb_segment(N_CHANNELS, SAMPLING_RATE)
    create_and_save_mneRaw(data_v, save_path, seg_len)



## 6-2. Preprocessing

In [None]:
DATA_FOLDER = Path("/home/data/sleepStateExperiment_from_video/")
def apply_band_path_filter(mneR:mne.io.Raw):
    sampling_rate = mneR.info["sfreq"]
    # Preapre an IIR filter
    EEG_D = mne.filter.construct_iir_filter(
        iir_params=dict(order=8, ftype='butter', output='sos'),
        f_pass=[0.3, 120],
        f_stop=None,  # Not used if ‘order’ is specified in iir_params
        sfreq=sampling_rate,
        btype='bandpass',
        phase='zero',
        return_copy=False
    )
    EMG_D = mne.filter.construct_iir_filter(
        iir_params=dict(order=8, ftype='butter', output='sos'),
        f_pass=[30, 9999],
        f_stop=None,  # Not used if ‘order’ is specified in iir_params
        sfreq=sampling_rate,
        btype='bandpass',
        phase='zero',
        return_copy=False
    )
    # Filter data
    mneR_butter = mneR.copy().filter(
        l_freq=None, h_freq=None,  # For FIR filter
        picks="eeg",  # All channels
        filter_length='auto',  # For FIR filter
        l_trans_bandwidth='auto', h_trans_bandwidth='auto',  # For FIR filter
        n_jobs=24,
        method='iir',
        iir_params=EEG_D,
        phase='zero',
        fir_window='hamming', fir_design='firwin',  # For FIR filter
        skip_by_annotation=('edge', 'bad_acq_skip'),
        pad='reflect_limited',  # For FIR filter
        verbose=None
    )
    mneR_butter = mneR_butter.filter(
        l_freq=None, h_freq=None,  # For FIR filter
        picks="emg",  # All channels
        filter_length='auto',  # For FIR filter
        l_trans_bandwidth='auto', h_trans_bandwidth='auto',  # For FIR filter
        n_jobs=24,
        method='iir',
        iir_params=EMG_D,
        phase='zero',
        fir_window='hamming', fir_design='firwin',  # For FIR filter
        skip_by_annotation=('edge', 'bad_acq_skip'),
        pad='reflect_limited',  # For FIR filter
        verbose=None
    )
    return mneR_butter


def apply_notch_filter(mneR):
    mneR_notch = mneR.notch_filter(
        freqs=50,                # 自動検出
        picks=["eeg", "emg"],
        method="spectrum_fit",
        filter_length="10s",       # 推定安定化
        mt_bandwidth=3.0,          # マルチテーパー帯域幅
        p_value=0.05,
        n_jobs=24
    )
    return mneR_notch




In [None]:
mne_file_list = glob.glob(str(Path(DATA_FOLDER, "analysis_tmp", "*.fif")))
mne_file_list = sorted(mne_file_list)
for mneR_path in mne_file_list:
    mneR_path = Path(mneR_path)
    # skip if already exits
    f_path = Path(
        mneR_path.parent.parent,
        "analysis_tmp", "band_notch", mneR_path.name)
    if os.path.exists(f_path):
        print("continue: ", f_path)
        continue
    # load file
    mneR = mne.io.read_raw_fif(mneR_path, preload=True)
    mneI = mneR.info.copy()

    mneR_butter = apply_band_path_filter(mneR)
    mneR_notch = apply_notch_filter(mneR_butter)

    mneR_notch.save(f_path, overwrite=True)



## 6-3. Sleepens
### 6-3-1. Run sleep Ensemble
see `python/sleepens/mySleepEns.py`

### 6-3-2. Check results

In [None]:
# to develop faster, use small file.
    # Path(DATA_FOLDER, "mneRaw", "20250917-001-0.fif")
file_name = "20250917-001-0{}.fif"

orig_path_list = glob.glob(
    str(Path(DATA_FOLDER,"analysis_tmp", "band_notch", "20250917-001*-*-0.fif"))
)
ens_path_list = glob.glob(
    str(Path(DATA_FOLDER, "sleepEnsOutput", "20250917-001*-*-0-predictions.fif"))
)

orig_path_list.sort()
ens_path_list.sort()

for orig_path, mneR_ens_path in zip(orig_path_list, ens_path_list):
    print(orig_path.split("/")[-1], mneR_ens_path.split("/")[-1])
    mneR_orig = mne.io.read_raw_fif(
        orig_path
    )
    mneR_ens = mne.io.read_raw_fif(
        mneR_ens_path
    )
    # mneI_ens = mneR_ens.info.copy()
    state = mneR_ens.pick("state")
    display_waveforms_with_states_simple(
        mneR_orig,
        state,
        span=(450, 600),
        unit="time"
    )

## 6-4. Sampling

In [None]:
def sample_intervals(total_time_s: float):
    segment_len = 180  # 3 min
    step = 5           # every 5 sec
    n_segments = 10

    possible_starts = list(range(0, int(total_time_s - segment_len) + 1, step))
    random.shuffle(possible_starts)

    selected = []
    for s in possible_starts:
        if all(abs(s - prev_s) >= segment_len for prev_s in selected):
            selected.append(s)
        if len(selected) == n_segments:
            break

    selected.sort()
    intervals = [(float(s), float(s + segment_len)) for s in selected]
    return intervals

In [None]:
total_time = 7 * 3600 + 34 * 60 
samples = sample_intervals(total_time)
csv_path = Path(DATA_FOLDER, "sleepEns_sampling", "sample_meta.csv")
with open(csv_path, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["id", "start", "end"])
    for i, (s, e) in enumerate(samples):
        writer.writerow([i, s, e])
print(samples)

In [None]:

file_name = "20250917-001{}.fif"

path_list = glob.glob(
    str(Path(DATA_FOLDER, "analysis_tmp", "band_notch", file_name.format("*-*-*")))
)
key = "orig"
slice_idx = -6

samples_df = pd.read_csv(DATA_FOLDER / "sleepEns_sampling/sample_meta.csv")
samples = [tuple(row) for row in samples_df[["start", "end"]].to_numpy()]



path_list.sort()
print("path_list: ")
print(path_list)

for i in range(6):
    paths = path_list[4 * i:4 * i + 4]

    raws = [mne.io.read_raw_fif(f, preload=True, verbose=True) for f in paths]

    raw = mne.concatenate_raws(raws, verbose=True)
    del raws

    # save sampled span
    name = paths[0].split("/")[-1][:slice_idx]
    for i, (tmin, tmax) in enumerate(samples):
        sub = raw.copy().crop(tmin=float(tmin), tmax=float(tmax), include_tmax=True)
        sub = sub.resample(sfreq=20000)
        out_path = Path(DATA_FOLDER, "sleepEns_sampling", f"{key}-{name}-sample{i}_raw.fif")
        sub.save(out_path.as_posix(), overwrite=True, verbose=True)
    del sub
    del raw


In [None]:

file_name = "20250917-001{}.fif"
path_list = glob.glob(
    str(Path(DATA_FOLDER, "sleepEnsOutput", file_name.format("*-*-*-predictions")))
)
key = "pred"
slice_idx = -18

samples_df = pd.read_csv(DATA_FOLDER / "sleepEns_sampling/sample_meta.csv")
samples = [tuple(row) for row in samples_df[["start", "end"]].to_numpy()]



path_list.sort()
print("path_list: ")
print(path_list)

for i in range(6):
    paths = path_list[4 * i:4 * i + 4]

    raws = [mne.io.read_raw_fif(f, preload=True, verbose=True) for f in paths]

    raw = mne.concatenate_raws(raws, verbose=True)
    del raws

    # save sampled span
    name = paths[0].split("/")[-1][:slice_idx]
    for i, (tmin, tmax) in enumerate(samples):
        sub = raw.copy().crop(tmin=float(tmin), tmax=float(tmax), include_tmax=True)
        sub = sub.resample(sfreq=20000)
        out_path = Path(DATA_FOLDER, "sleepEns_sampling", f"{key}-{name}-sample{i}_raw.fif")
        sub.save(out_path.as_posix(), overwrite=True, verbose=True)
    del sub
    del raw

In [None]:
# Hard-coded state mapping
SCORE_MAP = {'AW': 0, 'QW': 1, 'NREM': 2, 'REM': 3}
INV_SCORE_MAP = {v: k for k, v in SCORE_MAP.items()}
_STATE_COLORS = {
    0: (162, 53, 47),   # AW
    1: (236, 197, 72),  # QW
    2: (81, 158, 89),   # NREM
    3: (52, 71, 113),   # REM
}

# Normalize to 0–1 for Matplotlib
STATE_COLORS = {k: tuple(np.array(v, dtype=float) / 255.0) for k, v in _STATE_COLORS.items()}



def display_waveforms_with_states_simple_for_sample(
    raw: mne.io.Raw,
    labels_raw: mne.io.Raw,     # 1-channel Raw; integer labels 0..3 per sample
    title_text: str = "",
    start: float = 0.0,         # 追加: データ先頭の実時刻（秒）
    alpha_band: float = 0.15
):
    """
    Plot multi-channel waveforms and overlay 4-state categorical bands over the FULL duration.
    X-axis is always HH:MM:SS. The first data point is displayed at time `start` (seconds).

    Parameters
    ----------
    raw : mne.io.Raw
        Source time-series to plot.
    labels_raw : mne.io.Raw
        1-channel Raw containing per-sample integer labels (0..3).
        If sfreq differs from `raw`, it will be resampled to match.
    title_text : str
        Title for the figure.
    start : float
        X軸の開始時刻（秒）。データの0番目の要素がこの時刻として表示されます。
    alpha_band : float
        Alpha value for the background state bands.
    """
    # コメント: 入力チェック（labels_raw は1ch）
    if labels_raw.info["nchan"] != 1:
        raise ValueError("labels_raw must have exactly 1 channel (integer labels 0..3).")

    # コメント: サンプリング周波数を一致
    sfreq = raw.info["sfreq"]
    labs = labels_raw.copy()
    if not np.isclose(labs.info["sfreq"], sfreq):
        labs.resample(sfreq, npad="auto")

    # コメント: ラベルベクトルの取得と整形
    lab = labs.get_data()[0]
    lab = np.rint(lab).astype(int)
    lab = np.clip(lab, 0, 3)

    # コメント: 長さを raw に合わせる（最近傍）
    if lab.shape[0] != raw.n_times:
        idx = np.linspace(0, lab.shape[0] - 1, raw.n_times)
        lab = lab[np.rint(idx).astype(int)]

    # コメント: 全区間を取得（span 機能は削除）
    data  = raw.get_data()
    # コメント: x軸は常に「実時間（秒）」で、先頭に start を加算
    times = raw.times + float(start)
    labs_win = lab

    # コメント: ラベルの連続区間を抽出
    def contiguous_segments(lbl: np.ndarray):
        if lbl.size == 0:
            return []
        edges = np.where(np.diff(lbl) != 0)[0] + 1
        idxs = np.r_[0, edges, lbl.size]
        return [(idxs[i], idxs[i+1], int(lbl[idxs[i]])) for i in range(len(idxs)-1)]

    segs = contiguous_segments(labs_win)

    # コメント: 図の準備
    n_ch = raw.info["nchan"]
    fig_h = max(1.6, n_ch / 1.5)
    fig, axes = plt.subplots(
        nrows=n_ch, ncols=1,
        figsize=(8.5, fig_h),
        sharex=True, sharey=True,
        gridspec_kw={'height_ratios': [1]*n_ch, 'hspace': 0.0}
    )
    if n_ch == 1:
        axes = [axes]

    # 軸位置
    tick_step = 20.0
    ticks = np.arange(times[0], times[-1] + 1e-9, tick_step)  # 開始時刻起点の 20 秒刻み

    # コメント: x軸は常に HH:MM:SS 表示
    def _sec_to_hhmmss(x, pos):
        s = int(max(0, round(x)))
        h = s // 3600
        m = (s % 3600) // 60
        ss = s % 60
        return f"{h:02d}:{m:02d}:{ss:02d}"


    # コメント: 軸範囲
    if times.size == 0 or data.size == 0:
        xlim = (0, 0); ylim = (-1, 1)
    else:
        xlim = (times[0], times[-1])
        dmin, dmax = float(np.nanmin(data)), float(np.nanmax(data))
        if dmin == dmax:
            dmin, dmax = dmin - 1.0, dmax + 1.0
        ylim = (dmin, dmax)
    for ax in axes:
        ax.xaxis.set_major_formatter(FuncFormatter(_sec_to_hhmmss))
        ax.xaxis.set_major_locator(FixedLocator(ticks))
        ax.set_xlim(*xlim)
        ax.set_ylim(*ylim)


    # コメント: 背景バンドと波形
    for i, ax in enumerate(axes):
        for s, e, lab_val in segs:
            x0 = times[s]
            x1 = times[e-1] if e-1 < times.size else times[-1]
            ax.axvspan(x0, x1, color=STATE_COLORS[lab_val], alpha=alpha_band)

        ax.plot(times, data[i], linewidth=0.35, color='k')

        ax.spines.top.set_visible(False)
        ax.spines.right.set_visible(False)
        ax.set_yticks([])
        ax.set_ylabel(
            raw.ch_names[i],
            rotation=0, ha='right', va='center',
            rotation_mode='anchor'
        )
        if i == 0 and title_text:
            ax.set_title(title_text)
        if i == n_ch - 1:
            ax.set_xlabel("Time [HH:MM:SS]")
        else:
            ax.spines.bottom.set_visible(False)
            ax.get_xaxis().set_visible(False)

    # コメント: 凡例
    handles = [Patch(color=STATE_COLORS[k], alpha=alpha_band, label=INV_SCORE_MAP[k]) for k in range(4)]
    plt.subplots_adjust(right=0.80)
    axes[0].legend(
        handles=handles,
        loc="upper left",
        bbox_to_anchor=(1.02, 1.0),
        borderaxespad=0.0,
        frameon=False,
        title="State"
    )
    plt.show()

In [None]:
raw_path = DATA_FOLDER / "sleepEns_sampling/orig-20250917-001HPC-1-sample5_raw.fif"

raw = mne.io.read_raw_fif(raw_path)

label_path = DATA_FOLDER / "sleepEns_sampling/pred-20250917-001HPC-1-sample5_raw.fif"
label_raw = mne.io.read_raw_fif(label_path)



In [None]:
eeg_type_list = ["HPC-1", "HPC-2", "S1-1", "S1-2", "Skull-1", "Skull-2"]
base_path = str(DATA_FOLDER / "sleepEns_sampling/{0}-20250917-001{1}-sample{2}_raw.fif")

sample_meta_path = DATA_FOLDER / "sleepEns_sampling/sample_meta.csv"
sample_meta_df = pd.read_csv(sample_meta_path)
for sample_idx in range(10):
    for eeg_type in eeg_type_list:
        start_time = sample_meta_df.iloc[sample_idx, 1]
        _title = base_path.format("", f", EEG: {eeg_type},", ": "+str(sample_idx))
        _title = _title.split("/")[-1][1:].replace("_raw.fif","")
        title = "DATE: " + _title.replace("-sample", " sample")

        data_path = base_path.format("orig", eeg_type, sample_idx)
        label_path = base_path.format("pred", eeg_type, sample_idx,)

        data_raw = mne.io.read_raw_fif(data_path, verbose=False)
        label_raw = mne.io.read_raw_fif(label_path, verbose=False)

        display_waveforms_with_states_simple_for_sample(
            data_raw,
            label_raw.pick("state"),
            start=start_time,
            title_text=title
        )
del data_raw, label_raw,

In [None]:
data_raw.info["sfreq"]

# 7. Check results
## 7-1. check features one file
data are saved on `DATA_FOLDER / spectral_band`

In [None]:
data_path = DATA_FOLDER / "spectral_band/20250917-001HPC-1-0_eeg_bands.npz"
npz = np.load(data_path, allow_pickle=True)
print("Keys:", list(npz.keys()))


spec = npz["spec"]
freqs = npz["freqs"]
meta = npz["meta"]

print("spec shape:", spec.shape)
print("freqs:", freqs)
print("meta (head):", meta[:10])  # 最初の10個だけ確認


spec = npz["spec"]
freqs = npz["freqs"]
meta  = npz["meta"]

# コメント: spec → DataFrame化
df = pd.DataFrame(spec, columns=freqs)
display(df.head())
display(df.tail())

In [None]:


fig, ax = plt.subplots(figsize=(10, 4))

# コメント: 各バンドを時系列でプロット
df.plot(ax=ax, linewidth=0.8)

sns.set(style='ticks', context='talk')#全体スタイルの簡易設定
sns.despine()#XY軸のみ表示
# コメント: 軸ラベルとタイトル
ax.set_xlabel("Epoch index")
ax.set_ylabel("Power (a.u.)")
ax.set_title("Spectral Band Power Over Time")
# コメント: 凡例とレイアウト調整
ax.legend(title="Band")
plt.tight_layout()
sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', borderaxespad=1)#凡例位置
# 表示（notebookで確認）
plt.show()


In [None]:
DATA_FOLDER = Path("/home/data/sleepStateExperiment_from_video/")
raw = mne.io.read_raw_fif(DATA_FOLDER / "sleepEnsOutput/20250917-001HPC-1-0-predictions.fif")
raw.get_channel_types()

In [None]:
starts = raw.time_as_index(raw.annotations.onset)

# コメント: state チャンネルからイベントコードを取得（0も含む）
state = raw.get_data(picks="state")[0]
codes = state[starts].astype(int)

# コメント: MNE形式のevents配列 [sample, 0, code]
events = np.column_stack([starts, np.zeros_like(starts), codes])

# コメント: eventコードをDataFrame化（indexは0,1,2,...）
df_event = pd.DataFrame({
    "event": events[:, 2].astype(int)
})

df_event.head()


In [None]:
print(df_event.shape)
df_event.value_counts()

In [None]:
data_path = DATA_FOLDER / "spectral_band/20250917-001HPC-1-0_eeg_bands.npz"
npz = np.load(data_path, allow_pickle=True)
print("Keys:", list(npz.keys()))


spec = npz["spec"]
freqs = npz["freqs"]
meta = npz["meta"]

print("spec shape:", spec.shape)
print("freqs:", freqs)
print("meta (head):", meta[:10])  # 最初の10個だけ確認


spec = npz["spec"]
freqs = npz["freqs"]
meta  = npz["meta"]

# コメント: spec → DataFrame化
df_eeg = pd.DataFrame(spec, columns=freqs)
display(df_eeg.head())
display(df_eeg.tail())

In [None]:
DATA_FOLDER = Path("/home/data/sleepStateExperiment_from_video/")
data_path = DATA_FOLDER / "spectral_band/20250917-001HPC-1-0_emg_entropy.npz"
npz = np.load(data_path, allow_pickle=True)
print("Keys:", list(npz.keys()))


spec = npz["spec"]
freqs = npz["freqs"]
meta = npz["meta"]

print("spec shape:", spec.shape)
print("freqs:", freqs)
print("meta (head):", meta[:10])  # 最初の10個だけ確認


spec = npz["spec"]
freqs = npz["freqs"]
meta  = npz["meta"]
# コメント: spec → DataFrame化
df_emg = pd.DataFrame(spec, columns=freqs)
display(df_emg.head())
display(df_emg.tail())

In [None]:
df_orig = pd.concat([df_eeg, df_emg, df_event],axis=1)

In [None]:
df_orig.head()

In [None]:
from sklearn.preprocessing import MinMaxScaler

df = df_orig.copy()

# コメント: マッピング（必要なら併記）
event_map_id2name = {0: 'AW', 1: 'QW', 2: 'NREM', 3: 'REM'}

# コメント: 対象特徴量
features = ["DELTA", "THETA", "ALPHA", "BETA", "EMG ENTROPY"]

scaler = MinMaxScaler()

df[features] = scaler.fit_transform(df[features])

# コメント: 出力先ディレクトリ
outdir = "fig_event_stats"
os.makedirs(outdir, exist_ok=True)

# コメント: 統計量（平均・分散）をevent×featureで計算してCSV保存（任意）
stats_mean = df.groupby("event")[features].mean()
stats_var  = df.groupby("event")[features].var(ddof=1)
stats_std  = df.groupby("event")[features].std(ddof=1)

stats_mean.to_csv(os.path.join(outdir, "stats_mean.csv"))
stats_var.to_csv(os.path.join(outdir, "stats_variance.csv"))
stats_std.to_csv(os.path.join(outdir, "stats_std.csv"))

sns.set(style='ticks', context='poster')#全体スタイルの簡易設定
# コメント: 各eventごとに1枚ずつ図を作成（合計4枚）
for ev_id in sorted(df["event"].unique()):
    ev_name = event_map_id2name.get(int(ev_id), str(ev_id))
    sub = df[df["event"] == ev_id]

    # コメント: 各指標の平均と標準偏差（そのevent内での分布）
    means = sub[features].mean().values
    stds  = sub[features].std(ddof=1).values

    # コメント: 図の作成（バー + エラーバー + 「＊」）
    fig, ax = plt.subplots(figsize=(8, 5))
    x = np.arange(len(features))
    bars = ax.bar(x, means, yerr=stds, capsize=4)

    sns.despine()#XY軸のみ表示

    # コメント: 体裁
    ax.set_xticks(x)
    ax.set_xticklabels(features, rotation=0)
    ax.set_ylabel("Mean ± SD")
    ax.set_title(f"{ev_name} title??")
    ax.grid(axis="y", linestyle="--", alpha=0.3)

    fig.tight_layout()
    fig.show()
    


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
import itertools

# コメント: 事前定義済み
SCORE_MAP = {'AW': 0, 'QW': 1, 'NREM': 2, 'REM': 3}
INV_SCORE_MAP = {v: k for k, v in SCORE_MAP.items()}
_STATE_COLORS = {
    0: (162, 53, 47),   # AW
    1: (236, 197, 72),  # QW
    2: (81, 158, 89),   # NREM
    3: (52, 71, 113),   # REM
}
STATE_COLORS = {k: tuple(np.array(v, dtype=float)/255.0) for k, v in _STATE_COLORS.items()}

# コメント: 特徴量のカラム
features = ["DELTA", "THETA", "ALPHA", "BETA", "EMG ENTROPY"]

# =========================================
# 1) 各eventごとの平均・標準偏差を計算
# =========================================
stats_mean = df.groupby("event")[features].mean()
stats_std  = df.groupby("event")[features].std(ddof=1)

# =========================================
# 2) 有意差の有無を簡易に判定（例: ANOVA）
#    ※ 本格検定するなら多重比較に変更可能
# =========================================
pvals = {}
for feat in features:
    groups = [df.loc[df["event"] == ev, feat].values for ev in sorted(SCORE_MAP.values())]
    _, p = stats.f_oneway(*groups)
    pvals[feat] = p

# =========================================
# 3) 描画 (1枚のfigureに4系列を重ね)
# =========================================
fig, ax = plt.subplots(figsize=(9, 5))
x = np.arange(len(features))
width = 0.18  # 各バーの幅
offsets = np.linspace(-1.5*width, 1.5*width, 4)  # 4状態の横ずれ量

for i, ev_id in enumerate(sorted(SCORE_MAP.values())):
    means = stats_mean.loc[ev_id]
    stds  = stats_std.loc[ev_id]
    ax.bar(x + offsets[i],
           means,
           yerr=stds,
           width=width,
           capsize=3,
           label=INV_SCORE_MAP[ev_id],
           color=STATE_COLORS[ev_id],
           edgecolor='black',
           alpha=0.9)

# =========================================
# 4) 有意な特徴量に＊を付与
# =========================================
for i, feat in enumerate(features):
    if pvals[feat] < 0.05:
        ax.text(x[i], 
                max(stats_mean[feat]) + max(stats_std[feat])*0.5, 
                '*',
                ha='center', va='bottom', fontsize=14, color='black')

# =========================================
# 5) 体裁調整
# =========================================
ax.set_xticks(x)
ax.set_xticklabels(features)
ax.set_ylabel("Mean ± SD (normalized)")
ax.set_title("Feature comparison across sleep states")
ax.legend(title="Event")
ax.grid(axis="y", linestyle="--", alpha=0.3)
fig.tight_layout()

sns.despine()#XY軸のみ表示
sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', borderaxespad=1)#凡例位置

# 保存と表示
plt.show()


## 7-2. Check with all data

### 7-2-1. Create merged data

In [None]:
DATA_FOLDER = Path("/home/data/sleepStateExperiment_from_video/")

files = sorted(glob.glob(str(DATA_FOLDER / "sleepEnsOutput/20250917-001*-predictions.fif")))
df_list = []
for pred_path in files:
    pred_path = Path(pred_path)
    eeg_path = DATA_FOLDER / f"spectral_band/{pred_path.stem.replace('-predictions','')}_eeg_bands.npz"
    emg_path = DATA_FOLDER / f"spectral_band/{pred_path.stem.replace('-predictions','')}_emg_entropy.npz"


    #  event
    raw = mne.io.read_raw_fif(pred_path)
    starts = raw.time_as_index(raw.annotations.onset)

    # コメント: state チャンネルからイベントコードを取得（0も含む）
    state = raw.get_data(picks="state")[0]
    codes = state[starts].astype(int)

    # コメント: MNE形式のevents配列 [sample, 0, code]
    events = np.column_stack([starts, np.zeros_like(starts), codes])

    # コメント: eventコードをDataFrame化（indexは0,1,2,...）
    df_event = pd.DataFrame({
        "event": events[:, 2].astype(int)
    })

    # eeg
    npz = np.load(eeg_path, allow_pickle=True)
    df_eeg = pd.DataFrame(npz["spec"], columns=npz["freqs"])

    npz = np.load(emg_path, allow_pickle=True)
    df_emg = pd.DataFrame(npz["spec"], columns=npz["freqs"])

    # # concat
    df_orig = pd.concat([df_eeg, df_emg, df_event],axis=1)
    df_list.append(df_orig)



In [None]:
for df in df_list:
    display(df.head())

In [None]:
for i, eeg in enumerate(["HPC-1", "HPC-2", "S1-1", "S1-2", "Skull-1", "Skull-2"]):
    for df in df_list[4 * i:4 * (i + 1)]:
        df["eeg"] = eeg

In [None]:
df_list[0].head()

In [None]:
df_all = pd.concat(df_list, axis=0, ignore_index=True)
df_all

In [None]:
df_all.to_csv(DATA_FOLDER / "spectral_band/concat/20250917-001.csv")
df_all.head()

In [None]:
df_eeg = df_all.copy()

# 「eeg」列の値を置換
df_eeg["eeg"] = df_eeg["eeg"].replace({
    "HPC-1": "HPC",
    "HPC-2": "HPC",
    "S1-1": "S1",
    "S1-2": "S1",
    "Skull-1": "Skull",
    "Skull-2": "Skull"
})

display(df_eeg)


### 7-2-2. Violin plot, U test + multiple testing correction

In [None]:
# コメント: マッピングと色
SCORE_MAP = {'AW': 0, 'QW': 1, 'NR': 2, 'R': 3}
INV_SCORE_MAP = {v: k for k, v in SCORE_MAP.items()}

_STATE_COLORS = {
    0: (162, 53, 47),   # AW
    1: (236, 197, 72),  # QW
    2: (81, 158, 89),   # NREM
    3: (52, 71, 113),   # REM
}
STATE_COLORS = {k: tuple(np.array(v, dtype=float)/255.0) for k, v in _STATE_COLORS.items()}
palette = {str(ev): STATE_COLORS[ev] for ev in sorted(STATE_COLORS.keys())}  # コメント: intキー

features = ["ALPHA", "BETA", "DELTA", "THETA", "EMG ENTROPY"]
TITLE_MAP = {
    "ALPHA": "α",
    "BETA": "β",
    "THETA": "θ",
    "DELTA": "δ",
    "EMG ENTROPY": "EMG Entropy"
}

# コメント: アノテーション用ユーティリティ


def add_sig_bracket(ax, x1, x2, y):
    # コメント: x1とx2の上に括弧線とテキストを描く
    ax.plot([x1, x1, x2, x2], [y, y*1.01, y*1.01, y], lw=1.2, c="k")
    ax.text((x1 + x2)/2, y, "*", ha='center', va='bottom', fontsize=12, color='k')



order = sorted(SCORE_MAP.values())




p_val = 0.01



# EEGごとにデータを準備
eeg_groups = list(df_eeg.groupby("eeg"))

n_rows = len(eeg_groups)
n_cols = len(features)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(19, 6.5 * n_rows), sharey=False)
if n_rows == 1:
    axes = np.expand_dims(axes, axis=0)


for row_i, (eeg, df) in enumerate(eeg_groups):
    for col_i, (ax, feature) in enumerate(zip(axes[row_i], features)):
        sub = df[["event", feature]]

        sns.violinplot(
            data=sub,
            x="event", y=feature,
            palette=palette, inner="box",
            ax=ax,
        )



        # --- U検定（全ペア比較） ---
        pairs, pvals = [], []

        unique_groups = sorted(sub["event"].unique())

        for i, g1 in enumerate(unique_groups):
            for j, g2 in enumerate(unique_groups):
                if j <= i:
                    continue
                vals1 = sub.loc[sub["event"] == g1, feature]
                vals2 = sub.loc[sub["event"] == g2, feature]
                _, p = mannwhitneyu(vals1, vals2, alternative="two-sided")
                pairs.append((g1, g2))
                pvals.append(p)

        # --- 多重検定補正（Holm） ---
        rej, pvals_corr, _, _ = multipletests(pvals, alpha=p_val, method="holm")

        # --- 有意ペア抽出 ---
        sig_pairs = []
        for (g1, g2), p_corr, sig in zip(pairs, pvals_corr, rej):
            if sig:
                i, j = int(g1), int(g2)
                sig_pairs.append((min(i, j), max(i, j), p_corr))

        # --- アノテーション配置 ---
        y_max = np.nanmax(sub[feature].values)
        y_min = np.nanmin(sub[feature].values)
        span = (y_max - y_min) if np.isfinite(y_max - y_min) and (y_max - y_min) > 0 else (abs(y_max) + 1.0)
        base = y_max + 0.05 * span
        step = 0.08 * span

        for k, (i, j, p_adj) in enumerate(sorted(sig_pairs, key=lambda x: (x[1]-x[0], x[0], x[1]))):
            x1 = order.index(i)
            x2 = order.index(j)
            y = base + k * step
            add_sig_bracket(ax, x1, x2, y)

        # === 軸タイトル・装飾 ===
        ax.set_xticklabels([INV_SCORE_MAP[i] for i in order])
        ax.set_xlabel("")  # 軸ラベル自体は不要

        ax.set_title(TITLE_MAP[feature])
        if col_i == 0:
            ax.set_ylabel(r"Power [$\mu V^2$/Hz]")
            ax.text(-0.5, 0.5, eeg, transform=ax.transAxes, rotation=90,
            va='center', ha='right', fontweight='bold')
        elif col_i == 4:
            ax.set_ylabel("Entropy [a.u.]")
        else:
            ax.set_ylabel("")
        sns.despine(ax=ax)

# === 凡例と図全体調整 ===
sns.set(style='ticks', context='poster')
legend_handles = [
    matplotlib.lines.Line2D([0], [0], linewidth=0, label=f'n={sub.shape[0]}, *Adjusted P < {p_val}'),
]
leg = fig.legend( handles=legend_handles)
fig.suptitle("EEG (HPC, S1, Skull) Frequency Bands (α, β, θ, δ) and EMG Entropy", y=1.01)

sns.move_legend(fig, bbox_to_anchor=(1, 1), loc='upper right', borderaxespad=1) 
plt.tight_layout(w_pad=0)
plt.show()

### 7-2-3 Violin plot, DUNN test

In [None]:
# コメント: マッピングと色
SCORE_MAP = {'AW': 0, 'QW': 1, 'NR': 2, 'R': 3}
INV_SCORE_MAP = {v: k for k, v in SCORE_MAP.items()}

_STATE_COLORS = {
    0: (162, 53, 47),   # AW
    1: (236, 197, 72),  # QW
    2: (81, 158, 89),   # NREM
    3: (52, 71, 113),   # REM
}
STATE_COLORS = {k: tuple(np.array(v, dtype=float)/255.0) for k, v in _STATE_COLORS.items()}
palette = {str(ev): STATE_COLORS[ev] for ev in sorted(STATE_COLORS.keys())}  # コメント: intキー

features = ["ALPHA", "BETA", "DELTA", "THETA", "EMG ENTROPY"]
TITLE_MAP = {
    "ALPHA": "α",
    "BETA": "β",
    "THETA": "θ",
    "DELTA": "δ",
    "EMG ENTROPY": "EMG Entropy"
}

# コメント: アノテーション用ユーティリティ


def add_sig_bracket(ax, x1, x2, y):
    # コメント: x1とx2の上に括弧線とテキストを描く
    ax.plot([x1, x1, x2, x2], [y, y*1.01, y*1.01, y], lw=1.2, c="k")
    ax.text((x1 + x2)/2, y, "*", ha='center', va='bottom', fontsize=12, color='k')



order = sorted(SCORE_MAP.values())



p_val = 0.05



# EEGごとにデータを準備
eeg_groups = list(df_eeg.groupby("eeg"))

n_rows = len(eeg_groups)
n_cols = len(features)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 6.5 * n_rows), sharey=False)
if n_rows == 1:
    axes = np.expand_dims(axes, axis=0)


for row_i, (eeg, df) in enumerate(eeg_groups):
    for col_i, (ax, feature) in enumerate(zip(axes[row_i], features)):
        sub = df[["event", feature]]

        sns.violinplot(
            data=sub,
            x="event", y=feature,
            palette=palette, inner="box",
            ax=ax,
        )



        # --- Dunn検定（Holm補正） ---
        dunn_res = sp.posthoc_dunn(
            sub, 
            val_col=feature, 
            group_col="event", 
            p_adjust="holm"
        )

        # --- 有意ペア抽出 ---
        sig_pairs = []
        for i, g1 in enumerate(dunn_res.index):
            for j, g2 in enumerate(dunn_res.columns):
                if j <= i:
                    continue
                p_adj = dunn_res.iloc[i, j]
                if p_adj < p_val:
                    sig_pairs.append((int(g1), int(g2), p_adj))

        # --- アノテーション配置 ---
        y_max = np.nanmax(sub[feature].values)
        y_min = np.nanmin(sub[feature].values)
        span = (y_max - y_min) if np.isfinite(y_max - y_min) and (y_max - y_min) > 0 else (abs(y_max) + 1.0)
        base = y_max + 0.05 * span
        step = 0.08 * span

        for k, (i, j, p_adj) in enumerate(sorted(sig_pairs, key=lambda x: (x[1]-x[0], x[0], x[1]))):
            x1 = order.index(i)
            x2 = order.index(j)
            y = base + k * step
            add_sig_bracket(ax, x1, x2, y)

        # === 軸タイトル・装飾 ===
        ax.set_xticklabels([INV_SCORE_MAP[i] for i in order])
        ax.set_xlabel("")  # 軸ラベル自体は不要

        if row_i == 0:
            ax.set_title(TITLE_MAP[feature])
        if col_i == 0:
            ax.set_ylabel(r"Power [$\mu V^2$/Hz]")
            ax.text(-0.5, 0.5, eeg, transform=ax.transAxes, rotation=90,
            va='center', ha='right', fontweight='bold')
        elif col_i == 4:
            ax.set_ylabel("Entropy [a.u.]")
        else:
            ax.set_ylabel("")
        sns.despine(ax=ax)

# === 凡例と図全体調整 ===
sns.set(style='ticks', context='poster')
legend_handles = [
    matplotlib.lines.Line2D([0], [0], linewidth=0, label=f'n={sub.shape[0]}, *Adjusted P < {p_val}'),
]
leg = fig.legend( handles=legend_handles)
fig.suptitle("EEG (HPC, S1, Skull) Frequency Bands (α, β, θ, δ) and EMG Entropy", y=1.01)

sns.move_legend(fig, bbox_to_anchor=(1, 1), loc='upper right', borderaxespad=1) 
plt.tight_layout(w_pad=0)
plt.show()

### 7-2-4. Heat map

In [None]:
df_all = pd.read_csv(DATA_FOLDER / "spectral_band/concat/20250917-001.csv")

In [None]:
df_event = df_all.loc[:,["event","eeg"]]
df_event

In [None]:
dfs = []
for name, sub in df_event.groupby("eeg"):
    sub = sub.reset_index(drop=True).copy()
    sub.rename(columns={"event": f"event_{name}"}, inplace=True)
    dfs.append(sub[[f"event_{name}"]])  # event列だけ残す

# インデックスで横結合（行数が異なる場合はNaN補完）
df_wide = pd.concat(dfs, axis=1)
df_wide

In [None]:
cols = df_wide.columns

n_rows = len(df_wide)   # サンプル数
n_cols = len(cols)      # チャンネル数

# コメント: 出力配列を初期化（bool型, shape=(n_cols, n_cols, n_rows)）
matches = np.zeros((n_cols, n_cols, n_rows), dtype=bool)

# コメント: 各ペア間で event 値の一致を判定
for i, col_i in enumerate(cols):
    for j, col_j in enumerate(cols):
        # 各サンプルごとに一致しているか（Series同士の比較）
        matches[i, j, :] = (df_wide[col_i].values == df_wide[col_j].values)

# コメント: 結果確認
print(matches.shape)  # (n_cols, n_cols, n_rows)
print(matches.dtype)  # bool

# コメント: 一致率を算出する場合（オプション）
agreement = matches.mean(axis=2)
df_agree = pd.DataFrame(agreement, index=cols, columns=cols)
print(df_agree.round(3))


In [None]:
sum_matches = matches.sum(axis=2, keepdims=False)
sum_matches

In [None]:
(sum_matches / df_wide.shape[0]).mean()

In [None]:
import numpy as np
from statsmodels.stats.proportion import proportion_confint

# 前提: matches.shape = (n_col, n_col, n_row), keepdims=False で作った sum_matches を使う
# sum_matches = matches.sum(axis=2)  # shape = (n_col, n_col)

n_row = df_wide.shape[0]
n_col = sum_matches.shape[0]

# 平均一致率（質問の式と同じ）
p_mean = (sum_matches / n_row).mean()

# Wilson法のための k, n を定義（全ペア×全サンプルでの一致総数 / 試行総数）
k = int(sum_matches.sum())                 # 一致の総数（Trueの総数）
n = int(n_row * (n_col * n_col))           # 試行総数（各サンプルで n_col^2 個の比較）

# Wilson 95% CI
lo, hi = proportion_confint(k, n, alpha=0.05, method="wilson")
p = k / n

name = "ALL"
print(f"{name:8s}: p={p:.3f}, 95% CI=({lo:.3f}, {hi:.3f})")


In [None]:
# 対角線と重複除いたパターン
# sum_matches: shape = (n_col, n_col)
n_row = df_wide.shape[0]
n_col = sum_matches.shape[0]

# コメント: 対角を除いた上三角（または下三角）インデックスを取得
upper_idx = np.triu_indices(n_col, k=1)
print(upper_idx)
# コメント: 対角除外で有効な一致数と比較数を抽出
valid_sum = sum_matches[upper_idx]
print(valid_sum)
k = int(valid_sum.sum())                 # 一致の総数
n = int(len(valid_sum) * n_row)          # 試行総数（ペア数 × サンプル数）

# コメント: 平均一致率
p_hat = k / n

# Wilson法の信頼区間
ci_low, ci_high = proportion_confint(k, n, alpha=0.05, method="wilson")

print(f"ALL(offdiag): p={p_hat:.3f}, 95% CI=({ci_low:.3f}, {ci_high:.3f})")

In [None]:
df_agree = pd.DataFrame(
    sum_matches / df_wide.shape[0],
    index=cols,
    columns=cols
)
df_agree

In [None]:
_dic = {col:col[6:] for col in df_agree.columns}
df_tmp = df_agree.copy().rename(columns=_dic, index=_dic)
sns.set(style='ticks', context='talk')#全体スタイルの簡易設定
sns.heatmap(df_tmp, annot=True)

### 7-2-5. Check with bideo

In [None]:
# 対角線と重複/同じタイプのEEGチャンネルを除いたパターン
# sum_matches: shape = (n_col, n_col)
n_row = df_wide.shape[0]
n_col = sum_matches.shape[0]
cols = df_wide.columns
_upper_idx = np.triu_indices(n_col, k=1)
upper_idx = [[], []]
print(_upper_idx)
for i, j in zip(*_upper_idx):
    if cols[i][:-1] != cols[j][:-1]:
        upper_idx[0].append(i)
        upper_idx[1].append(j)
upper_idx = tuple([np.array(lis) for lis in upper_idx])
print(upper_idx)

# コメント: 対角除外で有効な一致数と比較数を抽出
valid_sum = sum_matches[upper_idx]
print(valid_sum)
k = int(valid_sum.sum())                 # 一致の総数
n = int(len(valid_sum) * n_row)          # 試行総数（ペア数 × サンプル数）

# コメント: 平均一致率
p_hat = k / n

# Wilson法の信頼区間
ci_low, ci_high = proportion_confint(k, n, alpha=0.05, method="wilson")

print(f"ALL(offdiag): p={p_hat:.3f}, 95% CI=({ci_low:.3f}, {ci_high:.3f})")

In [None]:

row_all_equal = (df_wide.nunique(axis=1) == 1)  # bool Series

# コメント: 合計と比率
agree_count = int(row_all_equal.sum())
agree_ratio = agree_count / len(df_wide)

print(f"全列一致の行数: {agree_count}")
print(f"全列一致の比率: {agree_ratio:.4f}")

In [None]:


from statsmodels.stats.proportion import proportion_confint
import numpy as np

# データ
hpc_1   = [4, 1, 2, 0, 1, 3, 0, 0, 0, 0]
hpc_2   = [4, 1, 2, 0, 1, 3, 0, 0, 0, 0]
s1_1    = [3, 1, 2, 0, 1, 4, 0, 0, 0, 0]
s1_2    = [4, 1, 2, 0, 1, 5, 0, 0, 0, 0]
skull_1 = [6, 1, 2, 0, 1, 4, 0, 0, 0, 0]
skull_2 = [5, 1, 2, 0, 1, 3, 0, 0, 0, 0]

groups = {
    "HPC_1": hpc_1,
    "HPC_2": hpc_2,
    "S1_1": s1_1,
    "S1_2": s1_2,
    "Skull_1": skull_1,
    "Skull_2": skull_2,
}

n_total = 36  # 各試行でのエポック数

results = {}
for name, counts in groups.items():
    k = np.sum(counts)
    n = len(counts) * n_total
    p_hat = k / n

    # Wilson法
    ci_low, ci_high = proportion_confint(k, n, alpha=0.05, method="wilson")
    results[name] = (p_hat, ci_low, ci_high)

for name, (p, lo, hi) in results.items():
    print(f"{name:8s}: p={p:.3f}, 95% CI=({lo:.3f}, {hi:.3f})")


In [None]:
names = list(results.keys())
probs = [results[n][0] for n in names]
lower = [results[n][0] - results[n][1] for n in names]  # 下側誤差
upper = [results[n][2] - results[n][0] for n in names]  # 上側誤差

x = np.arange(len(names))
fig, ax = plt.subplots(figsize=(8, 6))

ax.bar(x, probs, color="skyblue", yerr=[lower, upper], capsize=5)
ax.set_xticks(x)
ax.set_xticklabels(names, rotation=45, ha="right")
ax.set_ylabel("Misclassification rate")

# y軸範囲調整
ax.set_ylim(0, max(np.array(probs) + np.array(upper)) * 1.2)

# === 凡例のような注釈を右上に追加 ===
ax.text(
    0.95, 0.95, "mean ± 95% CI (n=360)",
    transform=ax.transAxes,
    ha="right", va="top", fontsize=24
)

plt.tight_layout()
sns.despine()
plt.show()
