# Import Libraries and Setup

In [1]:
#주의 : 2048Hz!! 5s input window!!
#therefore, must change accordingly!

import os
import numpy as np
import glob
import pandas as pd 
import utils
import matplotlib.pyplot as plt
import sys
import torch
from omegaconf import OmegaConf
import torch.nn.functional as F

from scipy.stats import ttest_ind
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.decomposition import PCA

# 전체 배열을 출력하도록 설정
np.set_printoptions(threshold=np.inf)

#### SETUP ####
include_fixation = True # False

org_data_path = "/global/cfs/cdirs/m4750/ECOG_AI/pilot_data/BrainBERT_preprocessed/task_data"
add_folder =  "with_fixation_flattened" if include_fixation else "no_fixation_flattened"
data_path = os.path.join(org_data_path, add_folder)
sampling_freq = 2000

valid_split = 0.0
test_split = 0.35 # minimum I think since there's only 120 samples


# Baseline 1. 정천기 교수님 랩 분석 방법

In [3]:
def apply_stft(data, nperseg, clip_fs):
    f, t, Zxx = utils.get_stft_multi_channel(data, fs=sampling_freq, clip_fs=clip_fs, batch_dim=True, 
                                             nperseg=nperseg, noverlap=nperseg-50, normalizing="zscore", return_onesided=True)
    return f, t, Zxx

def apply_superlet(data):
    return utils.get_superlet_multi_channel(data, fs=sampling_freq, nperseg=400, noverlap=350,
                                            order_min=1, order_max=13, c_1=3, foi=np.arange(1,41), clip=5)

def flatten_and_normalize(data, using_transform, using_segment, time_to_take):
    if not using_transform:
        flattened = data.reshape(data.shape[0], -1)
        return (flattened - np.mean(flattened, axis=0)) / np.std(flattened, axis=0)
    elif using_transform and not using_segment:
        flattened = data[:,:,:,:time_to_take].reshape(data.shape[0], -1)
        return flattened

In [4]:
def load_data(data_path):
    labels_pd = pd.read_csv(os.path.join(data_path, "labels.csv")).sort_index()
    train_pd, _, _ = utils.pd_train_val_test_split(labels_pd, valid_split=0, test_split=0, shuffle_seed=42)
    return utils.arr_from_pd(data_path, train_pd)

def sample_classes(y, n_samples=20, seed=42):
    np.random.seed(seed)
    class_0_indices = np.where(y == 0)[0]
    class_1_indices = np.where(y == 1)[0]
    sampled_class_0_indices = np.random.choice(class_0_indices, n_samples, replace=False)
    sampled_class_1_indices = np.random.choice(class_1_indices, n_samples, replace=False)
    return sampled_class_0_indices, sampled_class_1_indices

def apply_transformations(X, indices, nperseg=4000, clip_fs=-1):
    X_selected = X[indices]
    f, t, X_stft = apply_stft(X_selected, nperseg=nperseg, clip_fs=clip_fs)
    X_bandpower = utils.get_bandpower(X_stft)
    return np.squeeze(X_bandpower, axis=1)

