In [8]:
import os
from io import BytesIO
from pathlib import Path

import pandas as pd
from google.cloud import storage


def load_environment(env_path: str = ".env") -> tuple[str, str]:
    """Load GCP credentials and bucket name from a .env file."""
    env_file = Path(env_path)
    if env_file.is_file():
        for line in env_file.read_text().splitlines():
            stripped = line.strip()
            if not stripped or stripped.startswith("#") or "=" not in stripped:
                continue
            key, value = stripped.split("=", 1)
            os.environ.setdefault(key.strip(), value.strip())

    credentials_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
    bucket_name = os.environ.get("GCP_BUCKET_NAME")

    if not credentials_path or not bucket_name:
        raise EnvironmentError(
            "GOOGLE_APPLICATION_CREDENTIALS and GCP_BUCKET_NAME must be set in the .env file."
        )

    credentials_file = Path(credentials_path)
    if not credentials_file.is_file():
        raise FileNotFoundError(f"Credentials file not found at {credentials_file}")

    # Ensure the env var is exported for the Google client to pick up.
    os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(credentials_file)
    return str(credentials_file), bucket_name


def load_emg_dataframe(env_path=".env"):
    _, bucket_name = load_environment(env_path)

    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob_path = "EMG-nature/Clean_df/emg_trial_level_df.pkl"

    blob = bucket.blob(blob_path)
    buffer = BytesIO()
    blob.download_to_file(buffer)
    buffer.seek(0)

    df = pd.read_pickle(buffer)
    return df



emg_df = load_emg_dataframe()


In [9]:
emg_df

Unnamed: 0,participant,day,block,trial_id,position,grasp,signal
0,1,1,1,1,2,3,"[[3.763498e-05, 1.9842508e-05, 9.071698e-06, 1..."
1,1,1,1,2,2,3,"[[1.0537988e-05, 1.153949e-05, 1.18090165e-05,..."
2,1,1,1,3,2,3,"[[1.6977565e-05, 1.9937088e-05, 2.1830994e-05,..."
3,1,1,1,4,2,3,"[[3.6807487e-06, 3.2587977e-06, 2.339907e-06, ..."
4,1,1,1,5,2,3,"[[1.5383765e-05, 1.8471881e-05, 1.6300444e-05,..."
...,...,...,...,...,...,...,...
4795,8,2,2,146,9,2,"[[1.59362e-05, 1.7032327e-05, 1.913987e-05, 2...."
4796,8,2,2,147,9,2,"[[1.3964933e-06, 1.4381011e-06, 5.3344643e-06,..."
4797,8,2,2,148,9,2,"[[9.107072e-06, 1.3961595e-05, 1.9139401e-05, ..."
4798,8,2,2,149,9,2,"[[4.061428e-05, 3.564699e-05, 3.1583953e-05, 2..."


In [10]:
emg_df['signal'][0]

array([[ 3.76349817e-05,  1.98425078e-05,  9.07169760e-06, ...,
        -7.89036676e-06, -1.49826537e-05, -1.43748375e-05],
       [ 2.48200813e-05,  2.62004924e-05,  2.57970714e-05, ...,
         2.36397518e-05,  2.26163229e-05,  2.11052029e-05],
       [ 7.70358292e-06,  9.89300042e-06,  1.07674778e-05, ...,
         1.03989842e-05,  9.96852668e-06,  9.72673752e-06],
       ...,
       [-8.37383777e-06, -1.24599992e-05, -1.28278034e-05, ...,
        -1.55725484e-05, -1.43470406e-05, -1.06217649e-05],
       [-1.57952236e-05, -1.61320277e-05, -1.29005557e-05, ...,
        -9.94938546e-06, -9.96866675e-06, -1.09153243e-05],
       [-1.30936751e-05, -1.76599133e-05, -1.61400549e-05, ...,
        -1.25840334e-05, -1.46363518e-05, -1.61670978e-05]],
      shape=(16, 9980), dtype=float32)

In [11]:
emg_df['signal'].apply(lambda x: x.shape).value_counts()

signal
(16, 10000)    1556
(16, 9980)     1505
(16, 10020)    1073
(16, 9960)      544
(16, 10040)      90
(16, 10060)      22
(16, 9940)        6
(16, 9920)        2
(16, 10080)       1
(16, 9840)        1
Name: count, dtype: int64

In [12]:
import numpy as np

TARGET_LEN = 10000

def fix_length(arr, target_len=TARGET_LEN):
    C, L = arr.shape

    if L == target_len:
        return arr

    if L < target_len:  # pad at the end
        pad_width = target_len - L
        return np.pad(arr, ((0,0), (0, pad_width)), mode='constant')

    else:  # L > target_len, trim at the end
        return arr[:, :target_len]

emg_df['signal_fixed'] = emg_df['signal'].apply(fix_length)


In [14]:
emg_df['signal_fixed'].apply(lambda x: x.shape).value_counts()

signal_fixed
(16, 10000)    4800
Name: count, dtype: int64

