In [10]:
!pip install numpy scipy scikit-learn mne 


Looking in indexes: https://pypi.org/simple, https://cmind.intern%40gmail.com:****@charlesyin.pkgs.visualstudio.com/Cmind%20Inc/_packaging/Cmind-internal-python-packages/pypi/simple/

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [12]:
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.signal import butter, lfilter
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score, confusion_matrix

In [13]:
train_cnt_file = "train/Competition_train_cnt.txt"
train_lab_file = "train/Competition_train_lab.txt"

train_cnt = np.loadtxt(train_cnt_file)
Y_train = np.loadtxt(train_lab_file)

print("Raw shape of train_cnt:", train_cnt.shape)


Raw shape of train_cnt: (17792, 3000)


In [14]:
n_trials_train = 278
n_channels = 64
n_samples = 3000

X_train = train_cnt.reshape(n_trials_train, n_channels, n_samples)

print("X_train final shape:", X_train.shape)
print("Y_train shape:", Y_train.shape)

X_train final shape: (278, 64, 3000)
Y_train shape: (278,)


In [15]:
test_file = "test.txt"
test_label_file = "test_label.txt"

test_cnt = np.loadtxt(test_file)
Y_test = np.loadtxt(test_label_file)

print("Raw shape of test_cnt:", test_cnt.shape)

Raw shape of test_cnt: (6400, 3000)


In [16]:
n_trials_test = 100
X_test = test_cnt.reshape(n_trials_test, n_channels, n_samples)

print("test final shape:", X_test.shape)

test final shape: (100, 64, 3000)


In [17]:
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=42)

print("Train Set Shape:", X_train.shape, Y_train.shape)
print("Validation Set Shape:", X_val.shape, Y_val.shape)

Train Set Shape: (222, 64, 3000) (222,)
Validation Set Shape: (56, 64, 3000) (56,)