def perform_ttest_and_svm(X_bandpower_ttest, X_bandpower_svm, y_remaining, sampled_class_0_indices, bandname_list):
    t_stats, p_values = [], []
    split_data_ttest = np.split(X_bandpower_ttest, X_bandpower_ttest.shape[1], axis=1)
    split_data_svm = np.split(X_bandpower_svm, X_bandpower_svm.shape[1], axis=1)

    for i, data in enumerate(split_data_ttest):
        data = data.squeeze(axis=1)
        X4ttest_class_0 = data[:len(sampled_class_0_indices)]
        X4ttest_class_1 = data[len(sampled_class_0_indices):]

        if np.isnan(X4ttest_class_0).any() or np.isnan(X4ttest_class_1).any():
            print(f"Band: {bandname_list[i]} contains NaN values.")
            X4ttest_class_0 = np.nan_to_num(X4ttest_class_0)
            X4ttest_class_1 = np.nan_to_num(X4ttest_class_1)

        if np.var(X4ttest_class_0) == 0 or np.var(X4ttest_class_1) == 0:
            print(f"Band: {bandname_list[i]} has zero variance.")
            t_stats.append(np.nan)
            p_values.append(np.nan)
            continue

        mean_class_0 = np.mean(X4ttest_class_0, axis=1)
        mean_class_1 = np.mean(X4ttest_class_1, axis=1)
        t_stat, p_value = ttest_ind(mean_class_0, mean_class_1)
        t_stats.append(t_stat)
        p_values.append(p_value)

        print(f"Band: {bandname_list[i]}:")
        print(f"Class 0 mean: {np.mean(mean_class_0)}, Class 1 mean: {np.mean(mean_class_1)}")
        print("t-statistic:", t_stat)
        print("p-value:", p_value)

        if p_value < 0.05:
            data_svm = split_data_svm[i].squeeze(axis=1)
            X_train, X_test, y_train, y_test = train_test_split(data_svm, y_remaining, test_size=0.2, random_state=42, stratify=y_remaining)
            print(f"Train set class 0 count: {np.sum(y_train == 0)}")
            print(f"Train set class 1 count: {np.sum(y_train == 1)}")

            svm_model = SVC(kernel='rbf', C=1, gamma='scale', probability=True)
            svm_model.fit(X_train, y_train)
            y_pred = svm_model.predict(X_test)
            y_prob = svm_model.predict_proba(X_test)[:, 1]

            print(f"Test set class 0 count: {np.sum(y_test == 0)}")
            print(f"Test set class 1 count: {np.sum(y_test == 1)}")
            print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
            print(f"AUC: {roc_auc_score(y_test, y_prob)}")
            print(f"Mean of test prediction: {np.mean(y_pred)}")

    return t_stats, p_values

def main(data_path):
    X, y = load_data(data_path)
    sampled_class_0_indices, sampled_class_1_indices = sample_classes(y)
    remaining_indices = np.setdiff1d(np.arange(len(y)), np.concatenate([sampled_class_0_indices, sampled_class_1_indices]))
    X_remaining, y_remaining = X[remaining_indices], y[remaining_indices]

    X_bandpower_ttest = apply_transformations(X, np.concatenate([sampled_class_0_indices, sampled_class_1_indices]))
    X_bandpower_svm = apply_transformations(X, remaining_indices)

    bands = {
        "delta": (0.5, 4),
        "theta": (4, 8),
        "alpha": (8, 13),
        "beta": (13, 30),
        "low_gamma": (30, 60),
        "high_gamma": (60, 150)
    }
    bandname_list = list(bands.keys())

    t_stats, p_values = perform_ttest_and_svm(X_bandpower_ttest, X_bandpower_svm, y_remaining, sampled_class_0_indices, bandname_list)

if __name__ == "__main__":
    main(data_path)

Band: delta:
Class 0 mean: 1.5532223558515687e-16, Class 1 mean: -1.6583593137463721e-15
t-statistic: 1.8654113882911778
p-value: 0.06985857077650229
Band: theta:
Class 0 mean: 2.557581836832819e-18, Class 1 mean: -1.2439148024595997e-17
t-statistic: 0.5403996383305457
p-value: 0.5920749668281431
Band: alpha:
Class 0 mean: 1.1160357106179583e-17, Class 1 mean: -1.860059517696601e-18
t-statistic: 0.4149664366358154
p-value: 0.680500079791103
Band: beta:
Class 0 mean: -2.208820677264709e-17, Class 1 mean: -1.429920754229259e-17
t-statistic: -0.515151237836395
p-value: 0.6094312931015778
Band: low_gamma:
Class 0 mean: 2.6389594407320474e-17, Class 1 mean: -6.51020831193809e-18
t-statistic: 3.0535835804815643
p-value: 0.004116475915604346
Train set class 0 count: 22
Train set class 1 count: 42
Test set class 0 count: 6
Test set class 1 count: 10
Accuracy: 0.625
AUC: 0.5
Mean of test prediction: 1.0
Band: high_gamma:
Class 0 mean: 3.720119035393194e-18, Class 1 mean: 5.231417393521681e-18
t

# Baseline 2. BrainBERT

