In [None]:
# ============================================================
# Sub-1 ONLY (subject_1) â€” robust, no dimension errors
# Keeps your published cleaning/filtering/window settings SAME
# Exports: eeg_sub1.csv, emg_sub1.csv, labels_sub1.csv, combined_sub1.csv
# ============================================================

import os
import re
import warnings
import numpy as np
import pandas as pd
from scipy.io import loadmat
from scipy.signal import butter, sosfilt, filtfilt, iirnotch, resample
from sklearn.preprocessing import MinMaxScaler

# ----------------------------
# CONFIG (published settings)
# ----------------------------
SUBJECT_ID = 11
NO_GESTURE = 7

EMG_FS = 200
EEG_FS = 250

NOTCH_FREQ = 60
QUALITY_FACTOR = 30

EMG_FC, EMG_FH = 5, 50
EEG_FC, EEG_FH = 5, 50

ORDER = 4
WINDOW_TIME_MS = 1000
OVERLAP_PERCENT = 80
TARGET_FS = 200

EEG_BASE = "/home/tsultan1/paper-2/Dataset-2/EEG_DATA/EEG_DATA/data"
EMG_BASE = "/home/tsultan1/paper-2/Dataset-2/EMG_DATA/EMG_DATA/data"

OUT_DIR = "/home/tsultan1/paper-2/Dataset-2/final_exports-sub11"
os.makedirs(OUT_DIR, exist_ok=True)

EXPECTED_EEG_CH = 8
EXPECTED_EMG_CH = 8
FORCE_FIXED_CHANNELS = True     # prevents crashes due to odd channel counts
DO_AUGMENT = False              # keep False for published-style exports

# ============================================================
# YOUR ORIGINAL FILTERING FUNCTIONS (kept same)
# ============================================================

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    sos = butter(order, [low, high], analog=False, btype="bandpass", output="sos")
    return sos

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    sos = butter_bandpass(lowcut, highcut, fs, order=order)
    data_t = data.T
    filtered_t = np.zeros_like(data_t)
    for i in range(data_t.shape[0]):
        filtered_t[i, :] = sosfilt(sos, data_t[i, :])
    return filtered_t.T

def mains_removal(data, fs, notch_freq, quality_factor):
    b, a = iirnotch(notch_freq, quality_factor, fs)
    data_t = data.T
    try:
        filtered_t = filtfilt(b, a, data_t, axis=1, method="gust")
    except Exception:
        filtered_t = filtfilt(b, a, data_t, axis=1)
    return filtered_t.T

def preprocess_data(data, fs, notch_freq, quality_factor, lowcut, highcut, order, target_fs=None):
    notched_data = mains_removal(data, fs=fs, notch_freq=notch_freq, quality_factor=quality_factor)
    filtered_data = butter_bandpass_filter(notched_data, lowcut=lowcut, highcut=highcut, fs=fs, order=order)
    if target_fs and fs != target_fs:
        num_samples = int(data.shape[1] * target_fs / fs)
        downsampled_data = resample(filtered_data, num=num_samples, axis=1)
        return downsampled_data
    return filtered_data

def truncate_to_min_length(data1, data2):
    min_length = min(data1.shape[1], data2.shape[1])
    return data1[:, :min_length], data2[:, :min_length]

def window_with_overlap(data, sampling_frequency, window_time, overlap, no_channel):
    samples_per_window = int(sampling_frequency * (window_time / 1000))
    step_size = int(samples_per_window * (1 - overlap / 100))
    step_size = max(step_size, 1)
    if data.shape[1] < samples_per_window:
        return np.zeros((0, no_channel, samples_per_window), dtype=data.dtype)
    num_windows = (data.shape[1] - samples_per_window) // step_size + 1
    windows = np.zeros((num_windows, no_channel, samples_per_window), dtype=data.dtype)
    for i in range(num_windows):
        start = i * step_size
        end = start + samples_per_window
        windows[i] = data[:, start:end]
    return windows

# ============================================================
# Robust MAT loading (prevents dimension errors)
# ============================================================

def resolve_mat_root(base_dir: str) -> str:
    cand = os.path.join(base_dir, "mat_data")
    return cand if os.path.isdir(cand) else base_dir

EEG_ROOT = resolve_mat_root(EEG_BASE)
EMG_ROOT = resolve_mat_root(EMG_BASE)

def load_mat_data(filepath: str) -> np.ndarray:
    mat = loadmat(filepath)
    if "data" in mat:
        arr = mat["data"]
    else:
        arr = None
        for k, v in mat.items():
            if k.startswith("__"):
                continue
            if isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
                arr = v
                break
        if arr is None:
            raise ValueError(f"No numeric data found in {filepath}")

    arr = np.asarray(arr)
    arr = np.squeeze(arr)

    if arr.ndim == 1:
        arr = arr.reshape(1, -1)
    if arr.ndim != 2:
        arr = arr.reshape(arr.shape[0], -1)

    if np.iscomplexobj(arr):
        arr = np.real(arr)

    return arr.astype(np.float64, copy=False)