In [15]:
emg_df

Unnamed: 0,participant,day,block,trial_id,position,grasp,signal,signal_fixed
0,1,1,1,1,2,3,"[[3.763498e-05, 1.9842508e-05, 9.071698e-06, 1...","[[3.763498e-05, 1.9842508e-05, 9.071698e-06, 1..."
1,1,1,1,2,2,3,"[[1.0537988e-05, 1.153949e-05, 1.18090165e-05,...","[[1.0537988e-05, 1.153949e-05, 1.18090165e-05,..."
2,1,1,1,3,2,3,"[[1.6977565e-05, 1.9937088e-05, 2.1830994e-05,...","[[1.6977565e-05, 1.9937088e-05, 2.1830994e-05,..."
3,1,1,1,4,2,3,"[[3.6807487e-06, 3.2587977e-06, 2.339907e-06, ...","[[3.6807487e-06, 3.2587977e-06, 2.339907e-06, ..."
4,1,1,1,5,2,3,"[[1.5383765e-05, 1.8471881e-05, 1.6300444e-05,...","[[1.5383765e-05, 1.8471881e-05, 1.6300444e-05,..."
...,...,...,...,...,...,...,...,...
4795,8,2,2,146,9,2,"[[1.59362e-05, 1.7032327e-05, 1.913987e-05, 2....","[[1.59362e-05, 1.7032327e-05, 1.913987e-05, 2...."
4796,8,2,2,147,9,2,"[[1.3964933e-06, 1.4381011e-06, 5.3344643e-06,...","[[1.3964933e-06, 1.4381011e-06, 5.3344643e-06,..."
4797,8,2,2,148,9,2,"[[9.107072e-06, 1.3961595e-05, 1.9139401e-05, ...","[[9.107072e-06, 1.3961595e-05, 1.9139401e-05, ..."
4798,8,2,2,149,9,2,"[[4.061428e-05, 3.564699e-05, 3.1583953e-05, 2...","[[4.061428e-05, 3.564699e-05, 3.1583953e-05, 2..."


In [16]:
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin


class SignalExtractor(BaseEstimator, TransformerMixin):
    """Extract and reshape EMG signals from a DataFrame column."""

    def __init__(self, signal_col: str = "signal_fixed"):
        self.signal_col = signal_col

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        signals = []
        for idx, item in enumerate(X[self.signal_col].to_list()):
            arr = np.asarray(item, dtype=np.float32)
            if arr.shape != (16, 10000):
                raise ValueError(
                    f"Row {idx} in '{self.signal_col}' has shape {arr.shape}, expected (16, 10000)."
                )
            signals.append(arr.T[:, :, None])  # (10000, 16, 1)
        return np.stack(signals, axis=0)


class LabelExtractor(BaseEstimator, TransformerMixin):
    """Extract grasp labels and convert to zero-based integers."""

    def __init__(self, label_col: str = "grasp"):
        self.label_col = label_col

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        labels = X[self.label_col].astype(int).to_numpy() - 1
        if labels.min() < 0 or labels.max() > 5:
            raise ValueError("Grasp labels must be in the range 1-6 before conversion to 0-5.")
        return labels


class EMGNormalizer(BaseEstimator, TransformerMixin):
    """Z-score normalize EMG signals per channel using training statistics."""

    def __init__(self, epsilon: float = 1e-8):
        self.epsilon = epsilon
        self.mean_ = None
        self.std_ = None

    def fit(self, X, y=None):
        # X shape: (N, time, channels, 1)
        self.mean_ = X.mean(axis=(0, 1), keepdims=True)
        self.std_ = X.std(axis=(0, 1), keepdims=True)
        self.std_ = np.where(self.std_ < self.epsilon, self.epsilon, self.std_)
        return self

    def transform(self, X):
        if self.mean_ is None or self.std_ is None:
            raise RuntimeError("EMGNormalizer must be fitted before calling transform().")
        return (X - self.mean_) / self.std_


In [30]:
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline

# Extract labels for stratification and splitting
labeler = LabelExtractor(label_col="grasp")
labels = labeler.transform(emg_df)

train_df, test_df, y_train, y_test = train_test_split(
    emg_df,
    labels,
    test_size=0.3,
    random_state=42,
    stratify=labels,
)

# Build preprocessing pipeline for signals
preprocess = Pipeline(
    steps=[
        ("signals", SignalExtractor(signal_col="signal_fixed")),
        ("normalize", EMGNormalizer()),
    ]
)

X_train = preprocess.fit_transform(train_df, y_train)
X_test = preprocess.transform(test_df)

y_train = labeler.transform(train_df)
y_test = labeler.transform(test_df)

print("X_train shape:", X_train.shape)
print("X_test shape:", X_test.shape)
print("y_train shape:", y_train.shape)
print("y_test shape:", y_test.shape)


X_train shape: (3360, 10000, 16, 1)
X_test shape: (1440, 10000, 16, 1)
y_train shape: (3360,)
y_test shape: (1440,)