In [5]:
# 현재 작업 디렉토리의 상위 디렉토리를 sys.path에 추가
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

import models

# GPU 메모리 초기화 (필요하지 않음)
# torch.cuda.empty_cache()

def build_model(cfg):
    ckpt_path = cfg.upstream_ckpt
    init_state = torch.load(ckpt_path, map_location='cpu')  # CPU로 로드
    upstream_cfg = init_state["model_cfg"]
    upstream = models.build_model(upstream_cfg)
    return upstream

def load_model_weights(model, states, multi_gpu):
    if multi_gpu:
        model.module.load_weights(states)
    else:
        model.load_weights(states)

ckpt_path = "../pretrained_weights/stft_large_pretrained.pth"
cfg = OmegaConf.create({"upstream_ckpt": ckpt_path})
pretrained_model = build_model(cfg)
pretrained_model.to('cpu')  # 모델을 CPU로 이동

# 메모리 문제를 피하기 위해 map_location을 사용하여 모델 가중치 로드
init_state = torch.load(ckpt_path, map_location='cpu')  # CPU로 로드
load_model_weights(pretrained_model, init_state['model'], False)



In [10]:
def split_data(X, segments):
    split_points = np.cumsum([int(2000 * seg) for seg in segments])
    split_data = [X[:, :, start:end] for start, end in zip([0] + split_points[:-1].tolist(), split_points)]
    return split_data

def process_data_BrainBERT(data_path, segments=5, using_transform=False, transform=None, time_to_take=None, using_timewiseavg=False, pooling=None, PCA_ncomp=None):
    # Load labels and split into train, validation, and test sets
    labels_pd = pd.read_csv(os.path.join(data_path, "labels.csv")).sort_index()
    train_pd, valid_pd, test_pd = utils.pd_train_val_test_split(labels_pd, valid_split, test_split, shuffle_seed=42)

    # Convert dataframes to arrays
    X_train, y_train = utils.arr_from_pd(data_path, train_pd)
    X_valid, y_valid = utils.arr_from_pd(data_path, valid_pd)
    X_test, y_test = utils.arr_from_pd(data_path, test_pd)
    
    # Split data into segments
    train_segments = split_data(X_train, segments)
    test_segments = split_data(X_test, segments)

    # Apply transformations
    if using_transform:
        if transform == 'stft':
            train_segments = [apply_stft(segment, nperseg=400, clip_fs=40)[2] for segment in train_segments]
            test_segments = [apply_stft(segment, nperseg=400, clip_fs=40)[2] for segment in test_segments]
        elif transform == 'superlet':
            train_segments = [apply_superlet(segment) for segment in train_segments]
            test_segments = [apply_superlet(segment) for segment in test_segments]

    # Apply BrainBERT
    train_segments = [utils.make_ready_for_BrainBERT(segment, device='cpu') for segment in train_segments]
    test_segments = [utils.make_ready_for_BrainBERT(segment, device='cpu') for segment in test_segments]

    train_masks = [torch.zeros((segment.shape[:2])).bool().to('cpu') for segment in train_segments]
    test_masks = [torch.zeros((segment.shape[:2])).bool().to('cpu') for segment in test_segments]

    pretrained_model.eval()
    with torch.no_grad():
        train_outputs = [pretrained_model.forward(segment, mask, intermediate_rep=True) for segment, mask in zip(train_segments, train_masks)]
        test_outputs = [pretrained_model.forward(segment, mask, intermediate_rep=True) for segment, mask in zip(test_segments, test_masks)]

        if pooling == 'max':
            train_outputs = [F.max_pool1d(output.transpose(1, 2), kernel_size=10).transpose(1, 2) for output in train_outputs]
            test_outputs = [F.max_pool1d(output.transpose(1, 2), kernel_size=10).transpose(1, 2) for output in test_outputs]
        elif pooling == 'mean':
            train_outputs = [F.avg_pool1d(output.transpose(1, 2), kernel_size=10).transpose(1, 2) for output in train_outputs]
            test_outputs = [F.avg_pool1d(output.transpose(1, 2), kernel_size=10).transpose(1, 2) for output in test_outputs]
        
    train_segments = [utils.revert_BrainBERT_back(output) for output in train_outputs]
    test_segments = [utils.revert_BrainBERT_back(output) for output in test_outputs]

    if using_timewiseavg:
        train_segments = [segment.mean(axis=-1, keepdims=True) for segment in train_segments]
        test_segments = [segment.mean(axis=-1, keepdims=True) for segment in test_segments]
        if valid_split != 0:
            X_valid = X_valid.mean(axis=-1, keepdims=True)

    train_segments_flattened = [segment.reshape(segment.shape[0], -1) for segment in train_segments]
    test_segments_flattened = [segment.reshape(segment.shape[0], -1) for segment in test_segments]

    if PCA_ncomp is not None:
        train_segments_flattened_transformed = []
        test_segments_flattened_transformed = []
        for train_segment, test_segment in zip(train_segments_flattened, test_segments_flattened):
            pca = PCA(n_components=PCA_ncomp)
            train_segment_transformed = pca.fit_transform(train_segment)
            test_segment_transformed = pca.transform(test_segment)
            train_segments_flattened_transformed.append(train_segment_transformed)
            test_segments_flattened_transformed.append(test_segment_transformed)
        
        train_segments_flattened = train_segments_flattened_transformed
        test_segments_flattened = test_segments_flattened_transformed

    train_segments_flattened = [torch.tensor(segment) for segment in train_segments_flattened]
    test_segments_flattened = [torch.tensor(segment) for segment in test_segments_flattened]

    train_data = torch.cat(train_segments_flattened, dim=1)
    test_data = torch.cat(test_segments_flattened, dim=1)

    if valid_split != 0:
        X_valid_flattened = flatten_and_normalize(X_valid, using_transform, using_segment, time_to_take)
        return train_data, y_train, X_valid_flattened, y_valid, test_data, y_test
    else:
        return train_data, y_train, None, None, test_data, y_test