def ensure_channels_by_time(arr: np.ndarray, expected_ch: int, force_fixed: bool) -> np.ndarray:
    # prefer exact match in either orientation
    if arr.shape[0] == expected_ch:
        x = arr
    elif arr.shape[1] == expected_ch:
        x = arr.T
    else:
        # heuristic: channels closer to expected
        if abs(arr.shape[0] - expected_ch) > abs(arr.shape[1] - expected_ch):
            x = arr.T
        else:
            x = arr

    ch, t = x.shape
    if ch != expected_ch:
        if not force_fixed:
            raise ValueError(f"Channel mismatch: got {ch}, expected {expected_ch}")
        if ch > expected_ch:
            x = x[:expected_ch, :]
        else:
            pad = np.zeros((expected_ch - ch, t), dtype=x.dtype)
            x = np.vstack([x, pad])

    return x

def files_for_subject_gesture(base_path: str, subject_id: int, gesture_idx_1based: int):
    subj_dir = os.path.join(base_path, f"subject_{subject_id}")
    if not os.path.isdir(subj_dir):
        return []
    all_files = os.listdir(subj_dir)
    pat = re.compile(rf"G{gesture_idx_1based}", re.IGNORECASE)
    return [
        os.path.join(subj_dir, f)
        for f in all_files
        if f.lower().endswith(".mat") and pat.search(f)
    ]

def stack_trials_timewise(file_list, expected_ch: int, force_fixed: bool):
    chunks = []
    for fp in sorted(file_list):
        try:
            arr = load_mat_data(fp)
            arr = ensure_channels_by_time(arr, expected_ch, force_fixed)
            if arr.shape[1] > 0:
                chunks.append(arr)
        except Exception as e:
            warnings.warn(f"Skipping file: {fp} | reason: {e}")
            continue
    if not chunks:
        return None
    # keeps your original behavior (column_stack concatenates time)
    return np.column_stack(chunks)

# ============================================================
# Optional augmentation (OFF by default)
# ============================================================

def add_noise(x, noise_level=0.01):
    return x + np.random.normal(0, noise_level, size=x.shape)

def scale_data(x, scale_factor=1.1):
    return x * scale_factor

def time_shift(x, shift=5):
    return np.roll(x, shift, axis=0)

def flip_data(x):
    return -x

def augment_rows(Xeeg_2d, Xemg_2d, y, subj):
    eeg_aug, emg_aug, y_aug, subj_aug = [], [], [], []
    for i in range(Xeeg_2d.shape[0]):
        e = Xeeg_2d[i]
        m = Xemg_2d[i]
        lab = int(y[i])
        sid = int(subj[i])

        eeg_aug.append(e); emg_aug.append(m); y_aug.append(lab); subj_aug.append(sid)
        eeg_aug.append(add_noise(e)); emg_aug.append(add_noise(m)); y_aug.append(lab); subj_aug.append(sid)
        eeg_aug.append(scale_data(e)); emg_aug.append(scale_data(m)); y_aug.append(lab); subj_aug.append(sid)
        eeg_aug.append(time_shift(e, shift=5)); emg_aug.append(time_shift(m, shift=5)); y_aug.append(lab); subj_aug.append(sid)
        eeg_aug.append(flip_data(e)); emg_aug.append(flip_data(m)); y_aug.append(lab); subj_aug.append(sid)

    return (
        np.asarray(eeg_aug, dtype=np.float32),
        np.asarray(emg_aug, dtype=np.float32),
        np.asarray(y_aug, dtype=np.int64),
        np.asarray(subj_aug, dtype=np.int64),
    )

# ============================================================
# Sub-1 pipeline
# ============================================================