In [18]:
def bandpass_filter(data, lowcut=8, highcut=30, fs=1000, order=5):
    """
    Apply a bandpass filter to the input data.

    Parameters:
        data (ndarray): EEG/ECoG data of shape (trials, channels, samples)
        lowcut (float): Lower cutoff frequency (default: 8 Hz)
        highcut (float): Upper cutoff frequency (default: 30 Hz)
        fs (int): Sampling frequency (default: 1000 Hz)
        order (int): Order of the filter (default: 5)

    Returns:
        ndarray: Filtered data (same shape as input)
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')

    filtered_data = lfilter(b, a, data, axis=-1)
    return filtered_data


In [19]:
X_train_filt = bandpass_filter(X_train)
X_val_filt = bandpass_filter(X_val)
X_test_filt = bandpass_filter(X_test)

In [20]:
def compute_covariance(trial_data):
    """
    Compute the covariance matrix for a given trial.
    
    Parameters:
        trial_data (ndarray): Shape (channels, samples)
    
    Returns:
        ndarray: Covariance matrix (channels x channels)
    """
    return np.cov(trial_data)

def csp_fit(X, y, n_components=3):
    """
    Fit CSP spatial filters.

    Parameters:
        X (ndarray): Shape (trials, channels, samples)
        y (ndarray): Labels (-1 or 1), shape (trials,)
        n_components (int): Number of CSP components to retain per class

    Returns:
        ndarray: CSP spatial filters (channels, 2 * n_components)
    """

    X_class1 = X[y == -1]
    X_class2 = X[y == 1]

    cov_class1 = np.mean([compute_covariance(x) for x in X_class1], axis=0)
    cov_class2 = np.mean([compute_covariance(x) for x in X_class2], axis=0)

    from scipy.linalg import eigh
    w, v = eigh(cov_class1, cov_class2)

    idx = np.argsort(w)[::-1]
    v = v[:, idx]
    filters = np.hstack([v[:, :n_components], v[:, -n_components:]])
    
    return filters

def csp_transform(X, filters):
    """
    Apply CSP transformation to extract log-variance features.

    Parameters:
        X (ndarray): Shape (trials, channels, samples)
        filters (ndarray): CSP spatial filters (channels, 2*n_components)

    Returns:
        ndarray: CSP features (trials, 2*n_components)
    """
    n_trials, _, _ = X.shape
    n_filters = filters.shape[1]
    features = np.zeros((n_trials, n_filters))

    for i in range(n_trials):
        projected = filters.T @ X[i] 
        var = np.var(projected, axis=1)
        features[i, :] = np.log(var)

    return features


In [21]:
csp_filters = csp_fit(X_train_filt, Y_train, n_components=3)

X_train_csp = csp_transform(X_train_filt, csp_filters)
X_val_csp = csp_transform(X_val_filt, csp_filters)
X_test_csp = csp_transform(X_test_filt, csp_filters)

clf = LinearDiscriminantAnalysis()
clf.fit(X_train_csp, Y_train)

pred_val = clf.predict(X_val_csp)
acc_val = accuracy_score(Y_val, pred_val)
print(f"Validation Accuracy: {acc_val:.4f}")


Validation Accuracy: 0.8929


In [22]:
pred_noadapt = clf.predict(X_test_csp)

acc_noadapt = accuracy_score(Y_test, pred_noadapt)
print(f"No Adaptation Test Accuracy: {acc_noadapt:.4f}")

print("Confusion Matrix (No Adaptation):")
print(confusion_matrix(Y_test, pred_noadapt))


No Adaptation Test Accuracy: 0.7600
Confusion Matrix (No Adaptation):
[[50  0]
 [24 26]]


In [25]:
# Define frequency bands (e.g., 4-8 Hz, 8-12 Hz, etc.)
frequency_bands = [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40)]

# Apply bandpass filter for each frequency band
X_train_filt = np.array([bandpass_filter(X_train, low, high, fs=1000) for (low, high) in frequency_bands])
X_val_filt = np.array([bandpass_filter(X_val, low, high, fs=1000) for (low, high) in frequency_bands])
X_test_filt = np.array([bandpass_filter(X_test, low, high, fs=1000) for (low, high) in frequency_bands])

# Reshape to combine frequency bands
# New shape: (trials, channels, samples, frequency_bands)
X_train_filt = np.transpose(X_train_filt, (1, 2, 3, 0))
X_val_filt = np.transpose(X_val_filt, (1, 2, 3, 0))
X_test_filt = np.transpose(X_test_filt, (1, 2, 3, 0))


In [27]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, BatchNormalization, AveragePooling2D, Dropout, Flatten, Dense, Lambda

# Define the TA-CSPNN model with a corrected activation function
model = Sequential([
    # Temporal Convolution
    Conv2D(filters=8, kernel_size=(1, 63), padding='same', input_shape=(n_channels, n_samples, len(frequency_bands))),
    BatchNormalization(),
    
    # Spatial Convolution (Depthwise)
    DepthwiseConv2D(kernel_size=(n_channels, 1), depth_multiplier=2, depthwise_constraint=tf.keras.constraints.max_norm(1.)),
    BatchNormalization(),
    
    # Custom Squaring Activation
    Lambda(lambda x: tf.math.square(x)),  # Square activation function
    
    # Average Pooling
    AveragePooling2D(pool_size=(1, n_samples)),
    Dropout(0.25),
    
    # Flatten and Fully Connected Layer
    Flatten(),
    Dense(units=4, activation='softmax')  # Assuming 4 classes for motor imagery
])

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])


In [29]:
# Convert labels from (-1, 1) to (0, 1)
Y_train = (Y_train + 1) // 2  # Convert -1 -> 0 and 1 -> 1
Y_val = (Y_val + 1) // 2
Y_test = (Y_test + 1) // 2


In [30]:
# Train the model
model.fit(X_train_filt, Y_train, epochs=100, batch_size=32, validation_data=(X_val_filt, Y_val))

# Evaluate the model
test_loss, test_accuracy = model.evaluate(X_test_filt, Y_test)
print(f'Adaptation Test Accuracy: {test_accuracy:.4f}')

Epoch 1/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 5s/step - accuracy: 0.3996 - loss: 1.2914 - val_accuracy: 0.4464 - val_loss: 1.6084
Epoch 2/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 4s/step - accuracy: 0.4896 - loss: 1.2964 - val_accuracy: 0.5536 - val_loss: 1.3844
Epoch 3/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 4s/step - accuracy: 0.4782 - loss: 1.2665 - val_accuracy: 0.5357 - val_loss: 1.2872
Epoch 4/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 4s/step - accuracy: 0.5242 - loss: 1.2247 - val_accuracy: 0.5179 - val_loss: 1.2665
Epoch 5/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 4s/step - accuracy: 0.5138 - loss: 1.2271 - val_accuracy: 0.4821 - val_loss: 1.3534
Epoch 6/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 4s/step - accuracy: 0.5286 - loss: 1.2979 - val_accuracy: 0.4821 - val_loss: 1.3465
Epoch 7/100
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━