In [7]:
using_transform = True # True or False
transform = 'stft' # 'stft' or 'superlet'
time_to_take = 191 # How many points to take from the time axis when flattened
using_timewiseavg = False
pooling = ['max', 'mean', None] # 'max', 'mean', 'sum', None (kernel 사이즈는 위 process_data_BrainBERT 함수에서 수정해야 함)
segments = [5] # 5초를 몇 초 간격으로 자를 것인지
PCA_ncomp = None # PCA를 쓴다면 몇 개의 component를 사용할 것인지, None이면 PCA 사용 안 함

for pooling_method in pooling:
    print(f"Using {transform} transform with {time_to_take} time points") if using_transform else print("Not using transform")
    print("Using time-wise average") if using_timewiseavg else print("Not using time-wise average")
    print(f"Using {pooling_method} pooling") if pooling_method != None else print("Not using pooling")
    print(f"5 sec divided by {segments} sec")
    print("="*50)

    X_train, y_train, X_valid, y_valid, X_test, y_test = process_data_BrainBERT(data_path, segments, using_transform, transform, time_to_take, using_timewiseavg, pooling_method, PCA_ncomp)
    print(f"X_train: {X_train.shape} | X_test: {X_test.shape}")
    print("X_valid:", X_valid.shape) if X_valid is not None else print("No validation set")
    print("="*50)

    clf_models = [
        LogisticRegression(max_iter=10000, random_state=42), 
        SVC(kernel='rbf', C=1, gamma='scale', probability=True, random_state=42),
        MLPClassifier(hidden_layer_sizes=(1024, 512, 256, 128), validation_fraction=0.3,
                        learning_rate='adaptive', alpha=0.001, batch_size=32,
                        early_stopping=True, max_iter=1000, random_state=42)    
    ]

    for model in clf_models:
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        print(model)
        print("Mean prediction:", y_pred.mean())
        print("Accuracy:", accuracy_score(y_test, y_pred))
        print(classification_report(y_test, y_pred, zero_division=0))
        print("ROC AUC:", roc_auc_score(y_test, y_pred))
        print("\n")