def preprocess_subject(subject_id: int):
    X_emg_list, X_eeg_list, y_list, subj_list = [], [], [], []
    total_windows = 0

    for g0 in range(NO_GESTURE):     # labels 0..6
        g1 = g0 + 1                  # file pattern G1..G7

        emg_files = files_for_subject_gesture(EMG_ROOT, subject_id, g1)
        eeg_files = files_for_subject_gesture(EEG_ROOT, subject_id, g1)
        if not emg_files or not eeg_files:
            continue

        emg_raw = stack_trials_timewise(emg_files, EXPECTED_EMG_CH, FORCE_FIXED_CHANNELS)
        eeg_raw = stack_trials_timewise(eeg_files, EXPECTED_EEG_CH, FORCE_FIXED_CHANNELS)
        if emg_raw is None or eeg_raw is None:
            continue

        emg_pp = preprocess_data(emg_raw, EMG_FS, NOTCH_FREQ, QUALITY_FACTOR, EMG_FC, EMG_FH, ORDER, TARGET_FS)
        eeg_pp = preprocess_data(eeg_raw, EEG_FS, NOTCH_FREQ, QUALITY_FACTOR, EEG_FC, EEG_FH, ORDER, TARGET_FS)

        emg_pp, eeg_pp = truncate_to_min_length(emg_pp, eeg_pp)

        emg_w = window_with_overlap(emg_pp, TARGET_FS, WINDOW_TIME_MS, OVERLAP_PERCENT, emg_pp.shape[0])
        eeg_w = window_with_overlap(eeg_pp, TARGET_FS, WINDOW_TIME_MS, OVERLAP_PERCENT, eeg_pp.shape[0])

        if emg_w.shape[0] == 0 or eeg_w.shape[0] == 0:
            continue

        n = min(emg_w.shape[0], eeg_w.shape[0])
        emg_w = emg_w[:n].astype(np.float32)
        eeg_w = eeg_w[:n].astype(np.float32)

        X_emg_list.append(emg_w)
        X_eeg_list.append(eeg_w)
        y_list.extend([g0] * n)
        subj_list.extend([subject_id] * n)
        total_windows += n

    if not X_emg_list or not X_eeg_list:
        raise RuntimeError(f"No windows produced for subject_{subject_id}. Check folder/files.")

    X_emg = np.vstack(X_emg_list)  # (N, 8, T)
    X_eeg = np.vstack(X_eeg_list)  # (N, 8, T)
    y = np.asarray(y_list, dtype=np.int64)
    subj = np.asarray(subj_list, dtype=np.int64)

    # safety
    n = min(len(y), X_emg.shape[0], X_eeg.shape[0], len(subj))
    return X_emg[:n], X_eeg[:n], y[:n], subj[:n], total_windows

def export_subject_csvs(X_emg, X_eeg, y, subj):
    Xeeg_2d = X_eeg.reshape(X_eeg.shape[0], -1)
    Xemg_2d = X_emg.reshape(X_emg.shape[0], -1)

    eeg_scaler = MinMaxScaler()
    emg_scaler = MinMaxScaler()
    Xeeg_norm = eeg_scaler.fit_transform(Xeeg_2d).astype(np.float32)
    Xemg_norm = emg_scaler.fit_transform(Xemg_2d).astype(np.float32)

    if DO_AUGMENT:
        Xeeg_final, Xemg_final, y_final, subj_final = augment_rows(Xeeg_norm, Xemg_norm, y, subj)
    else:
        Xeeg_final, Xemg_final, y_final, subj_final = Xeeg_norm, Xemg_norm, y, subj

    eeg_path = os.path.join(OUT_DIR, "eeg_sub1.csv")
    emg_path = os.path.join(OUT_DIR, "emg_sub1.csv")
    lab_path = os.path.join(OUT_DIR, "labels_sub1.csv")
    comb_path = os.path.join(OUT_DIR, "combined_sub1.csv")

    pd.DataFrame(Xeeg_final).to_csv(eeg_path, index=False)
    pd.DataFrame(Xemg_final).to_csv(emg_path, index=False)
    pd.DataFrame({"subject_id": subj_final, "Label": y_final}).to_csv(lab_path, index=False)

    eeg_cols = [f"eeg_{i}" for i in range(Xeeg_final.shape[1])]
    emg_cols = [f"emg_{i}" for i in range(Xemg_final.shape[1])]

    df_eeg = pd.DataFrame(Xeeg_final, columns=eeg_cols)
    df_emg = pd.DataFrame(Xemg_final, columns=emg_cols)

    df_eeg.insert(0, "Label", y_final)
    df_eeg.insert(0, "subject_id", subj_final)
    df_comb = pd.concat([df_eeg, df_emg], axis=1)
    df_comb.to_csv(comb_path, index=False)

    print("\n[SAVED SUB-1]")
    print(" ", eeg_path)
    print(" ", emg_path)
    print(" ", lab_path)
    print(" ", comb_path)
    print(f"\nRows: {len(y_final)} | label set: {np.unique(y_final)}")

if __name__ == "__main__":
    X_emg, X_eeg, y, subj, total = preprocess_subject(SUBJECT_ID)
    print(f"[OK] subject_{SUBJECT_ID}: windows={total}")
    export_subject_csvs(X_emg, X_eeg, y, subj)


[OK] subject_11: windows=1106

[SAVED SUB-1]
  /home/tsultan1/paper-2/Dataset-2/final_exports-sub11/eeg_sub1.csv
  /home/tsultan1/paper-2/Dataset-2/final_exports-sub11/emg_sub1.csv
  /home/tsultan1/paper-2/Dataset-2/final_exports-sub11/labels_sub1.csv
  /home/tsultan1/paper-2/Dataset-2/final_exports-sub11/combined_sub1.csv

Rows: 1106 | label set: [0 1 2 3 4 5 6]