Using stft transform with 191 time points
Not using time-wise average
Using max pooling
5 sec divided by [5] sec
X_train: torch.Size([78, 14592]) | X_test: torch.Size([42, 14592])
No validation set
LogisticRegression(max_iter=10000, random_state=42)
Mean prediction: 0.7380952380952381
Accuracy: 0.5714285714285714
              precision    recall  f1-score   support

           0       0.27      0.23      0.25        13
           1       0.68      0.72      0.70        29

    accuracy                           0.57        42
   macro avg       0.48      0.48      0.47        42
weighted avg       0.55      0.57      0.56        42

ROC AUC: 0.47745358090185674


SVC(C=1, probability=True, random_state=42)
Mean prediction: 1.0
Accuracy: 0.6904761904761905
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.69      1.00      0.82        29

    accuracy                           0.69        42
   macro avg   



MLPClassifier(alpha=0.001, batch_size=32, early_stopping=True,
              hidden_layer_sizes=(1024, 512, 256, 128),
              learning_rate='adaptive', max_iter=1000, random_state=42,
              validation_fraction=0.3)
Mean prediction: 1.0
Accuracy: 0.6904761904761905
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.69      1.00      0.82        29

    accuracy                           0.69        42
   macro avg       0.35      0.50      0.41        42
weighted avg       0.48      0.69      0.56        42

ROC AUC: 0.5




In [8]:
using_transform = True # True or False
transform = 'stft' # 'stft' or 'superlet'
time_to_take = 191 # How many points to take from the time axis when flattened
using_timewiseavg = False
pooling = ['max', 'mean', None] # 'max', 'mean', 'sum', None (kernel 사이즈는 위 process_data_BrainBERT 함수에서 수정해야 함)
segments = [2, 1.5, 1.5] # 5초를 몇 초 간격으로 자를 것인지
PCA_ncomp = None # PCA를 쓴다면 몇 개의 component를 사용할 것인지, None이면 PCA 사용 안 함

for pooling_method in pooling:
    print(f"Using {transform} transform with {time_to_take} time points") if using_transform else print("Not using transform")
    print("Using time-wise average") if using_timewiseavg else print("Not using time-wise average")
    print(f"Using {pooling_method} pooling") if pooling_method != None else print("Not using pooling")
    print(f"5 sec divided by {segments} sec")
    print("="*50)

    X_train, y_train, X_valid, y_valid, X_test, y_test = process_data_BrainBERT(data_path, segments, using_transform, transform, time_to_take, using_timewiseavg, pooling_method, PCA_ncomp)
    print(f"X_train: {X_train.shape} | X_test: {X_test.shape}")
    print("X_valid:", X_valid.shape) if X_valid is not None else print("No validation set")
    print("="*50)

    clf_models = [
        LogisticRegression(max_iter=10000, random_state=42), 
        SVC(kernel='rbf', C=1, gamma='scale', probability=True, random_state=42),
        MLPClassifier(hidden_layer_sizes=(1024, 512, 256, 128), validation_fraction=0.3,
                        learning_rate='adaptive', alpha=0.001, batch_size=32,
                        early_stopping=True, max_iter=1000, random_state=42)   
    ]

    for model in clf_models:
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        print(model)
        print("Mean prediction:", y_pred.mean())
        print("Accuracy:", accuracy_score(y_test, y_pred))
        print(classification_report(y_test, y_pred, zero_division=0))
        print("ROC AUC:", roc_auc_score(y_test, y_pred))
        print("\n")

Using stft transform with 191 time points
Not using time-wise average
Using max pooling
5 sec divided by [2, 1.5, 1.5] sec
X_train: torch.Size([78, 13056]) | X_test: torch.Size([42, 13056])
No validation set
LogisticRegression(max_iter=10000, random_state=42)
Mean prediction: 0.8095238095238095
Accuracy: 0.5952380952380952
              precision    recall  f1-score   support

           0       0.25      0.15      0.19        13
           1       0.68      0.79      0.73        29

    accuracy                           0.60        42
   macro avg       0.46      0.47      0.46        42
weighted avg       0.54      0.60      0.56        42

ROC AUC: 0.473474801061008


SVC(C=1, probability=True, random_state=42)
Mean prediction: 1.0
Accuracy: 0.6904761904761905
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.69      1.00      0.82        29

    accuracy                           0.69        42
   macr



MLPClassifier(alpha=0.001, batch_size=32, early_stopping=True,
              hidden_layer_sizes=(1024, 512, 256, 128),
              learning_rate='adaptive', max_iter=1000, random_state=42,
              validation_fraction=0.3)
Mean prediction: 1.0
Accuracy: 0.6904761904761905
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.69      1.00      0.82        29

    accuracy                           0.69        42
   macro avg       0.35      0.50      0.41        42
weighted avg       0.48      0.69      0.56        42

ROC AUC: 0.5




In [29]:
using_transform = True # True or False
transform = 'stft' # 'stft' or 'superlet'
time_to_take = 191 # How many points to take from the time axis when flattened
using_timewiseavg = False
pooling = None # 'max', 'mean', 'sum', None (kernel 사이즈는 위 process_data_BrainBERT 함수에서 수정해야 함)
segments = [2, 1.5, 1.5] # 5초를 몇 초 간격으로 자를 것인지
PCA_ncomp = 10 # PCA를 쓴다면 몇 개의 component를 사용할 것인지, None이면 PCA 사용 안 함

print(f"Using {transform} transform with {time_to_take} time points") if using_transform else print("Not using transform")
print("Using time-wise average") if using_timewiseavg else print("Not using time-wise average")
print(f"Using {pooling} pooling") if pooling else print("Not using pooling")
print(f"5 sec divided by {segments} sec")
print("="*50)

X_train, y_train, X_valid, y_valid, X_test, y_test = process_data_BrainBERT(data_path, segments, using_transform, transform, time_to_take, using_timewiseavg, pooling, PCA_ncomp)
print(f"X_train: {X_train.shape} | X_test: {X_test.shape}")
print("X_valid:", X_valid.shape) if X_valid is not None else print("No validation set")
print("="*50)

clf_models = [
    LogisticRegression(max_iter=10000, random_state=42), 
    SVC(kernel='rbf', C=1, gamma='scale', probability=True, random_state=42),
    MLPClassifier(hidden_layer_sizes=(1024, 512, 256, 128), validation_fraction=0.3,
                    learning_rate='adaptive', alpha=0.001, batch_size=32,
                    early_stopping=True, max_iter=1000, random_state=42)   
]

for model in clf_models:
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    print(model)
    print("Mean prediction:", y_pred.mean())
    print("Accuracy:", accuracy_score(y_test, y_pred))
    print(classification_report(y_test, y_pred, zero_division=0))
    print("ROC AUC:", roc_auc_score(y_test, y_pred))
    print("\n")

Using stft transform with 191 time points
Not using time-wise average
Using PCA pooling
5 sec divided by (2, 1.5, 1.5) sec


X_train: torch.Size([78, 51]) | X_test: torch.Size([42, 51])
No validation set
LogisticRegression(max_iter=10000, random_state=42)
Mean prediction: 0.5
Accuracy: 0.42857142857142855
              precision    recall  f1-score   support

           0       0.24      0.38      0.29        13
           1       0.62      0.45      0.52        29

    accuracy                           0.43        42
   macro avg       0.43      0.42      0.41        42
weighted avg       0.50      0.43      0.45        42

ROC AUC: 0.41644562334217505


SVC(C=1, probability=True, random_state=42)
Mean prediction: 1.0
Accuracy: 0.6904761904761905
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.69      1.00      0.82        29

    accuracy                           0.69        42
   macro avg       0.35      0.50      0.41        42
weighted avg       0.48      0.69      0.56        42

ROC AUC: 0.5


MLPClassifier(alpha=0.00

In [18]:
using_transform = True # True or False
transform = 'stft' # 'stft' or 'superlet'
time_to_take = 191 # How many points to take from the time axis when flattened
using_timewiseavg = False
pooling = 'max' # 'max', 'mean', 'sum', None (kernel 사이즈는 위 process_data_BrainBERT 함수에서 수정해야 함)
segments = [1, 1, 1, 1, 1] # 5초를 몇 초 간격으로 자를 것인지
PCA_ncomp = None # PCA를 쓴다면 몇 개의 component를 사용할 것인지, None이면 PCA 사용 안 함

print(f"Using {transform} transform with {time_to_take} time points") if using_transform else print("Not using transform")
print("Using time-wise average") if using_timewiseavg else print("Not using time-wise average")
print(f"Using {pooling} pooling") if pooling else print("Not using pooling")
print(f"5 sec divided by {segments} sec")
print("="*50)

X_train, y_train, X_valid, y_valid, X_test, y_test = process_data_BrainBERT(data_path, segments, using_transform, transform, time_to_take, using_timewiseavg, pooling, PCA_ncomp)
print(f"X_train: {X_train.shape} | X_test: {X_test.shape}")
print("X_valid:", X_valid.shape) if X_valid is not None else print("No validation set")
print("="*50)

clf_models = [
    LogisticRegression(max_iter=10000, random_state=42), 
    SVC(kernel='rbf', C=1, gamma='scale', probability=True, random_state=42),
    MLPClassifier(hidden_layer_sizes=(1024, 512, 256, 128), validation_fraction=0.3,
                    learning_rate='adaptive', alpha=0.001, batch_size=32,
                    early_stopping=True, max_iter=1000, random_state=42)
]

for model in clf_models:
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    print(model)
    print("Mean prediction:", y_pred.mean())
    print("Accuracy:", accuracy_score(y_test, y_pred))
    print(classification_report(y_test, y_pred, zero_division=0))
    print("ROC AUC:", roc_auc_score(y_test, y_pred))
    print("\n")

Using stft transform with 191 time points
Not using time-wise average
Using max pooling
5 sec divided by [1, 1, 1, 1, 1] sec


X_train: torch.Size([78, 11520]) | X_test: torch.Size([42, 11520])
No validation set
LogisticRegression(max_iter=10000, random_state=42)
Mean prediction: 0.7857142857142857
Accuracy: 0.6190476190476191
              precision    recall  f1-score   support

           0       0.33      0.23      0.27        13
           1       0.70      0.79      0.74        29

    accuracy                           0.62        42
   macro avg       0.52      0.51      0.51        42
weighted avg       0.58      0.62      0.60        42

ROC AUC: 0.5119363395225465


SVC(C=1, probability=True, random_state=42)
Mean prediction: 1.0
Accuracy: 0.6904761904761905
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.69      1.00      0.82        29

    accuracy                           0.69        42
   macro avg       0.35      0.50      0.41        42
weighted avg       0.48      0.69      0.56        42

ROC AUC: 0.5


MLPCl

# Supple. BrainBERT 사용하지 않는 ver.

In [2]:
def process_data(data_path, using_transform=False, transform=None, time_to_take=None, using_segment=False, seg_len=None, using_timewiseavg=False, using_bandpower=False, valid_split=0.2, test_split=0.2, sampling_freq=1000):
    # 1. Load labels and split into train, validation, and test sets
    labels_pd = pd.read_csv(os.path.join(data_path, "labels.csv")).sort_index()
    train_pd, valid_pd, test_pd = utils.pd_train_val_test_split(labels_pd, valid_split, test_split, shuffle_seed=42)

    # 2. Convert dataframes to arrays
    X_train, y_train = utils.arr_from_pd(data_path, train_pd)
    X_valid, y_valid = utils.arr_from_pd(data_path, valid_pd)
    X_test, y_test = utils.arr_from_pd(data_path, test_pd)

    # 3. Apply transformations
    if using_transform:
        if transform == 'stft':
            X_train = apply_stft(X_train, nperseg=400, clip_fs=5, sampling_freq=sampling_freq)
            X_test = apply_stft(X_test, nperseg=400, clip_fs=5, sampling_freq=sampling_freq)
            if valid_split != 0:
                X_valid = apply_stft(X_valid, nperseg=400, clip_fs=5, sampling_freq=sampling_freq)
        elif transform == 'superlet':
            X_train = apply_superlet(X_train, sampling_freq=sampling_freq)
            X_test = apply_superlet(X_test, sampling_freq=sampling_freq)
            if valid_split != 0:
                X_valid = apply_superlet(X_valid, sampling_freq=sampling_freq)
        print(f"Transformed shape: X train: {X_train.shape}, y train: {y_train.shape}, X test: {X_test.shape}, y test: {y_test.shape}")
    
    # 4. Segment data if required
    if using_segment:
        X_train, y_train = utils.segment_data_and_labels(X_train, y_train, seg_len)
        X_test, y_test = utils.segment_data_and_labels(X_test, y_test, seg_len)
        if valid_split != 0:
            X_valid, y_valid = utils.segment_data_and_labels(X_valid, y_valid, seg_len)
        print(f"Segmented shape: X train: {X_train.shape}, y train: {y_train.shape}, X test: {X_test.shape}, y test: {y_test.shape}")        
    
    # 5. Apply time-wise averaging if required
    if using_timewiseavg:
        X_train = X_train.mean(axis=-1, keepdims=True)
        X_test = X_test.mean(axis=-1, keepdims=True)
        if valid_split != 0:
            X_valid = X_valid.mean(axis=-1, keepdims=True)

    # 6. Apply bandpower transformation if required
    if using_transform and using_bandpower:
        X_train = utils.get_bandpower(X_train)
        X_test = utils.get_bandpower(X_test)
        if valid_split != 0:
            X_valid = utils.get_bandpower(X_valid)
        print(f"Bandpower shape: X train: {X_train.shape}, y train: {y_train.shape}, X test: {X_test.shape}, y test: {y_test.shape}")
   
    # 7. Flatten and normalize data
    X_train_flattened = flatten_and_normalize(X_train, using_transform, using_segment, time_to_take)
    X_test_flattened = flatten_and_normalize(X_test, using_transform, using_segment, time_to_take)

    # 8. Print final shapes and return data
    print(f"Input shape: X train: {X_train_flattened.shape}, X test: {X_test_flattened.shape}")
    print(f"Train set 0: {np.sum(y_train == 0)}, Train set 1: {np.sum(y_train == 1)}")
    print(f"Test set 0: {np.sum(y_test == 0)}, Test set 1: {np.sum(y_test == 1)}")
    print("="*50)

    if valid_split != 0:
        X_valid_flattened = flatten_and_normalize(X_valid, using_transform, using_segment, time_to_take)
        return X_train_flattened, y_train, X_valid_flattened, y_valid, X_test_flattened, y_test
    else:
        return X_train_flattened, y_train, None, None, X_test_flattened, y_test

In [11]:
using_transform = False # True or False
transform = 'stft' # 'stft' or 'superlet'
time_to_take = 191 # How many points to take from the time axis when flattened
using_segment = False # True or False
seg_len = 10
using_timewiseavg = False
using_bandpower = [False, True] # if using_transform is False, this will be ignored

print(f"Using {transform} transform with {time_to_take} time points") if using_transform else print("Not using transform")
print(f"Using segment length {seg_len}") if using_segment else print("Not using segment")
print("Using time-wise average") if using_timewiseavg else print("Not using time-wise average")
print("Using bandpower") if using_bandpower else print("Not using bandpower")
print("="*50)

for using_bandpower_bool in using_bandpower:
    print(f"Using bandpower: {using_bandpower_bool}")
    
    X_train, y_train, X_valid, y_valid, X_test, y_test = process_data(data_path, using_transform, transform, time_to_take, using_segment, seg_len,  using_timewiseavg, using_bandpower_bool)
    
    clf_models = [LogisticRegression(max_iter=10000), 
          MLPClassifier(hidden_layer_sizes=(1024, 512, 256, 128), validation_fraction=0.3,
                        learning_rate='adaptive', alpha=0.001, batch_size=32,
                        early_stopping=True, max_iter=1000)]
    
    for model in clf_models:
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        print(model)
        print("Mean prediction:", y_pred.mean())
        print("Accuracy:", accuracy_score(y_test, y_pred))
        print(classification_report(y_test, y_pred, zero_division=0))
        print("ROC AUC:", roc_auc_score(y_test, y_pred))
        print("\n")

Not using transform
Not using segment
Not using time-wise average
Using bandpower
Using bandpower: False
(78, 1, 10000)
(42, 1, 10000)


AttributeError: 'NoneType' object has no attribute 'shape'